Browse Source

cleaned multihead attention added

master
Yassaman Ommi 3 years ago
parent
commit
aafa1ed8fb
1 changed files with 36 additions and 4 deletions
  1. 36
    4
      GraphTransformer.py

+ 36
- 4
GraphTransformer.py View File

@@ -95,7 +95,8 @@ def remove_random_node(graph, max_size=40, min_size=10):

""""
Layers:
GCN
Graph Convolution
Graph Multihead Attention
"""

class GraphConv(nn.Module):
@@ -108,11 +109,42 @@ class GraphConv(nn.Module):

def forward(self, x, adj):
'''
x is hamun feature matrix
x is the feature matrix constructed in feature_matrix function
adj ham ke is adjacency matrix of the graph
'''
y = torch.matmul(adj, x)
print(y.shape)
print(self.weight.shape)
# print(y.shape)
# print(self.weight.shape)
y = torch.matmul(y, self.weight.double())
return y

class GraphAttn(nn.Module):
def __init__(self, heads, model_dim, dropout=0.1):
super().__init__()
self.model_dim = model_dim
self.key_dim = model_dim // heads
self.heads = heads

self.q_linear = nn.Linear(model_dim, model_dim).cuda()
self.v_linear = nn.Linear(model_dim, model_dim).cuda()
self.k_linear = nn.Linear(model_dim, model_dim).cuda()

self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(model_dim, model_dim).cuda()

def forward(self, query, key, value):
# print(q, k, v)
bs = query.size(0) # size of the graph
key = self.k_linear(key).view(bs, -1, self.heads, self.key_dim)
query = self.q_linear(query).view(bs, -1, self.heads, self.key_dim)
value = self.v_linear(value).view(bs, -1, self.heads, self.key_dim)

key = key.transpose(1,2)
query = query.transpose(1,2)
value = value.transpose(1,2)

scores = attention(query, key, value, self.key_dim)
concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)
output = self.out(concat)

return output

Loading…
Cancel
Save