import torch import torch.nn.functional as F from torch.nn.modules.module import Module from torch.nn.parameter import Parameter class GraphConvolution(Module): """ Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ def __init__(self, in_features, out_features, dropout=0., act=F.relu): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.dropout = dropout self.act = act self.weight = Parameter(torch.FloatTensor(in_features, out_features)) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.weight) def forward(self, input, adj): input = F.dropout(input, self.dropout, self.training) support = torch.mm(input, self.weight) output = torch.spmm(adj, support) output = self.act(output) return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'