Browse Source

cleaned multihead attention added

master
Yassaman Ommi 4 years ago
parent
commit
aafa1ed8fb
1 changed files with 36 additions and 4 deletions
  1. 36
    4
      GraphTransformer.py

+ 36
- 4
GraphTransformer.py View File



"""" """"
Layers: Layers:
GCN
Graph Convolution
Graph Multihead Attention
""" """


class GraphConv(nn.Module): class GraphConv(nn.Module):


def forward(self, x, adj): def forward(self, x, adj):
''' '''
x is hamun feature matrix
x is the feature matrix constructed in feature_matrix function
adj ham ke is adjacency matrix of the graph adj ham ke is adjacency matrix of the graph
''' '''
y = torch.matmul(adj, x) y = torch.matmul(adj, x)
print(y.shape)
print(self.weight.shape)
# print(y.shape)
# print(self.weight.shape)
y = torch.matmul(y, self.weight.double()) y = torch.matmul(y, self.weight.double())
return y return y

class GraphAttn(nn.Module):
def __init__(self, heads, model_dim, dropout=0.1):
super().__init__()
self.model_dim = model_dim
self.key_dim = model_dim // heads
self.heads = heads

self.q_linear = nn.Linear(model_dim, model_dim).cuda()
self.v_linear = nn.Linear(model_dim, model_dim).cuda()
self.k_linear = nn.Linear(model_dim, model_dim).cuda()

self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(model_dim, model_dim).cuda()

def forward(self, query, key, value):
# print(q, k, v)
bs = query.size(0) # size of the graph
key = self.k_linear(key).view(bs, -1, self.heads, self.key_dim)
query = self.q_linear(query).view(bs, -1, self.heads, self.key_dim)
value = self.v_linear(value).view(bs, -1, self.heads, self.key_dim)

key = key.transpose(1,2)
query = query.transpose(1,2)
value = value.transpose(1,2)

scores = attention(query, key, value, self.key_dim)
concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)
output = self.out(concat)

return output

Loading…
Cancel
Save