You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

layers.py 1.1KB

4 years ago
12345678910111213141516171819202122232425262728293031323334
  1. import torch
  2. import torch.nn.functional as F
  3. from torch.nn.modules.module import Module
  4. from torch.nn.parameter import Parameter
  5. class GraphConvolution(Module):
  6. """
  7. Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
  8. """
  9. def __init__(self, in_features, out_features, dropout=0., act=F.relu):
  10. super(GraphConvolution, self).__init__()
  11. self.in_features = in_features
  12. self.out_features = out_features
  13. self.dropout = dropout
  14. self.act = act
  15. self.weight = Parameter(torch.FloatTensor(in_features, out_features))
  16. self.reset_parameters()
  17. def reset_parameters(self):
  18. torch.nn.init.xavier_uniform_(self.weight)
  19. def forward(self, input, adj):
  20. input = F.dropout(input, self.dropout, self.training)
  21. support = torch.mm(input, self.weight)
  22. output = torch.spmm(adj, support)
  23. output = self.act(output)
  24. return output
  25. def __repr__(self):
  26. return self.__class__.__name__ + ' (' \
  27. + str(self.in_features) + ' -> ' \
  28. + str(self.out_features) + ')'