You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 7.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. from code.layer import GraphConvolution
  6. def my_sigmoid(mx, dim):
  7. mx = torch.sigmoid(mx)
  8. return F.normalize(mx, p=1, dim=dim)
  9. class Discriminator(nn.Module):
  10. def __init__(self, nfeat, nhid, nclass, dropout, order):
  11. super(Discriminator, self).__init__()
  12. layers = []
  13. if len(nhid) == 0:
  14. layers.append(GraphConvolution(nfeat, nclass, order=order))
  15. else:
  16. layers.append(GraphConvolution(nfeat, nhid[0], order=order))
  17. for i in range(len(nhid) - 1):
  18. layers.append(GraphConvolution(nhid[i], nhid[i + 1], order=order))
  19. if nclass > 1:
  20. layers.append(GraphConvolution(nhid[-1], nclass, order=order))
  21. self.gc = nn.ModuleList(layers)
  22. self.dropout = dropout
  23. self.nclass = nclass
  24. def forward(self, x, adj, samples=-1, func=F.relu):
  25. end_layer = len(self.gc) - 1 if self.nclass > 1 else len(self.gc)
  26. for i in range(end_layer):
  27. x = F.dropout(x, self.dropout, training=self.training)
  28. x = self.gc[i](x, adj)
  29. x = func(x)
  30. if self.nclass > 1:
  31. classifier = self.gc[-1](x, adj)
  32. classifier = F.log_softmax(classifier, dim=1)
  33. return classifier[samples,:], x
  34. else:
  35. return None, x
  36. class Weighing(nn.Module):
  37. def __init__(self, nfeat, nhid, weighing_output_dim, dropout, order):
  38. super(Weighing, self).__init__()
  39. layers = []
  40. layers.append(GraphConvolution(nfeat, nhid[0], order=order))
  41. for i in range(len(nhid) - 1):
  42. layers.append(GraphConvolution(nhid[i], nhid[i + 1], order=order))
  43. layers.append(GraphConvolution(nhid[-1], weighing_output_dim, order=order))
  44. self.gc = nn.ModuleList(layers)
  45. self.dropout = dropout
  46. def forward(self, x, adj, func, samples=-1):
  47. end_layer = len(self.gc) - 1
  48. for i in range(end_layer):
  49. x = self.gc[i](x, adj)
  50. x = F.leaky_relu(x, inplace=False)
  51. x = F.dropout(x, self.dropout, training=self.training)
  52. if type(samples) is int:
  53. weights = func(self.gc[-1](x, adj), dim=0)
  54. else:
  55. weights = func(self.gc[-1](x, adj)[samples, :], dim=0)
  56. if len(weights.shape) == 1:
  57. weights = weights[None, :]
  58. return weights, x
  59. class RAGCN:
  60. def __init__(self, features, adj , nclass, struc_D, struc_Ws, n_ws, weighing_output_dim, act, gamma=1):
  61. nfeat = features.shape[1]
  62. if type(adj['D']) is list:
  63. order_D = len(adj['D'])
  64. else:
  65. order_D = 1
  66. if type(adj['W'][0]) is list:
  67. order_W = len(adj['W'][0])
  68. else:
  69. order_W = 1
  70. self.n_ws = n_ws
  71. self.wod = weighing_output_dim
  72. self.net_D = Discriminator(nfeat, struc_D['nhid'], nclass, struc_D['dropout'], order=order_D)
  73. self.gamma = gamma
  74. self.opt_D = optim.Adam(self.net_D.parameters(), lr=struc_D['lr'], weight_decay=struc_D['wd'])
  75. self.net_Ws = []
  76. self.opt_Ws = []
  77. for i in range(n_ws):
  78. self.net_Ws.append(Weighing(nfeat=nfeat, nhid=struc_Ws[i]['nhid'], weighing_output_dim=weighing_output_dim
  79. , dropout=struc_Ws[i]['dropout'], order=order_W))
  80. self.opt_Ws.append(
  81. optim.Adam(self.net_Ws[-1].parameters(), lr=struc_Ws[i]['lr'], weight_decay=struc_Ws[i]['wd']))
  82. self.adj_D = adj['D']
  83. self.adj_W = adj['W']
  84. self.features = features
  85. self.act = act
  86. def run_D(self, samples):
  87. class_prob, embed = self.net_D(self.features, self.adj_D, samples)
  88. return class_prob, embed
  89. def run_W(self, samples, labels, args_cuda=False, equal_weights=False):
  90. batch_size = samples.shape[0]
  91. embed = None
  92. if equal_weights:
  93. max_label = int(labels.max().item() + 1)
  94. weight = torch.empty(batch_size)
  95. for i in range(max_label):
  96. labels_indices = (labels == i).nonzero().squeeze()
  97. if len(labels_indices.shape) == 0:
  98. batch_size = 1
  99. else:
  100. batch_size = len(labels_indices)
  101. if labels_indices is not None:
  102. weight[labels_indices] = 1 / batch_size * torch.ones(batch_size)
  103. weight = weight / max_label
  104. else:
  105. max_label = int(labels.max().item() + 1)
  106. weight = torch.empty(batch_size)
  107. if args_cuda:
  108. weight = weight.cuda()
  109. for i in range(max_label):
  110. labels_indices = (labels == i).nonzero().squeeze()
  111. if labels_indices is not None:
  112. sub_samples = samples[labels_indices]
  113. weight_, embed = self.net_Ws[i](x=self.features[sub_samples,:], adj=self.adj_W[i], samples=-1, func=self.act)
  114. weight[labels_indices] = weight_.squeeze() if self.wod == 1 else weight_[:,i]
  115. weight = weight / max_label
  116. if args_cuda:
  117. weight = weight.cuda()
  118. return weight, embed
  119. def loss_function_D(self, output, labels, weights):
  120. return torch.sum(- weights * (labels.float() * output).sum(1), -1)
  121. def loss_function_G(self, output, labels, weights):
  122. return torch.sum(- weights * (labels.float() * output).sum(1), -1) - self.gamma*torch.sum(weights*torch.log(weights+1e-20))
  123. def zero_grad_both(self):
  124. self.opt_D.zero_grad()
  125. for opt in self.opt_Ws:
  126. opt.zero_grad()
  127. def run_both(self, epoch_for_D, epoch_for_W, labels_one_hot, samples=-1,
  128. args_cuda=False, equal_weights=False):
  129. labels_not_onehot = labels_one_hot.max(1)[1].type_as(labels_one_hot)
  130. for e_D in range(epoch_for_D):
  131. self.zero_grad_both()
  132. class_prob_1, embed = self.run_D(samples)
  133. weights_1, _ = self.run_W(samples=samples, labels=labels_not_onehot, args_cuda=args_cuda, equal_weights=equal_weights)
  134. loss_D = self.loss_function_D(output=class_prob_1, labels=labels_one_hot, weights=weights_1)
  135. loss_D.backward()
  136. self.opt_D.step()
  137. for e_W in range(epoch_for_W):
  138. self.zero_grad_both()
  139. class_prob_2, embed = self.run_D(samples)
  140. weights_2, embeds = self.run_W(samples=samples, labels=labels_not_onehot, args_cuda=args_cuda, equal_weights=equal_weights)
  141. loss_G = -self.loss_function_G(output=class_prob_2, labels=labels_one_hot, weights=weights_2)
  142. loss_G.backward()
  143. for opt in self.opt_Ws:
  144. opt.step()
  145. def train(self):
  146. self.net_D.train()
  147. for i in range(self.n_ws):
  148. self.net_Ws[i].train()
  149. def eval(self):
  150. self.net_D.eval()
  151. for i in range(self.n_ws):
  152. self.net_Ws[i].eval()
  153. def cuda(self):
  154. self.features = self.features.to(device='cuda')
  155. for i in range(len(self.adj_D)):
  156. self.adj_D[i] = self.adj_D[i].to(device='cuda')
  157. for i in range(len(self.adj_W)):
  158. self.adj_W[i] = self.adj_W[i].to(device='cuda')
  159. self.net_D.cuda()
  160. for i in range(self.n_ws):
  161. self.net_Ws[i].cuda()