|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""" |
|
|
"""" |
|
|
Layers: |
|
|
Layers: |
|
|
GCN |
|
|
|
|
|
|
|
|
Graph Convolution |
|
|
|
|
|
Graph Multihead Attention |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
class GraphConv(nn.Module): |
|
|
class GraphConv(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, adj): |
|
|
def forward(self, x, adj): |
|
|
''' |
|
|
''' |
|
|
x is hamun feature matrix |
|
|
|
|
|
|
|
|
x is the feature matrix constructed in feature_matrix function |
|
|
adj ham ke is adjacency matrix of the graph |
|
|
adj ham ke is adjacency matrix of the graph |
|
|
''' |
|
|
''' |
|
|
y = torch.matmul(adj, x) |
|
|
y = torch.matmul(adj, x) |
|
|
print(y.shape) |
|
|
|
|
|
print(self.weight.shape) |
|
|
|
|
|
|
|
|
# print(y.shape) |
|
|
|
|
|
# print(self.weight.shape) |
|
|
y = torch.matmul(y, self.weight.double()) |
|
|
y = torch.matmul(y, self.weight.double()) |
|
|
return y |
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
class GraphAttn(nn.Module): |
|
|
|
|
|
def __init__(self, heads, model_dim, dropout=0.1): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.model_dim = model_dim |
|
|
|
|
|
self.key_dim = model_dim // heads |
|
|
|
|
|
self.heads = heads |
|
|
|
|
|
|
|
|
|
|
|
self.q_linear = nn.Linear(model_dim, model_dim).cuda() |
|
|
|
|
|
self.v_linear = nn.Linear(model_dim, model_dim).cuda() |
|
|
|
|
|
self.k_linear = nn.Linear(model_dim, model_dim).cuda() |
|
|
|
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.out = nn.Linear(model_dim, model_dim).cuda() |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, query, key, value): |
|
|
|
|
|
# print(q, k, v) |
|
|
|
|
|
bs = query.size(0) # size of the graph |
|
|
|
|
|
key = self.k_linear(key).view(bs, -1, self.heads, self.key_dim) |
|
|
|
|
|
query = self.q_linear(query).view(bs, -1, self.heads, self.key_dim) |
|
|
|
|
|
value = self.v_linear(value).view(bs, -1, self.heads, self.key_dim) |
|
|
|
|
|
|
|
|
|
|
|
key = key.transpose(1,2) |
|
|
|
|
|
query = query.transpose(1,2) |
|
|
|
|
|
value = value.transpose(1,2) |
|
|
|
|
|
|
|
|
|
|
|
scores = attention(query, key, value, self.key_dim) |
|
|
|
|
|
concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim) |
|
|
|
|
|
output = self.out(concat) |
|
|
|
|
|
|
|
|
|
|
|
return output |