PyTorch implementation of Dynamic Graph Convolutional Network
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

wd_gc.py 3.9KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import torch
  2. import torch.nn as nn
  3. class WD_GCN(nn.Module):
  4. def __init__(self, A, X, edges, hidden_feat=[2, 2]):
  5. super(WD_GCN, self).__init__()
  6. self.device = 1
  7. self.A = A
  8. self.X = X
  9. self.T, self.N = X.shape[0], X.shape[1] # T = number of nodes, N = number of node features
  10. self.v = torch.cuda.FloatTensor([self.N, 1])
  11. self.edge_src_nodes = torch.matmul(edges[[0, 1]].transpose(1, 0).float(), self.v).cuda()
  12. self.edge_trg_nodes = torch.matmul(edges[[0, 2]].transpose(1, 0).float(), self.v).cuda()
  13. self.tanh = torch.nn.Tanh()
  14. self.sigmoid = torch.nn.Sigmoid()
  15. self.relu = torch.nn.ReLU(inplace=False)
  16. self.AX = self.compute_AX(A, X)
  17. # GCN parameters
  18. self.W = nn.Parameter(torch.randn(X.shape[-1], hidden_feat[0]).cuda())
  19. # Edge classification/link prediction parameters
  20. self.U = torch.randn(2*hidden_feat[0], hidden_feat[1]).cuda()
  21. # LSTM parameters
  22. self.Wf = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda())
  23. self.Wj = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda())
  24. self.Wc = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda())
  25. self.Wo = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda())
  26. self.Uf = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda())
  27. self.Uj = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda())
  28. self.Uc = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda())
  29. self.Uo = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda())
  30. self.bf = nn.Parameter(torch.randn(hidden_feat[0]).cuda())
  31. self.bj = nn.Parameter(torch.randn(hidden_feat[0]).cuda())
  32. self.bc = nn.Parameter(torch.randn(hidden_feat[0]).cuda())
  33. self.bo = nn.Parameter(torch.randn(hidden_feat[0]).cuda())
  34. self.h_init = torch.randn(hidden_feat[0]).cuda()
  35. self.c_init = torch.randn(hidden_feat[0]).cuda()
  36. def __call__(self, A=None, X=None, edges=None):
  37. return self.forward(A, X, edges)
  38. def forward(self, A=None, X=None, edges=None):
  39. if type(A) == list:
  40. AX = self.compute_AX(A, X)
  41. edge_src_nodes = torch.matmul(edges[[0, 1]].transpose(1, 0).float(), self.v)
  42. edge_trg_nodes = torch.matmul(edges[[0, 2]].transpose(1, 0).float(), self.v)
  43. else:
  44. AX = self.AX
  45. edge_src_nodes = self.edge_src_nodes
  46. edge_trg_nodes = self.edge_trg_nodes
  47. Y = self.relu(torch.matmul(AX.cuda(), self.W.cuda()))
  48. Z = self.LSTM(Y)
  49. Z_mat_edge_src_nodes = Z.reshape(-1, Z.shape[-1])[edge_src_nodes.long()]
  50. Z_mat_edge_trg_nodes = Z.reshape(-1, Z.shape[-1])[edge_trg_nodes.long()]
  51. Z_mat = torch.cat((Z_mat_edge_src_nodes, Z_mat_edge_trg_nodes), dim=1).cuda()
  52. output = torch.matmul(Z_mat, self.U) # this is for prediction head
  53. return output
  54. def compute_AX(self, A, X):
  55. AX = torch.zeros(self.T, self.N, X.shape[-1]).cuda()
  56. for k in range(len(A)):
  57. AX[k] = torch.sparse.mm(A[k].cuda(), X[k])
  58. return AX
  59. def LSTM(self, Y):
  60. c = self.c_init.repeat(self.N, 1)
  61. h = self.h_init.repeat(self.N, 1)
  62. Z = torch.zeros(Y.shape)
  63. for time in range(Y.shape[0]):
  64. f = self.sigmoid(torch.matmul(Y[time], self.Wf) + torch.matmul(h, self.Uf) + self.bf.repeat(self.N, 1))
  65. j = self.sigmoid(torch.matmul(Y[time], self.Wj) + torch.matmul(h, self.Uj) + self.bj.repeat(self.N, 1))
  66. o = self.sigmoid(torch.matmul(Y[time], self.Wo) + torch.matmul(h, self.Uo) + self.bo.repeat(self.N, 1))
  67. ct = self.sigmoid(torch.matmul(Y[time], self.Wc) + torch.matmul(h, self.Uc) + self.bc.repeat(self.N, 1))
  68. c = j * ct + f * c
  69. h = o * self.tanh(c)
  70. Z[time] = h
  71. return Z