| @@ -95,7 +95,8 @@ def remove_random_node(graph, max_size=40, min_size=10): | |||
| """" | |||
| Layers: | |||
| GCN | |||
| Graph Convolution | |||
| Graph Multihead Attention | |||
| """ | |||
| class GraphConv(nn.Module): | |||
| @@ -108,11 +109,42 @@ class GraphConv(nn.Module): | |||
| def forward(self, x, adj): | |||
| ''' | |||
| x is hamun feature matrix | |||
| x is the feature matrix constructed in feature_matrix function | |||
| adj ham ke is adjacency matrix of the graph | |||
| ''' | |||
| y = torch.matmul(adj, x) | |||
| print(y.shape) | |||
| print(self.weight.shape) | |||
| # print(y.shape) | |||
| # print(self.weight.shape) | |||
| y = torch.matmul(y, self.weight.double()) | |||
| return y | |||
| class GraphAttn(nn.Module): | |||
| def __init__(self, heads, model_dim, dropout=0.1): | |||
| super().__init__() | |||
| self.model_dim = model_dim | |||
| self.key_dim = model_dim // heads | |||
| self.heads = heads | |||
| self.q_linear = nn.Linear(model_dim, model_dim).cuda() | |||
| self.v_linear = nn.Linear(model_dim, model_dim).cuda() | |||
| self.k_linear = nn.Linear(model_dim, model_dim).cuda() | |||
| self.dropout = nn.Dropout(dropout) | |||
| self.out = nn.Linear(model_dim, model_dim).cuda() | |||
| def forward(self, query, key, value): | |||
| # print(q, k, v) | |||
| bs = query.size(0) # size of the graph | |||
| key = self.k_linear(key).view(bs, -1, self.heads, self.key_dim) | |||
| query = self.q_linear(query).view(bs, -1, self.heads, self.key_dim) | |||
| value = self.v_linear(value).view(bs, -1, self.heads, self.key_dim) | |||
| key = key.transpose(1,2) | |||
| query = query.transpose(1,2) | |||
| value = value.transpose(1,2) | |||
| scores = attention(query, key, value, self.key_dim) | |||
| concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim) | |||
| output = self.out(concat) | |||
| return output | |||