You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 3.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from layers import GraphConvolution, CrossLayer
  5. from utils import between_index
  6. class GCN(nn.Module):
  7. def __init__(self, nfeat, nhid, nclass, dropout):
  8. super(GCN, self).__init__()
  9. self.nclass = nclass
  10. layers = []
  11. layers.append(GraphConvolution(nfeat, nhid[0]))
  12. for i in range(len(nhid)-1):
  13. layers.append(GraphConvolution(nhid[i], nhid[i+1]))
  14. if nclass > 1:
  15. layers.append(GraphConvolution(nhid[-1], nclass))
  16. self.gc = nn.ModuleList(layers)
  17. self.dropout = dropout
  18. def forward(self, x, adj):
  19. end_layer = len(self.gc)-1 if self.nclass > 1 else len(self.gc)
  20. for i in range(end_layer):
  21. x = F.relu(self.gc[i](x, adj))
  22. x = F.dropout(x, self.dropout, training=self.training)
  23. if self.nclass > 1:
  24. classifier = self.gc[-1](x, adj)
  25. return F.log_softmax(classifier, dim=1), x
  26. else:
  27. return None, x
  28. class CCN(nn.Module):
  29. def __init__(self, n_inputs, inputs_nfeat, inputs_nhid, inputs_nclass, dropout):
  30. super(CCN, self).__init__()
  31. # Features of network
  32. self.n_inputs = n_inputs
  33. self.inputs_nfeat = inputs_nfeat
  34. self.inputs_nhid = inputs_nhid
  35. self.inputs_nclass = inputs_nclass
  36. self.dropout = dropout
  37. # Every single layer
  38. temp_in_layer = []
  39. temp_bet_layer = []
  40. for i in range(n_inputs):
  41. temp_in_layer.append(GCN(inputs_nfeat[i], inputs_nhid[i], inputs_nclass[i], dropout))
  42. self.in_layer = nn.ModuleList(temp_in_layer)
  43. # Between layers
  44. for i in range(n_inputs):
  45. for j in range(i+1,n_inputs):
  46. temp_bet_layer.append(CrossLayer(inputs_nhid[i][-1], inputs_nhid[j][-1], bias=True, bet_weight=False))
  47. self.bet_layer = nn.ModuleList(temp_bet_layer)
  48. def forward(self, xs, adjs):
  49. classes_labels = []
  50. bet_layers_output = []
  51. in_layers_output = []
  52. features = []
  53. # Single layer forward to find features and classes
  54. for i in range(self.n_inputs):
  55. class_labels, feature = self.in_layer[i](xs[i], adjs[i])
  56. classes_labels.append(class_labels)
  57. features.append(feature)
  58. temp = torch.mm(feature, torch.t(feature))
  59. in_layers_output.append(temp)
  60. # Find between layers with CCN
  61. for i in range(self.n_inputs):
  62. for j in range(i+1,self.n_inputs):
  63. temp = self.bet_layer[between_index(self.n_inputs, i, j)](features[i], features[j])
  64. bet_layers_output.append(temp)
  65. return classes_labels, bet_layers_output, in_layers_output
  66. class AggCCN(nn.Module):
  67. def __init__(self, n_inputs, inputs_nfeat, inputs_nhid, inputs_nclass, dropout):
  68. super(AggCCN, self).__init__()
  69. # Features of network
  70. self.n_inputs = n_inputs
  71. self.inputs_nfeat = inputs_nfeat
  72. self.inputs_nhid = inputs_nhid
  73. self.inputs_nclass = inputs_nclass
  74. self.dropout = dropout
  75. # Every single layer
  76. temp_in_layer = []
  77. temp_bet_layer = []
  78. for i in range(n_inputs):
  79. temp_in_layer.append(GCN(inputs_nfeat[i], inputs_nhid[i], inputs_nclass[i], dropout))
  80. self.in_layer = nn.ModuleList(temp_in_layer)
  81. def forward(self, xs, adjs):
  82. classes_labels = []
  83. bet_layers_output = []
  84. in_layers_output = []
  85. features = []
  86. # Single layer forward to find features and classes
  87. for i in range(self.n_inputs):
  88. class_labels, feature = self.in_layer[i](xs[i], adjs[i])
  89. classes_labels.append(class_labels)
  90. features.append(feature)
  91. temp = torch.mm(feature, torch.t(feature))
  92. in_layers_output.append(temp)
  93. return classes_labels, in_layers_output