You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

layers.py 2.5KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import math
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from torch.nn.modules.module import Module
  5. class GraphConvolution(Module):
  6. """
  7. Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
  8. """
  9. def __init__(self, in_features, out_features, bias=True):
  10. super(GraphConvolution, self).__init__()
  11. self.in_features = in_features
  12. self.out_features = out_features
  13. self.weight = Parameter(torch.FloatTensor(in_features, out_features))
  14. if bias:
  15. self.bias = Parameter(torch.FloatTensor(out_features))
  16. else:
  17. self.register_parameter('bias', None)
  18. self.reset_parameters()
  19. def reset_parameters(self):
  20. stdv = 1. / math.sqrt(self.weight.size(1))
  21. self.weight.data.uniform_(-stdv, stdv)
  22. if self.bias is not None:
  23. self.bias.data.uniform_(-stdv, stdv)
  24. def forward(self, input, adj):
  25. support = torch.spmm(input, self.weight)
  26. output = torch.spmm(adj, support)
  27. if self.bias is not None:
  28. return output + self.bias
  29. else:
  30. return output
  31. def __repr__(self):
  32. return self.__class__.__name__ + ' (' \
  33. + str(self.in_features) + ' -> ' \
  34. + str(self.out_features) + ')'
  35. class CrossLayer(Module):
  36. """
  37. MultiLayer
  38. """
  39. def __init__(self, L1_dim, L2_dim, bias=True, bet_weight=True):
  40. super(CrossLayer, self).__init__()
  41. self.L1_dim = L1_dim
  42. self.L2_dim = L2_dim
  43. self.bet_weight = bet_weight
  44. self.weight = Parameter(torch.FloatTensor(L1_dim, L2_dim))
  45. if bias:
  46. self.bias = Parameter(torch.FloatTensor(L2_dim))
  47. else:
  48. self.register_parameter('bias', None)
  49. self.reset_parameters()
  50. def reset_parameters(self):
  51. stdv = 1. / math.sqrt(self.weight.size(1))
  52. self.weight.data.uniform_(-stdv, stdv)
  53. if self.bias is not None:
  54. self.bias.data.uniform_(-stdv, stdv)
  55. def forward(self, L1_features, L2_features):
  56. if self.bet_weight:
  57. temp = torch.mm(L1_features, self.weight)
  58. output = torch.mm(temp, torch.t(L2_features))
  59. if self.bias is not None:
  60. output = output + self.bias
  61. else:
  62. output = torch.mm(L1_features, torch.t(L2_features))
  63. return output
  64. def __repr__(self):
  65. return self.__class__.__name__ + ' (' \
  66. + str(self.L1_dim) + ' -> ' \
  67. + str(self.L2_dim) + ')'