|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
class WD_GCN(nn.Module): |
|
|
|
|
|
def __init__(self, A, X, edges, hidden_feat=[2, 2]): |
|
|
|
|
|
super(WD_GCN, self).__init__() |
|
|
|
|
|
self.device = 1 |
|
|
|
|
|
self.A = A |
|
|
|
|
|
self.X = X |
|
|
|
|
|
self.T, self.N = X.shape[0], X.shape[1] # T = number of nodes, N = number of node features |
|
|
|
|
|
self.v = torch.cuda.FloatTensor([self.N, 1]) |
|
|
|
|
|
self.edge_src_nodes = torch.matmul(edges[[0, 1]].transpose(1, 0).float(), self.v).cuda() |
|
|
|
|
|
self.edge_trg_nodes = torch.matmul(edges[[0, 2]].transpose(1, 0).float(), self.v).cuda() |
|
|
|
|
|
self.tanh = torch.nn.Tanh() |
|
|
|
|
|
self.sigmoid = torch.nn.Sigmoid() |
|
|
|
|
|
self.relu = torch.nn.ReLU(inplace=False) |
|
|
|
|
|
self.AX = self.compute_AX(A, X) |
|
|
|
|
|
|
|
|
|
|
|
# GCN parameters |
|
|
|
|
|
self.W = nn.Parameter(torch.randn(X.shape[-1], hidden_feat[0]).cuda()) |
|
|
|
|
|
|
|
|
|
|
|
# Edge classification/link prediction parameters |
|
|
|
|
|
self.U = torch.randn(2*hidden_feat[0], hidden_feat[1]).cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# LSTM parameters |
|
|
|
|
|
self.Wf = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda()) |
|
|
|
|
|
self.Wj = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda()) |
|
|
|
|
|
self.Wc = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda()) |
|
|
|
|
|
self.Wo = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda()) |
|
|
|
|
|
self.Uf = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda()) |
|
|
|
|
|
self.Uj = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda()) |
|
|
|
|
|
self.Uc = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda()) |
|
|
|
|
|
self.Uo = nn.Parameter(torch.randn(hidden_feat[0], hidden_feat[0]).cuda()) |
|
|
|
|
|
self.bf = nn.Parameter(torch.randn(hidden_feat[0]).cuda()) |
|
|
|
|
|
self.bj = nn.Parameter(torch.randn(hidden_feat[0]).cuda()) |
|
|
|
|
|
self.bc = nn.Parameter(torch.randn(hidden_feat[0]).cuda()) |
|
|
|
|
|
self.bo = nn.Parameter(torch.randn(hidden_feat[0]).cuda()) |
|
|
|
|
|
self.h_init = torch.randn(hidden_feat[0]).cuda() |
|
|
|
|
|
self.c_init = torch.randn(hidden_feat[0]).cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, A=None, X=None, edges=None): |
|
|
|
|
|
return self.forward(A, X, edges) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, A=None, X=None, edges=None): |
|
|
|
|
|
if type(A) == list: |
|
|
|
|
|
AX = self.compute_AX(A, X) |
|
|
|
|
|
edge_src_nodes = torch.matmul(edges[[0, 1]].transpose(1, 0).float(), self.v) |
|
|
|
|
|
edge_trg_nodes = torch.matmul(edges[[0, 2]].transpose(1, 0).float(), self.v) |
|
|
|
|
|
else: |
|
|
|
|
|
AX = self.AX |
|
|
|
|
|
edge_src_nodes = self.edge_src_nodes |
|
|
|
|
|
edge_trg_nodes = self.edge_trg_nodes |
|
|
|
|
|
|
|
|
|
|
|
Y = self.relu(torch.matmul(AX.cuda(), self.W.cuda())) |
|
|
|
|
|
Z = self.LSTM(Y) |
|
|
|
|
|
Z_mat_edge_src_nodes = Z.reshape(-1, Z.shape[-1])[edge_src_nodes.long()] |
|
|
|
|
|
Z_mat_edge_trg_nodes = Z.reshape(-1, Z.shape[-1])[edge_trg_nodes.long()] |
|
|
|
|
|
Z_mat = torch.cat((Z_mat_edge_src_nodes, Z_mat_edge_trg_nodes), dim=1).cuda() |
|
|
|
|
|
output = torch.matmul(Z_mat, self.U) # this is for prediction head |
|
|
|
|
|
|
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def compute_AX(self, A, X): |
|
|
|
|
|
AX = torch.zeros(self.T, self.N, X.shape[-1]).cuda() |
|
|
|
|
|
for k in range(len(A)): |
|
|
|
|
|
AX[k] = torch.sparse.mm(A[k].cuda(), X[k]) |
|
|
|
|
|
|
|
|
|
|
|
return AX |
|
|
|
|
|
|
|
|
|
|
|
def LSTM(self, Y): |
|
|
|
|
|
c = self.c_init.repeat(self.N, 1) |
|
|
|
|
|
h = self.h_init.repeat(self.N, 1) |
|
|
|
|
|
Z = torch.zeros(Y.shape) |
|
|
|
|
|
for time in range(Y.shape[0]): |
|
|
|
|
|
f = self.sigmoid(torch.matmul(Y[time], self.Wf) + torch.matmul(h, self.Uf) + self.bf.repeat(self.N, 1)) |
|
|
|
|
|
j = self.sigmoid(torch.matmul(Y[time], self.Wj) + torch.matmul(h, self.Uj) + self.bj.repeat(self.N, 1)) |
|
|
|
|
|
o = self.sigmoid(torch.matmul(Y[time], self.Wo) + torch.matmul(h, self.Uo) + self.bo.repeat(self.N, 1)) |
|
|
|
|
|
ct = self.sigmoid(torch.matmul(Y[time], self.Wc) + torch.matmul(h, self.Uc) + self.bc.repeat(self.N, 1)) |
|
|
|
|
|
c = j * ct + f * c |
|
|
|
|
|
h = o * self.tanh(c) |
|
|
|
|
|
Z[time] = h |
|
|
|
|
|
|
|
|
|
|
|
return Z |