You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

layer.py 1.7KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import math
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from torch.nn.modules.module import Module
  5. class GraphConvolution(Module):
  6. """
  7. Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
  8. """
  9. def __init__(self, in_features, out_features, order, bias=True):
  10. super(GraphConvolution, self).__init__()
  11. self.in_features = in_features
  12. self.out_features = out_features
  13. self.order = order
  14. self.weight = torch.nn.ParameterList([])
  15. for i in range(self.order):
  16. self.weight.append(Parameter(torch.FloatTensor(in_features, out_features)))
  17. if bias:
  18. self.bias = Parameter(torch.FloatTensor(out_features))
  19. else:
  20. self.register_parameter('bias', None)
  21. self.reset_parameters()
  22. def reset_parameters(self):
  23. for i in range(self.order):
  24. stdv = 1. / math.sqrt(self.weight[i].size(1))
  25. self.weight[i].data.uniform_(-stdv, stdv)
  26. if self.bias is not None:
  27. self.bias.data.uniform_(-stdv, stdv)
  28. def forward(self, input, adj):
  29. output = []
  30. if self.order == 1 and type(adj) != list:
  31. adj = [adj]
  32. for i in range(self.order):
  33. support = torch.mm(input, self.weight[i])
  34. # output.append(support)
  35. output.append(torch.mm(adj[i], support))
  36. output = sum(output)
  37. if self.bias is not None:
  38. return output + self.bias
  39. else:
  40. return output
  41. def __repr__(self):
  42. return self.__class__.__name__ + ' (' \
  43. + str(self.in_features) + ' -> ' \
  44. + str(self.out_features) + ')'