You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 62KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555
  1. from __future__ import unicode_literals, print_function, division
  2. from io import open
  3. import unicodedata
  4. import string
  5. import re
  6. import random
  7. import torch
  8. import torch.nn as nn
  9. from torch.autograd import Variable
  10. from torch import optim
  11. import torch.nn.functional as F
  12. import torch.nn.init as init
  13. from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
  14. from collections import OrderedDict
  15. import math
  16. import numpy as np
  17. import time
  18. def binary_cross_entropy_weight(y_pred, y, has_weight=False, weight_length=1, weight_max=10):
  19. '''
  20. :param y_pred:
  21. :param y:
  22. :param weight_length: how long until the end of sequence shall we add weight
  23. :param weight_value: the magnitude that the weight is enhanced
  24. :return:
  25. '''
  26. if has_weight:
  27. weight = torch.ones(y.size(0), y.size(1), y.size(2))
  28. weight_linear = torch.arange(1, weight_length + 1) / weight_length * weight_max
  29. weight_linear = weight_linear.view(1, weight_length, 1).repeat(y.size(0), 1, y.size(2))
  30. weight[:, -1 * weight_length:, :] = weight_linear
  31. loss = F.binary_cross_entropy(y_pred, y, weight=weight.cuda())
  32. else:
  33. loss = F.binary_cross_entropy(y_pred, y)
  34. return loss
  35. def sample_tensor(y, sample=True, thresh=0.5):
  36. # do sampling
  37. if sample:
  38. y_thresh = Variable(torch.rand(y.size())).cuda()
  39. y_result = torch.gt(y, y_thresh).float()
  40. # do max likelihood based on some threshold
  41. else:
  42. y_thresh = Variable(torch.ones(y.size()) * thresh).cuda()
  43. y_result = torch.gt(y, y_thresh).float()
  44. return y_result
  45. def gumbel_softmax(logits, temperature, eps=1e-9):
  46. '''
  47. :param logits: shape: N*L
  48. :param temperature:
  49. :param eps:
  50. :return:
  51. '''
  52. # get gumbel noise
  53. noise = torch.rand(logits.size())
  54. noise.add_(eps).log_().neg_()
  55. noise.add_(eps).log_().neg_()
  56. noise = Variable(noise).cuda()
  57. x = (logits + noise) / temperature
  58. x = F.softmax(x)
  59. return x
  60. # for i in range(10):
  61. # x = Variable(torch.randn(1,10)).cuda()
  62. # y = gumbel_softmax(x, temperature=0.01)
  63. # print(x)
  64. # print(y)
  65. # _,id = y.topk(1)
  66. # print(id)
  67. def gumbel_sigmoid(logits, temperature):
  68. '''
  69. :param logits:
  70. :param temperature:
  71. :param eps:
  72. :return:
  73. '''
  74. # get gumbel noise
  75. noise = torch.rand(logits.size()) # uniform(0,1)
  76. noise_logistic = torch.log(noise) - torch.log(1 - noise) # logistic(0,1)
  77. noise = Variable(noise_logistic).cuda()
  78. x = (logits + noise) / temperature
  79. x = F.sigmoid(x)
  80. return x
  81. # x = Variable(torch.randn(100)).cuda()
  82. # y = gumbel_sigmoid(x,temperature=0.01)
  83. # print(x)
  84. # print(y)
  85. def sample_sigmoid(y, sample, thresh=0.5, sample_time=2):
  86. '''
  87. do sampling over unnormalized score
  88. :param y: input
  89. :param sample: Bool
  90. :param thresh: if not sample, the threshold
  91. :param sampe_time: how many times do we sample, if =1, do single sample
  92. :return: sampled result
  93. '''
  94. # do sigmoid first
  95. y = torch.sigmoid(y)
  96. # do sampling
  97. if sample:
  98. if sample_time > 1:
  99. y_result = Variable(torch.rand(y.size(0), y.size(1), y.size(2))).cuda()
  100. # loop over all batches
  101. for i in range(y_result.size(0)):
  102. # do 'multi_sample' times sampling
  103. for j in range(sample_time):
  104. y_thresh = Variable(torch.rand(y.size(1), y.size(2))).cuda()
  105. y_result[i] = torch.gt(y[i], y_thresh).float()
  106. if (torch.sum(y_result[i]).data > 0).any():
  107. break
  108. # else:
  109. # print('all zero',j)
  110. else:
  111. y_thresh = Variable(torch.rand(y.size(0), y.size(1), y.size(2))).cuda()
  112. y_result = torch.gt(y, y_thresh).float()
  113. # do max likelihood based on some threshold
  114. else:
  115. y_thresh = Variable(torch.ones(y.size(0), y.size(1), y.size(2)) * thresh).cuda()
  116. y_result = torch.gt(y, y_thresh).float()
  117. return y_result
  118. def sample_sigmoid_supervised(y_pred, y, current, y_len, sample_time=2):
  119. '''
  120. do sampling over unnormalized score
  121. :param y_pred: input
  122. :param y: supervision
  123. :param sample: Bool
  124. :param thresh: if not sample, the threshold
  125. :param sampe_time: how many times do we sample, if =1, do single sample
  126. :return: sampled result
  127. '''
  128. # do sigmoid first
  129. y_pred = F.sigmoid(y_pred)
  130. # do sampling
  131. y_result = Variable(torch.rand(y_pred.size(0), y_pred.size(1), y_pred.size(2))).cuda()
  132. # loop over all batches
  133. for i in range(y_result.size(0)):
  134. # using supervision
  135. if current < y_len[i]:
  136. while True:
  137. y_thresh = Variable(torch.rand(y_pred.size(1), y_pred.size(2))).cuda()
  138. y_result[i] = torch.gt(y_pred[i], y_thresh).float()
  139. # print('current',current)
  140. # print('y_result',y_result[i].data)
  141. # print('y',y[i])
  142. y_diff = y_result[i].data - y[i]
  143. if (y_diff >= 0).all():
  144. break
  145. # supervision done
  146. else:
  147. # do 'multi_sample' times sampling
  148. for j in range(sample_time):
  149. y_thresh = Variable(torch.rand(y_pred.size(1), y_pred.size(2))).cuda()
  150. y_result[i] = torch.gt(y_pred[i], y_thresh).float()
  151. if (torch.sum(y_result[i]).data > 0).any():
  152. break
  153. return y_result
  154. def sample_sigmoid_supervised_simple(y_pred, y, current, y_len, sample_time=2):
  155. '''
  156. do sampling over unnormalized score
  157. :param y_pred: input
  158. :param y: supervision
  159. :param sample: Bool
  160. :param thresh: if not sample, the threshold
  161. :param sampe_time: how many times do we sample, if =1, do single sample
  162. :return: sampled result
  163. '''
  164. # do sigmoid first
  165. y_pred = F.sigmoid(y_pred)
  166. # do sampling
  167. y_result = Variable(torch.rand(y_pred.size(0), y_pred.size(1), y_pred.size(2))).cuda()
  168. # loop over all batches
  169. for i in range(y_result.size(0)):
  170. # using supervision
  171. if current < y_len[i]:
  172. y_result[i] = y[i]
  173. # supervision done
  174. else:
  175. # do 'multi_sample' times sampling
  176. for j in range(sample_time):
  177. y_thresh = Variable(torch.rand(y_pred.size(1), y_pred.size(2))).cuda()
  178. y_result[i] = torch.gt(y_pred[i], y_thresh).float()
  179. if (torch.sum(y_result[i]).data > 0).any():
  180. break
  181. return y_result
  182. ################### current adopted model, LSTM+MLP || LSTM+VAE || LSTM+LSTM (where LSTM can be GRU as well)
  183. #####
  184. # definition of terms
  185. # h: hidden state of LSTM
  186. # y: edge prediction, model output
  187. # n: noise for generator
  188. # l: whether an output is real or not, binary
  189. # plain LSTM model
  190. class LSTM_plain(nn.Module):
  191. def __init__(self, input_size, embedding_size, hidden_size, num_layers, has_input=True, has_output=False,
  192. output_size=None):
  193. super(LSTM_plain, self).__init__()
  194. self.num_layers = num_layers
  195. self.hidden_size = hidden_size
  196. self.has_input = has_input
  197. self.has_output = has_output
  198. if has_input:
  199. self.input = nn.Linear(input_size, embedding_size)
  200. self.rnn = nn.LSTM(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
  201. batch_first=True)
  202. else:
  203. self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  204. if has_output:
  205. self.output = nn.Sequential(
  206. nn.Linear(hidden_size, embedding_size),
  207. nn.ReLU(),
  208. nn.Linear(embedding_size, output_size)
  209. )
  210. self.relu = nn.ReLU()
  211. # initialize
  212. self.hidden = None # need initialize before forward run
  213. for name, param in self.rnn.named_parameters():
  214. if 'bias' in name:
  215. nn.init.constant(param, 0.25)
  216. elif 'weight' in name:
  217. nn.init.xavier_uniform(param, gain=nn.init.calculate_gain('sigmoid'))
  218. for m in self.modules():
  219. if isinstance(m, nn.Linear):
  220. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  221. def init_hidden(self, batch_size):
  222. return (Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).cuda(),
  223. Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).cuda())
  224. def forward(self, input_raw, pack=False, input_len=None):
  225. if self.has_input:
  226. input = self.input(input_raw)
  227. input = self.relu(input)
  228. else:
  229. input = input_raw
  230. if pack:
  231. input = pack_padded_sequence(input, input_len, batch_first=True)
  232. output_raw, self.hidden = self.rnn(input, self.hidden)
  233. if pack:
  234. output_raw = pad_packed_sequence(output_raw, batch_first=True)[0]
  235. if self.has_output:
  236. output_raw = self.output(output_raw)
  237. # return hidden state at each time step
  238. return output_raw
  239. # plain GRU model
  240. class GRU_plain(nn.Module):
  241. def __init__(self, input_size, embedding_size, hidden_size, num_layers, has_input=True, has_output=False,
  242. output_size=None):
  243. super(GRU_plain, self).__init__()
  244. self.num_layers = num_layers
  245. self.hidden_size = hidden_size
  246. self.has_input = has_input
  247. self.has_output = has_output
  248. if has_input:
  249. self.input = nn.Linear(input_size, embedding_size)
  250. self.rnn = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
  251. batch_first=True)
  252. else:
  253. self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  254. if has_output:
  255. self.output = nn.Sequential(
  256. nn.Linear(hidden_size, embedding_size),
  257. nn.ReLU(),
  258. nn.Linear(embedding_size, output_size)
  259. )
  260. self.relu = nn.ReLU()
  261. # initialize
  262. self.hidden = None # need initialize before forward run
  263. for name, param in self.rnn.named_parameters():
  264. if 'bias' in name:
  265. nn.init.constant(param, 0.25)
  266. elif 'weight' in name:
  267. nn.init.xavier_uniform(param, gain=nn.init.calculate_gain('sigmoid'))
  268. for m in self.modules():
  269. if isinstance(m, nn.Linear):
  270. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  271. def init_hidden(self, batch_size):
  272. return Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).cuda()
  273. def forward(self, input_raw, pack=False, input_len=None):
  274. if self.has_input:
  275. input = self.input(input_raw)
  276. input = self.relu(input)
  277. else:
  278. input = input_raw
  279. if pack:
  280. input = pack_padded_sequence(input, input_len, batch_first=True)
  281. output_raw, self.hidden = self.rnn(input, self.hidden)
  282. if pack:
  283. output_raw = pad_packed_sequence(output_raw, batch_first=True)[0]
  284. if self.has_output:
  285. output_raw = self.output(output_raw)
  286. # return hidden state at each time step
  287. return output_raw
  288. # a deterministic linear output
  289. class MLP_plain(nn.Module):
  290. def __init__(self, h_size, embedding_size, y_size):
  291. super(MLP_plain, self).__init__()
  292. self.deterministic_output = nn.Sequential(
  293. nn.Linear(h_size, embedding_size),
  294. nn.ReLU(),
  295. nn.Linear(embedding_size, y_size)
  296. )
  297. for m in self.modules():
  298. if isinstance(m, nn.Linear):
  299. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  300. def forward(self, h):
  301. y = self.deterministic_output(h)
  302. return y
  303. # a deterministic linear output, additional output indicates if the sequence should continue grow
  304. class MLP_token_plain(nn.Module):
  305. def __init__(self, h_size, embedding_size, y_size):
  306. super(MLP_token_plain, self).__init__()
  307. self.deterministic_output = nn.Sequential(
  308. nn.Linear(h_size, embedding_size),
  309. nn.ReLU(),
  310. nn.Linear(embedding_size, y_size)
  311. )
  312. self.token_output = nn.Sequential(
  313. nn.Linear(h_size, embedding_size),
  314. nn.ReLU(),
  315. nn.Linear(embedding_size, 1)
  316. )
  317. for m in self.modules():
  318. if isinstance(m, nn.Linear):
  319. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  320. def forward(self, h):
  321. y = self.deterministic_output(h)
  322. t = self.token_output(h)
  323. return y, t
  324. class Main_MLP_VAE_plain(nn.Module):
  325. def __init__(self, h_size, embedding_size, y_size):
  326. super(Main_MLP_VAE_plain, self).__init__()
  327. self.encode_11 = nn.Linear(h_size, embedding_size) # mu
  328. self.encode_12 = nn.Linear(h_size, embedding_size) # lsgms
  329. self.decode_1 = nn.Linear(embedding_size, embedding_size)
  330. self.decode_2 = nn.Linear(embedding_size, y_size) # make edge prediction (reconstruct)
  331. self.relu = nn.ReLU()
  332. for m in self.modules():
  333. if isinstance(m, nn.Linear):
  334. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  335. def forward(self, h):
  336. # encoder
  337. z_mu = self.encode_11(h)
  338. z_lsgms = self.encode_12(h)
  339. # reparameterize
  340. z_sgm = z_lsgms.mul(0.5).exp_()
  341. eps = Variable(torch.randn(z_sgm.size())).cuda()
  342. z = eps * z_sgm + z_mu
  343. # decoder
  344. # print("### z ")
  345. # print(z)
  346. y = self.decode_1(z)
  347. # print("### decode_1 ")
  348. # print(y)
  349. y = self.relu(y)
  350. y = self.decode_2(y)
  351. # print("### decode_2 ")
  352. # print(y)
  353. return y, z_mu, z_lsgms
  354. # a deterministic linear output (update: add noise)
  355. class MLP_VAE_plain(nn.Module):
  356. def __init__(self, h_size, embedding_size, y_size):
  357. super(MLP_VAE_plain, self).__init__()
  358. self.encode_11 = nn.Linear(h_size, embedding_size) # mu
  359. self.encode_12 = nn.Linear(h_size, embedding_size) # lsgms
  360. self.decode_1 = nn.Linear(embedding_size, 32)
  361. self.linear = nn.Linear(h_size, y_size)
  362. self.bn1 = nn.BatchNorm1d(32)
  363. # self.decode_1_2 = nn.Linear(256, 512)
  364. # self.bn2 = nn.BatchNorm1d(512)
  365. self.decode_2 = nn.Linear(32, y_size) # make edge prediction (reconstruct)
  366. self.relu = nn.ReLU()
  367. for m in self.modules():
  368. if isinstance(m, nn.Linear):
  369. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  370. def forward(self, h):
  371. # encoder
  372. z_mu = self.encode_11(h)
  373. z_lsgms = self.encode_12(h)
  374. # reparameterize
  375. z_sgm = z_lsgms.mul(0.5).exp_()
  376. eps = Variable(torch.randn(z_sgm.size())).cuda()
  377. # z = eps*z_sgm + z_mu
  378. z = z_mu
  379. # decoder
  380. y = self.decode_1(z)
  381. # y = self.bn1(y)
  382. y = self.relu(y)
  383. # y = self.decode_1_2(y)
  384. # y = self.bn2(y)
  385. y = self.relu(y)
  386. y = self.decode_2(y)
  387. # reconstruction = True
  388. # if reconstruction:
  389. # y = self.linear(h)
  390. return y, z_mu, z_lsgms
  391. class GRAN(nn.Module):
  392. def __init__(self, max_num_nodes):
  393. super(GRAN, self).__init__()
  394. self.max_num_nodes = max_num_nodes
  395. self.hidden_dim = 16
  396. self.num_mix_component = 1
  397. self.output_dim = 1
  398. self.output_theta = nn.Sequential(
  399. nn.Linear(2 * self.hidden_dim, self.hidden_dim),
  400. nn.ReLU(inplace=True),
  401. nn.Linear(self.hidden_dim, self.hidden_dim),
  402. nn.ReLU(inplace=True),
  403. nn.Linear(self.hidden_dim, self.output_dim * self.num_mix_component))
  404. def forward(self, h, batch_num_nodes, random_index_list, GRAN_arbitrary_node):
  405. batch_size = h.size()[0]
  406. diff = Variable(torch.zeros(batch_size, self.max_num_nodes - 1, h.size()[2]).cuda())
  407. concat = Variable(torch.zeros(batch_size, self.max_num_nodes - 1, 2 * h.size()[2]).cuda())
  408. for i in range(batch_size):
  409. diff[i, :batch_num_nodes[i] - 1, :] = h[i, :batch_num_nodes[i] - 1, :] - h[i, batch_num_nodes[i] - 1, :]
  410. diff[i, batch_num_nodes[i] - 1:, :] = -100
  411. h_duplicate = h[i, batch_num_nodes[i] - 1, :].unsqueeze(0).repeat(1, batch_num_nodes[i] - 1, 1).squeeze()
  412. concat[i, :batch_num_nodes[i] - 1, :] = torch.cat(
  413. (h[i, :batch_num_nodes[i] - 1, :], h_duplicate), 1)
  414. if GRAN_arbitrary_node:
  415. h_duplicate = h[i, random_index_list[i], :].unsqueeze(0).repeat(1, batch_num_nodes[i] - 1,
  416. 1).squeeze()
  417. concat[i, :random_index_list[i], :] = torch.cat(
  418. (h[i, :random_index_list[i], :], h_duplicate[:random_index_list[i], :]), 1)
  419. concat[i, random_index_list[i]:batch_num_nodes[i] - 1, :] = torch.cat(
  420. (h[i, random_index_list[i] + 1:batch_num_nodes[i], :],
  421. h_duplicate[random_index_list[i]:batch_num_nodes[i] - 1, :]), 1)
  422. # return self.output_theta(diff)
  423. # print("*** h")
  424. # print(h)
  425. # print("*** concat")
  426. # print(concat)
  427. return self.output_theta(concat)
  428. # a deterministic linear output (update: add noise)
  429. class MLP_VAE_conditional_plain(nn.Module):
  430. def __init__(self, h_size, embedding_size, y_size):
  431. super(MLP_VAE_conditional_plain, self).__init__()
  432. self.encode_11 = nn.Linear(h_size, embedding_size) # mu
  433. self.encode_12 = nn.Linear(h_size, embedding_size) # lsgms
  434. self.decode_1 = nn.Linear(embedding_size + h_size, embedding_size)
  435. self.decode_2 = nn.Linear(embedding_size, y_size) # make edge prediction (reconstruct)
  436. self.relu = nn.ReLU()
  437. for m in self.modules():
  438. if isinstance(m, nn.Linear):
  439. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  440. def forward(self, h):
  441. # encoder
  442. z_mu = self.encode_11(h)
  443. z_lsgms = self.encode_12(h)
  444. # reparameterize
  445. z_sgm = z_lsgms.mul(0.5).exp_()
  446. eps = Variable(torch.randn(z_sgm.size(0), z_sgm.size(1), z_sgm.size(2))).cuda()
  447. z = eps * z_sgm + z_mu
  448. # decoder
  449. y = self.decode_1(torch.cat((h, z), dim=2))
  450. y = self.relu(y)
  451. y = self.decode_2(y)
  452. return y, z_mu, z_lsgms
  453. ########### baseline model 1: Learning deep generative model of graphs
  454. class DGM_graphs(nn.Module):
  455. def __init__(self, h_size):
  456. # h_size: node embedding size
  457. # h_size*2: graph embedding size
  458. super(DGM_graphs, self).__init__()
  459. ### all modules used by the model
  460. ## 1 message passing, 2 times
  461. self.m_uv_1 = nn.Linear(h_size * 2, h_size * 2)
  462. self.f_n_1 = nn.GRUCell(h_size * 2, h_size) # input_size, hidden_size
  463. self.m_uv_2 = nn.Linear(h_size * 2, h_size * 2)
  464. self.f_n_2 = nn.GRUCell(h_size * 2, h_size) # input_size, hidden_size
  465. ## 2 graph embedding and new node embedding
  466. # for graph embedding
  467. self.f_m = nn.Linear(h_size, h_size * 2)
  468. self.f_gate = nn.Sequential(
  469. nn.Linear(h_size, 1),
  470. nn.Sigmoid()
  471. )
  472. # for new node embedding
  473. self.f_m_init = nn.Linear(h_size, h_size * 2)
  474. self.f_gate_init = nn.Sequential(
  475. nn.Linear(h_size, 1),
  476. nn.Sigmoid()
  477. )
  478. self.f_init = nn.Linear(h_size * 2, h_size)
  479. ## 3 f_addnode
  480. self.f_an = nn.Sequential(
  481. nn.Linear(h_size * 2, 1),
  482. nn.Sigmoid()
  483. )
  484. ## 4 f_addedge
  485. self.f_ae = nn.Sequential(
  486. nn.Linear(h_size * 2, 1),
  487. nn.Sigmoid()
  488. )
  489. ## 5 f_nodes
  490. self.f_s = nn.Linear(h_size * 2, 1)
  491. def message_passing(node_neighbor, node_embedding, model):
  492. node_embedding_new = []
  493. for i in range(len(node_neighbor)):
  494. neighbor_num = len(node_neighbor[i])
  495. if neighbor_num > 0:
  496. node_self = node_embedding[i].expand(neighbor_num, node_embedding[i].size(1))
  497. node_self_neighbor = torch.cat([node_embedding[j] for j in node_neighbor[i]], dim=0)
  498. message = torch.sum(model.m_uv_1(torch.cat((node_self, node_self_neighbor), dim=1)), dim=0, keepdim=True)
  499. node_embedding_new.append(model.f_n_1(message, node_embedding[i]))
  500. else:
  501. message_null = Variable(torch.zeros((node_embedding[i].size(0), node_embedding[i].size(1) * 2))).cuda()
  502. node_embedding_new.append(model.f_n_1(message_null, node_embedding[i]))
  503. node_embedding = node_embedding_new
  504. node_embedding_new = []
  505. for i in range(len(node_neighbor)):
  506. neighbor_num = len(node_neighbor[i])
  507. if neighbor_num > 0:
  508. node_self = node_embedding[i].expand(neighbor_num, node_embedding[i].size(1))
  509. node_self_neighbor = torch.cat([node_embedding[j] for j in node_neighbor[i]], dim=0)
  510. message = torch.sum(model.m_uv_1(torch.cat((node_self, node_self_neighbor), dim=1)), dim=0, keepdim=True)
  511. node_embedding_new.append(model.f_n_1(message, node_embedding[i]))
  512. else:
  513. message_null = Variable(torch.zeros((node_embedding[i].size(0), node_embedding[i].size(1) * 2))).cuda()
  514. node_embedding_new.append(model.f_n_1(message_null, node_embedding[i]))
  515. return node_embedding_new
  516. def calc_graph_embedding(node_embedding_cat, model):
  517. node_embedding_graph = model.f_m(node_embedding_cat)
  518. node_embedding_graph_gate = model.f_gate(node_embedding_cat)
  519. graph_embedding = torch.sum(torch.mul(node_embedding_graph, node_embedding_graph_gate), dim=0, keepdim=True)
  520. return graph_embedding
  521. def calc_init_embedding(node_embedding_cat, model):
  522. node_embedding_init = model.f_m_init(node_embedding_cat)
  523. node_embedding_init_gate = model.f_gate_init(node_embedding_cat)
  524. init_embedding = torch.sum(torch.mul(node_embedding_init, node_embedding_init_gate), dim=0, keepdim=True)
  525. init_embedding = model.f_init(init_embedding)
  526. return init_embedding
  527. ################################################## code that are NOT used for final version #############
  528. # RNN that updates according to graph structure, new proposed model
  529. class Graph_RNN_structure(nn.Module):
  530. def __init__(self, hidden_size, batch_size, output_size, num_layers, is_dilation=True, is_bn=True):
  531. super(Graph_RNN_structure, self).__init__()
  532. ## model configuration
  533. self.hidden_size = hidden_size
  534. self.batch_size = batch_size
  535. self.output_size = output_size
  536. self.num_layers = num_layers # num_layers of cnn_output
  537. self.is_bn = is_bn
  538. ## model
  539. self.relu = nn.ReLU()
  540. # self.linear_output = nn.Linear(hidden_size, 1)
  541. # self.linear_output_simple = nn.Linear(hidden_size, output_size)
  542. # for state transition use only, input is null
  543. # self.gru = nn.GRU(input_size=1, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  544. # use CNN to produce output prediction
  545. # self.cnn_output = nn.Sequential(
  546. # nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=1, padding=1),
  547. # # nn.BatchNorm1d(hidden_size),
  548. # nn.ReLU(),
  549. # nn.Conv1d(hidden_size, 1, kernel_size=3, dilation=1, padding=1)
  550. # )
  551. if is_dilation:
  552. self.conv_block = nn.ModuleList(
  553. [nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=2 ** i, padding=2 ** i) for i in
  554. range(num_layers - 1)])
  555. else:
  556. self.conv_block = nn.ModuleList(
  557. [nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=1, padding=1) for i in
  558. range(num_layers - 1)])
  559. self.bn_block = nn.ModuleList([nn.BatchNorm1d(hidden_size) for i in range(num_layers - 1)])
  560. self.conv_out = nn.Conv1d(hidden_size, 1, kernel_size=3, dilation=1, padding=1)
  561. # # use CNN to do state transition
  562. # self.cnn_transition = nn.Sequential(
  563. # nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=1, padding=1),
  564. # # nn.BatchNorm1d(hidden_size),
  565. # nn.ReLU(),
  566. # nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=1, padding=1)
  567. # )
  568. # use linear to do transition, same as GCN mean aggregator
  569. self.linear_transition = nn.Sequential(
  570. nn.Linear(hidden_size, hidden_size),
  571. nn.ReLU()
  572. )
  573. # GRU based output, output a single edge prediction at a time
  574. # self.gru_output = nn.GRU(input_size=1, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  575. # use a list to keep all generated hidden vectors, each hidden has size batch*hidden_dim*1, and the list size is expanding
  576. # when using convolution to compute attention weight, we need to first concat the list into a pytorch variable: batch*hidden_dim*current_num_nodes
  577. self.hidden_all = []
  578. ## initialize
  579. for m in self.modules():
  580. if isinstance(m, nn.Linear):
  581. # print('linear')
  582. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  583. # print(m.weight.data.size())
  584. if isinstance(m, nn.Conv1d):
  585. # print('conv1d')
  586. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  587. # print(m.weight.data.size())
  588. if isinstance(m, nn.BatchNorm1d):
  589. # print('batchnorm1d')
  590. m.weight.data.fill_(1)
  591. m.bias.data.zero_()
  592. # print(m.weight.data.size())
  593. if isinstance(m, nn.GRU):
  594. # print('gru')
  595. m.weight_ih_l0.data = init.xavier_uniform(m.weight_ih_l0.data,
  596. gain=nn.init.calculate_gain('sigmoid'))
  597. m.weight_hh_l0.data = init.xavier_uniform(m.weight_hh_l0.data,
  598. gain=nn.init.calculate_gain('sigmoid'))
  599. m.bias_ih_l0.data = torch.ones(m.bias_ih_l0.data.size(0)) * 0.25
  600. m.bias_hh_l0.data = torch.ones(m.bias_hh_l0.data.size(0)) * 0.25
  601. def init_hidden(self, len=None):
  602. if len is None:
  603. return Variable(torch.ones(self.batch_size, self.hidden_size, 1)).cuda()
  604. else:
  605. hidden_list = []
  606. for i in range(len):
  607. hidden_list.append(Variable(torch.ones(self.batch_size, self.hidden_size, 1)).cuda())
  608. return hidden_list
  609. # only run a single forward step
  610. def forward(self, x, teacher_forcing, temperature=0.5, bptt=True, bptt_len=20, flexible=True, max_prev_node=100):
  611. # x: batch*1*self.output_size, the groud truth
  612. # todo: current only look back to self.output_size nodes, try to look back according to bfs sequence
  613. # 1 first compute new state
  614. # print('hidden_all', self.hidden_all[-1*self.output_size:])
  615. # hidden_all_cat = torch.cat(self.hidden_all[-1*self.output_size:], dim=2)
  616. # # # add BPTT, detach the first variable
  617. # if bptt:
  618. # self.hidden_all[0] = Variable(self.hidden_all[0].data).cuda()
  619. hidden_all_cat = torch.cat(self.hidden_all, dim=2)
  620. # print(hidden_all_cat.size())
  621. # print('hidden_all_cat',hidden_all_cat.size())
  622. # att_weight size: batch*1*current_num_nodes
  623. for i in range(self.num_layers - 1):
  624. hidden_all_cat = self.conv_block[i](hidden_all_cat)
  625. if self.is_bn:
  626. hidden_all_cat = self.bn_block[i](hidden_all_cat)
  627. hidden_all_cat = self.relu(hidden_all_cat)
  628. x_pred = self.conv_out(hidden_all_cat)
  629. # 2 then compute output, using a gru
  630. # first try the simple version, directly give the edge prediction
  631. # x_pred = self.linear_output_simple(hidden_new)
  632. # x_pred = x_pred.view(x_pred.size(0),1,x_pred.size(1))
  633. # todo: use a gru version output
  634. # if sample==False:
  635. # # when training: we know the ground truth, input the sequence at once
  636. # y_pred,_ = self.gru_output(x, hidden_new.permute(2,0,1))
  637. # y_pred = self.linear_output(y_pred)
  638. # else:
  639. # # when validating, we need to sampling at each time step
  640. # y_pred = Variable(torch.zeros(x.size(0), x.size(1), x.size(2))).cuda()
  641. # y_pred_long = Variable(torch.zeros(x.size(0), x.size(1), x.size(2))).cuda()
  642. # x_step = x[:, 0:1, :]
  643. # for i in range(x.size(1)):
  644. # y_step,_ = self.gru_output(x_step)
  645. # y_step = self.linear_output(y_step)
  646. # y_pred[:, i, :] = y_step
  647. # y_step = F.sigmoid(y_step)
  648. # x_step = sample(y_step, sample=True, thresh=0.45)
  649. # y_pred_long[:, i, :] = x_step
  650. # pass
  651. # 3 then update self.hidden_all list
  652. # i.e., model will use ground truth to update new node
  653. # x_pred_sample = gumbel_sigmoid(x_pred, temperature=temperature)
  654. x_pred_sample = sample_tensor(F.sigmoid(x_pred), sample=True)
  655. thresh = 0.5
  656. x_thresh = Variable(
  657. torch.ones(x_pred_sample.size(0), x_pred_sample.size(1), x_pred_sample.size(2)) * thresh).cuda()
  658. x_pred_sample_long = torch.gt(x_pred_sample, x_thresh).long()
  659. if teacher_forcing:
  660. # first mask previous hidden states
  661. hidden_all_cat_select = hidden_all_cat * x
  662. x_sum = torch.sum(x, dim=2, keepdim=True).float()
  663. # i.e., the model will use it's own prediction to attend
  664. else:
  665. # first mask previous hidden states
  666. hidden_all_cat_select = hidden_all_cat * x_pred_sample
  667. x_sum = torch.sum(x_pred_sample_long, dim=2, keepdim=True).float()
  668. # update hidden vector for new nodes
  669. hidden_new = torch.sum(hidden_all_cat_select, dim=2, keepdim=True) / x_sum
  670. hidden_new = self.linear_transition(hidden_new.permute(0, 2, 1))
  671. hidden_new = hidden_new.permute(0, 2, 1)
  672. if flexible:
  673. # use ground truth to maintaing history state
  674. if teacher_forcing:
  675. x_id = torch.min(torch.nonzero(torch.squeeze(x.data)))
  676. self.hidden_all = self.hidden_all[x_id:]
  677. # use prediction to maintaing history state
  678. else:
  679. x_id = torch.min(torch.nonzero(torch.squeeze(x_pred_sample_long.data)))
  680. start = max(len(self.hidden_all) - max_prev_node + 1, x_id)
  681. self.hidden_all = self.hidden_all[start:]
  682. # maintaing a fixed size history state
  683. else:
  684. # self.hidden_all.pop(0)
  685. self.hidden_all = self.hidden_all[1:]
  686. self.hidden_all.append(hidden_new)
  687. # 4 return prediction
  688. # print('x_pred',x_pred)
  689. # print('x_pred_mean', torch.mean(x_pred))
  690. # print('x_pred_sample_mean', torch.mean(x_pred_sample))
  691. return x_pred, x_pred_sample
  692. # batch_size = 8
  693. # output_size = 4
  694. # generator = Graph_RNN_structure(hidden_size=16, batch_size=batch_size, output_size=output_size, num_layers=1).cuda()
  695. # for i in range(4):
  696. # generator.hidden_all.append(generator.init_hidden())
  697. #
  698. # x = Variable(torch.rand(batch_size,1,output_size)).cuda()
  699. # x_pred = generator(x,teacher_forcing=True, sample=True)
  700. # print(x_pred)
  701. # current baseline model, generating a graph by lstm
  702. class Graph_generator_LSTM(nn.Module):
  703. def __init__(self, feature_size, input_size, hidden_size, output_size, batch_size, num_layers):
  704. super(Graph_generator_LSTM, self).__init__()
  705. self.batch_size = batch_size
  706. self.num_layers = num_layers
  707. self.hidden_size = hidden_size
  708. self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  709. self.linear_input = nn.Linear(feature_size, input_size)
  710. self.linear_output = nn.Linear(hidden_size, output_size)
  711. self.relu = nn.ReLU()
  712. # initialize
  713. # self.hidden,self.cell = self.init_hidden()
  714. self.hidden = self.init_hidden()
  715. self.lstm.weight_ih_l0.data = init.xavier_uniform(self.lstm.weight_ih_l0.data,
  716. gain=nn.init.calculate_gain('sigmoid'))
  717. self.lstm.weight_hh_l0.data = init.xavier_uniform(self.lstm.weight_hh_l0.data,
  718. gain=nn.init.calculate_gain('sigmoid'))
  719. self.lstm.bias_ih_l0.data = torch.ones(self.lstm.bias_ih_l0.data.size(0)) * 0.25
  720. self.lstm.bias_hh_l0.data = torch.ones(self.lstm.bias_hh_l0.data.size(0)) * 0.25
  721. for m in self.modules():
  722. if isinstance(m, nn.Linear):
  723. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  724. def init_hidden(self):
  725. return (Variable(torch.zeros(self.num_layers, self.batch_size, self.hidden_size)).cuda(),
  726. Variable(torch.zeros(self.num_layers, self.batch_size, self.hidden_size)).cuda())
  727. def forward(self, input_raw, pack=False, len=None):
  728. input = self.linear_input(input_raw)
  729. input = self.relu(input)
  730. if pack:
  731. input = pack_padded_sequence(input, len, batch_first=True)
  732. output_raw, self.hidden = self.lstm(input, self.hidden)
  733. if pack:
  734. output_raw = pad_packed_sequence(output_raw, batch_first=True)[0]
  735. output = self.linear_output(output_raw)
  736. return output
  737. # a simple MLP generator output
  738. class Graph_generator_LSTM_output_generator(nn.Module):
  739. def __init__(self, h_size, n_size, y_size):
  740. super(Graph_generator_LSTM_output_generator, self).__init__()
  741. # one layer MLP
  742. self.generator_output = nn.Sequential(
  743. nn.Linear(h_size + n_size, 64),
  744. nn.ReLU(),
  745. nn.Linear(64, y_size),
  746. nn.Sigmoid()
  747. )
  748. def forward(self, h, n, temperature):
  749. y_cat = torch.cat((h, n), dim=2)
  750. y = self.generator_output(y_cat)
  751. # y = gumbel_sigmoid(y,temperature=temperature)
  752. return y
  753. # a simple MLP discriminator
  754. class Graph_generator_LSTM_output_discriminator(nn.Module):
  755. def __init__(self, h_size, y_size):
  756. super(Graph_generator_LSTM_output_discriminator, self).__init__()
  757. # one layer MLP
  758. self.discriminator_output = nn.Sequential(
  759. nn.Linear(h_size + y_size, 64),
  760. nn.ReLU(),
  761. nn.Linear(64, 1),
  762. nn.Sigmoid()
  763. )
  764. def forward(self, h, y):
  765. y_cat = torch.cat((h, y), dim=2)
  766. l = self.discriminator_output(y_cat)
  767. return l
  768. # GCN basic operation
  769. class GraphConv(nn.Module):
  770. def __init__(self, input_dim, output_dim):
  771. super(GraphConv, self).__init__()
  772. self.input_dim = input_dim
  773. self.output_dim = output_dim
  774. self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
  775. # self.relu = nn.ReLU()
  776. def forward(self, x, adj):
  777. y = torch.matmul(adj, x)
  778. y = torch.matmul(y, self.weight)
  779. return y
  780. # vanilla GCN encoder
  781. class GCN_encoder(nn.Module):
  782. def __init__(self, input_dim, hidden_dim, output_dim):
  783. super(GCN_encoder, self).__init__()
  784. self.conv1 = GraphConv(input_dim=input_dim, output_dim=hidden_dim)
  785. self.conv2 = GraphConv(input_dim=hidden_dim, output_dim=output_dim)
  786. # self.bn1 = nn.BatchNorm1d(output_dim)
  787. # self.bn2 = nn.BatchNorm1d(output_dim)
  788. self.relu = nn.ReLU()
  789. for m in self.modules():
  790. if isinstance(m, GraphConv):
  791. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  792. # init_range = np.sqrt(6.0 / (m.input_dim + m.output_dim))
  793. # m.weight.data = torch.rand([m.input_dim, m.output_dim]).cuda()*init_range
  794. # print('find!')
  795. elif isinstance(m, nn.BatchNorm1d):
  796. m.weight.data.fill_(1)
  797. m.bias.data.zero_()
  798. def forward(self, x, adj):
  799. x = self.conv1(x, adj)
  800. # x = x/torch.sum(x, dim=2, keepdim=True)
  801. x = self.relu(x)
  802. # x = self.bn1(x)
  803. x = self.conv2(x, adj)
  804. # x = x / torch.sum(x, dim=2, keepdim=True)
  805. return x
  806. # vanilla GCN decoder
  807. class GCN_decoder(nn.Module):
  808. def __init__(self):
  809. super(GCN_decoder, self).__init__()
  810. # self.act = nn.Sigmoid()
  811. def forward(self, x):
  812. # x_t = x.view(-1,x.size(2),x.size(1))
  813. x_t = x.permute(0, 2, 1)
  814. # print('x',x)
  815. # print('x_t',x_t)
  816. y = torch.matmul(x, x_t)
  817. return y
  818. # GCN based graph embedding
  819. # allowing for arbitrary num of nodes
  820. class GCN_encoder_graph(nn.Module):
  821. def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
  822. super(GCN_encoder_graph, self).__init__()
  823. self.num_layers = num_layers
  824. self.conv_first = GraphConv(input_dim=input_dim, output_dim=hidden_dim)
  825. # self.conv_hidden1 = GraphConv(input_dim=hidden_dim, output_dim=hidden_dim)
  826. # self.conv_hidden2 = GraphConv(input_dim=hidden_dim, output_dim=hidden_dim)
  827. self.conv_block = nn.ModuleList(
  828. [GraphConv(input_dim=hidden_dim, output_dim=hidden_dim) for i in range(num_layers)])
  829. self.conv_last = GraphConv(input_dim=hidden_dim, output_dim=output_dim)
  830. self.act = nn.ReLU()
  831. for m in self.modules():
  832. if isinstance(m, GraphConv):
  833. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  834. # init_range = np.sqrt(6.0 / (m.input_dim + m.output_dim))
  835. # m.weight.data = torch.rand([m.input_dim, m.output_dim]).cuda()*init_range
  836. # print('find!')
  837. def forward(self, x, adj):
  838. x = self.conv_first(x, adj)
  839. x = self.act(x)
  840. out_all = []
  841. out, _ = torch.max(x, dim=1, keepdim=True)
  842. out_all.append(out)
  843. for i in range(self.num_layers - 2):
  844. x = self.conv_block[i](x, adj)
  845. x = self.act(x)
  846. out, _ = torch.max(x, dim=1, keepdim=True)
  847. out_all.append(out)
  848. x = self.conv_last(x, adj)
  849. x = self.act(x)
  850. out, _ = torch.max(x, dim=1, keepdim=True)
  851. out_all.append(out)
  852. output = torch.cat(out_all, dim=1)
  853. output = output.permute(1, 0, 2)
  854. # print(out)
  855. return output
  856. # x = Variable(torch.rand(1,8,10)).cuda()
  857. # adj = Variable(torch.rand(1,8,8)).cuda()
  858. # model = GCN_encoder_graph(10,10,10).cuda()
  859. # y = model(x,adj)
  860. # print(y.size())
  861. def preprocess(A):
  862. # Get size of the adjacency matrix
  863. size = A.size(1)
  864. # Get the degrees for each node
  865. degrees = torch.sum(A, dim=2)
  866. # Create diagonal matrix D from the degrees of the nodes
  867. D = Variable(torch.zeros(A.size(0), A.size(1), A.size(2))).cuda()
  868. for i in range(D.size(0)):
  869. D[i, :, :] = torch.diag(torch.pow(degrees[i, :], -0.5))
  870. # Cholesky decomposition of D
  871. # D = np.linalg.cholesky(D)
  872. # Inverse of the Cholesky decomposition of D
  873. # D = np.linalg.inv(D)
  874. # Create an identity matrix of size x size
  875. # Create A hat
  876. # Return A_hat
  877. A_normal = torch.matmul(torch.matmul(D, A), D)
  878. # print(A_normal)
  879. return A_normal
  880. # a sequential GCN model, GCN with n layers
  881. class GCN_generator(nn.Module):
  882. def __init__(self, hidden_dim):
  883. super(GCN_generator, self).__init__()
  884. # todo: add an linear_input module to map the input feature into 'hidden_dim'
  885. self.conv = GraphConv(input_dim=hidden_dim, output_dim=hidden_dim)
  886. self.act = nn.ReLU()
  887. # initialize
  888. for m in self.modules():
  889. if isinstance(m, GraphConv):
  890. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  891. def forward(self, x, teacher_force=False, adj_real=None):
  892. # x: batch * node_num * feature
  893. batch_num = x.size(0)
  894. node_num = x.size(1)
  895. adj = Variable(torch.eye(node_num).view(1, node_num, node_num).repeat(batch_num, 1, 1)).cuda()
  896. adj_output = Variable(torch.eye(node_num).view(1, node_num, node_num).repeat(batch_num, 1, 1)).cuda()
  897. # do GCN n times
  898. # todo: try if residual connections are plausible
  899. # todo: add higher order of adj (adj^2, adj^3, ...)
  900. # todo: try if norm everytim is plausible
  901. # first do GCN 1 time to preprocess the raw features
  902. # x_new = self.conv(x, adj)
  903. # x_new = self.act(x_new)
  904. # x = x + x_new
  905. x = self.conv(x, adj)
  906. x = self.act(x)
  907. # x = x / torch.norm(x, p=2, dim=2, keepdim=True)
  908. # then do GCN rest n-1 times
  909. for i in range(1, node_num):
  910. # 1 calc prob of a new edge, output the result in adj_output
  911. x_last = x[:, i:i + 1, :].clone()
  912. x_prev = x[:, 0:i, :].clone()
  913. x_prev = x_prev
  914. x_last = x_last
  915. prob = x_prev @ x_last.permute(0, 2, 1)
  916. adj_output[:, i, 0:i] = prob.permute(0, 2, 1).clone()
  917. adj_output[:, 0:i, i] = prob.clone()
  918. # 2 update adj
  919. if teacher_force:
  920. adj = Variable(torch.eye(node_num).view(1, node_num, node_num).repeat(batch_num, 1, 1)).cuda()
  921. adj[:, 0:i + 1, 0:i + 1] = adj_real[:, 0:i + 1, 0:i + 1].clone()
  922. else:
  923. adj[:, i, 0:i] = prob.permute(0, 2, 1).clone()
  924. adj[:, 0:i, i] = prob.clone()
  925. adj = preprocess(adj)
  926. # print(adj)
  927. # print(adj.min().data[0],adj.max().data[0])
  928. # print(x.min().data[0],x.max().data[0])
  929. # 3 do graph conv, with residual connection
  930. # x_new = self.conv(x, adj)
  931. # x_new = self.act(x_new)
  932. # x = x + x_new
  933. x = self.conv(x, adj)
  934. x = self.act(x)
  935. # x = x / torch.norm(x, p=2, dim=2, keepdim=True)
  936. # one = Variable(torch.ones(adj_output.size(0), adj_output.size(1), adj_output.size(2)) * 1.00).cuda().float()
  937. # two = Variable(torch.ones(adj_output.size(0), adj_output.size(1), adj_output.size(2)) * 2.01).cuda().float()
  938. # adj_output = (adj_output + one) / two
  939. # print(adj_output.max().data[0], adj_output.min().data[0])
  940. return adj_output
  941. # #### test code ####
  942. # print('teacher forcing')
  943. # # print('no teacher forcing')
  944. #
  945. # start = time.time()
  946. # generator = GCN_generator(hidden_dim=4)
  947. # end = time.time()
  948. # print('model build time', end-start)
  949. # for run in range(10):
  950. # for i in [500]:
  951. # for batch in [1,10,100]:
  952. # start = time.time()
  953. # torch.manual_seed(123)
  954. # x = Variable(torch.rand(batch,i,4)).cuda()
  955. # adj = Variable(torch.eye(i).view(1,i,i).repeat(batch,1,1)).cuda()
  956. # # print('x', x)
  957. # # print('adj', adj)
  958. #
  959. # # y = generator(x)
  960. # y = generator(x,True,adj)
  961. # # print('y',y)
  962. # end = time.time()
  963. # print('node num', i, ' batch size',batch, ' run time', end-start)
  964. class CNN_decoder(nn.Module):
  965. def __init__(self, input_size, output_size, stride=2):
  966. super(CNN_decoder, self).__init__()
  967. self.input_size = input_size
  968. self.output_size = output_size
  969. self.relu = nn.ReLU()
  970. self.deconv1_1 = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.input_size / 2),
  971. kernel_size=3, stride=stride)
  972. self.bn1_1 = nn.BatchNorm1d(int(self.input_size / 2))
  973. self.deconv1_2 = nn.ConvTranspose1d(in_channels=int(self.input_size / 2), out_channels=int(self.input_size / 2),
  974. kernel_size=3, stride=stride)
  975. self.bn1_2 = nn.BatchNorm1d(int(self.input_size / 2))
  976. self.deconv1_3 = nn.ConvTranspose1d(in_channels=int(self.input_size / 2), out_channels=int(self.output_size),
  977. kernel_size=3, stride=1, padding=1)
  978. self.deconv2_1 = nn.ConvTranspose1d(in_channels=int(self.input_size / 2), out_channels=int(self.input_size / 4),
  979. kernel_size=3, stride=stride)
  980. self.bn2_1 = nn.BatchNorm1d(int(self.input_size / 4))
  981. self.deconv2_2 = nn.ConvTranspose1d(in_channels=int(self.input_size / 4), out_channels=int(self.input_size / 4),
  982. kernel_size=3, stride=stride)
  983. self.bn2_2 = nn.BatchNorm1d(int(self.input_size / 4))
  984. self.deconv2_3 = nn.ConvTranspose1d(in_channels=int(self.input_size / 4), out_channels=int(self.output_size),
  985. kernel_size=3, stride=1, padding=1)
  986. self.deconv3_1 = nn.ConvTranspose1d(in_channels=int(self.input_size / 4), out_channels=int(self.input_size / 8),
  987. kernel_size=3, stride=stride)
  988. self.bn3_1 = nn.BatchNorm1d(int(self.input_size / 8))
  989. self.deconv3_2 = nn.ConvTranspose1d(in_channels=int(self.input_size / 8), out_channels=int(self.input_size / 8),
  990. kernel_size=3, stride=stride)
  991. self.bn3_2 = nn.BatchNorm1d(int(self.input_size / 8))
  992. self.deconv3_3 = nn.ConvTranspose1d(in_channels=int(self.input_size / 8), out_channels=int(self.output_size),
  993. kernel_size=3, stride=1, padding=1)
  994. for m in self.modules():
  995. if isinstance(m, nn.ConvTranspose1d):
  996. # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  997. # m.weight.dataset.normal_(0, math.sqrt(2. / n))
  998. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  999. elif isinstance(m, nn.BatchNorm1d):
  1000. m.weight.data.fill_(1)
  1001. m.bias.data.zero_()
  1002. def forward(self, x):
  1003. '''
  1004. :param
  1005. x: batch * channel * length
  1006. :return:
  1007. '''
  1008. # hop1
  1009. x = self.deconv1_1(x)
  1010. x = self.bn1_1(x)
  1011. x = self.relu(x)
  1012. # print(x.size())
  1013. x = self.deconv1_2(x)
  1014. x = self.bn1_2(x)
  1015. x = self.relu(x)
  1016. # print(x.size())
  1017. x_hop1 = self.deconv1_3(x)
  1018. # print(x_hop1.size())
  1019. # hop2
  1020. x = self.deconv2_1(x)
  1021. x = self.bn2_1(x)
  1022. x = self.relu(x)
  1023. # print(x.size())
  1024. x = self.deconv2_2(x)
  1025. x = self.bn2_2(x)
  1026. x = self.relu(x)
  1027. x_hop2 = self.deconv2_3(x)
  1028. # print(x_hop2.size())
  1029. # hop3
  1030. x = self.deconv3_1(x)
  1031. x = self.bn3_1(x)
  1032. x = self.relu(x)
  1033. # print(x.size())
  1034. x = self.deconv3_2(x)
  1035. x = self.bn3_2(x)
  1036. x = self.relu(x)
  1037. # print(x.size())
  1038. x_hop3 = self.deconv3_3(x)
  1039. # print(x_hop3.size())
  1040. return x_hop1, x_hop2, x_hop3
  1041. # # reference code for doing residual connections
  1042. # def _make_layer(self, block, planes, blocks, stride=1):
  1043. # downsample = None
  1044. # if stride != 1 or self.inplanes != planes * block.expansion:
  1045. # downsample = nn.Sequential(
  1046. # nn.Conv2d(self.inplanes, planes * block.expansion,
  1047. # kernel_size=1, stride=stride, bias=False),
  1048. # nn.BatchNorm2d(planes * block.expansion),
  1049. # )
  1050. #
  1051. # layers = []
  1052. # layers.append(block(self.inplanes, planes, stride, downsample))
  1053. # self.inplanes = planes * block.expansion
  1054. # for i in range(1, blocks):
  1055. # layers.append(block(self.inplanes, planes))
  1056. #
  1057. # return nn.Sequential(*layers)
  1058. class CNN_decoder_share(nn.Module):
  1059. def __init__(self, input_size, output_size, stride, hops):
  1060. super(CNN_decoder_share, self).__init__()
  1061. self.input_size = input_size
  1062. self.output_size = output_size
  1063. self.hops = hops
  1064. self.relu = nn.ReLU()
  1065. self.deconv = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.input_size),
  1066. kernel_size=3, stride=stride)
  1067. self.bn = nn.BatchNorm1d(int(self.input_size))
  1068. self.deconv_out = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.output_size),
  1069. kernel_size=3, stride=1, padding=1)
  1070. for m in self.modules():
  1071. if isinstance(m, nn.ConvTranspose1d):
  1072. # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  1073. # m.weight.dataset.normal_(0, math.sqrt(2. / n))
  1074. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  1075. elif isinstance(m, nn.BatchNorm1d):
  1076. m.weight.data.fill_(1)
  1077. m.bias.data.zero_()
  1078. def forward(self, x):
  1079. '''
  1080. :param
  1081. x: batch * channel * length
  1082. :return:
  1083. '''
  1084. # hop1
  1085. x = self.deconv(x)
  1086. x = self.bn(x)
  1087. x = self.relu(x)
  1088. # print(x.size())
  1089. x = self.deconv(x)
  1090. x = self.bn(x)
  1091. x = self.relu(x)
  1092. # print(x.size())
  1093. x_hop1 = self.deconv_out(x)
  1094. # print(x_hop1.size())
  1095. # hop2
  1096. x = self.deconv(x)
  1097. x = self.bn(x)
  1098. x = self.relu(x)
  1099. # print(x.size())
  1100. x = self.deconv(x)
  1101. x = self.bn(x)
  1102. x = self.relu(x)
  1103. x_hop2 = self.deconv_out(x)
  1104. # print(x_hop2.size())
  1105. # hop3
  1106. x = self.deconv(x)
  1107. x = self.bn(x)
  1108. x = self.relu(x)
  1109. # print(x.size())
  1110. x = self.deconv(x)
  1111. x = self.bn(x)
  1112. x = self.relu(x)
  1113. # print(x.size())
  1114. x_hop3 = self.deconv_out(x)
  1115. # print(x_hop3.size())
  1116. return x_hop1, x_hop2, x_hop3
  1117. class CNN_decoder_attention(nn.Module):
  1118. def __init__(self, input_size, output_size, stride=2):
  1119. super(CNN_decoder_attention, self).__init__()
  1120. self.input_size = input_size
  1121. self.output_size = output_size
  1122. self.relu = nn.ReLU()
  1123. self.deconv = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.input_size),
  1124. kernel_size=3, stride=stride)
  1125. self.bn = nn.BatchNorm1d(int(self.input_size))
  1126. self.deconv_out = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.output_size),
  1127. kernel_size=3, stride=1, padding=1)
  1128. self.deconv_attention = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.input_size),
  1129. kernel_size=1, stride=1, padding=0)
  1130. self.bn_attention = nn.BatchNorm1d(int(self.input_size))
  1131. self.relu_leaky = nn.LeakyReLU(0.2)
  1132. for m in self.modules():
  1133. if isinstance(m, nn.ConvTranspose1d):
  1134. # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  1135. # m.weight.dataset.normal_(0, math.sqrt(2. / n))
  1136. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  1137. elif isinstance(m, nn.BatchNorm1d):
  1138. m.weight.data.fill_(1)
  1139. m.bias.data.zero_()
  1140. def forward(self, x):
  1141. '''
  1142. :param
  1143. x: batch * channel * length
  1144. :return:
  1145. '''
  1146. # hop1
  1147. x = self.deconv(x)
  1148. x = self.bn(x)
  1149. x = self.relu(x)
  1150. x = self.deconv(x)
  1151. x = self.bn(x)
  1152. x = self.relu(x)
  1153. x_hop1 = self.deconv_out(x)
  1154. x_hop1_attention = self.deconv_attention(x)
  1155. # x_hop1_attention = self.bn_attention(x_hop1_attention)
  1156. x_hop1_attention = self.relu(x_hop1_attention)
  1157. x_hop1_attention = torch.matmul(x_hop1_attention,
  1158. x_hop1_attention.view(-1, x_hop1_attention.size(2), x_hop1_attention.size(1)))
  1159. # x_hop1_attention_sum = torch.norm(x_hop1_attention, 2, dim=1, keepdim=True)
  1160. # x_hop1_attention = x_hop1_attention/x_hop1_attention_sum
  1161. # print(x_hop1.size())
  1162. # hop2
  1163. x = self.deconv(x)
  1164. x = self.bn(x)
  1165. x = self.relu(x)
  1166. x = self.deconv(x)
  1167. x = self.bn(x)
  1168. x = self.relu(x)
  1169. x_hop2 = self.deconv_out(x)
  1170. x_hop2_attention = self.deconv_attention(x)
  1171. # x_hop2_attention = self.bn_attention(x_hop2_attention)
  1172. x_hop2_attention = self.relu(x_hop2_attention)
  1173. x_hop2_attention = torch.matmul(x_hop2_attention,
  1174. x_hop2_attention.view(-1, x_hop2_attention.size(2), x_hop2_attention.size(1)))
  1175. # x_hop2_attention_sum = torch.norm(x_hop2_attention, 2, dim=1, keepdim=True)
  1176. # x_hop2_attention = x_hop2_attention/x_hop2_attention_sum
  1177. # print(x_hop2.size())
  1178. # hop3
  1179. x = self.deconv(x)
  1180. x = self.bn(x)
  1181. x = self.relu(x)
  1182. x = self.deconv(x)
  1183. x = self.bn(x)
  1184. x = self.relu(x)
  1185. x_hop3 = self.deconv_out(x)
  1186. x_hop3_attention = self.deconv_attention(x)
  1187. # x_hop3_attention = self.bn_attention(x_hop3_attention)
  1188. x_hop3_attention = self.relu(x_hop3_attention)
  1189. x_hop3_attention = torch.matmul(x_hop3_attention,
  1190. x_hop3_attention.view(-1, x_hop3_attention.size(2), x_hop3_attention.size(1)))
  1191. # x_hop3_attention_sum = torch.norm(x_hop3_attention, 2, dim=1, keepdim=True)
  1192. # x_hop3_attention = x_hop3_attention / x_hop3_attention_sum
  1193. # print(x_hop3.size())
  1194. return x_hop1, x_hop2, x_hop3, x_hop1_attention, x_hop2_attention, x_hop3_attention
  1195. #### test code ####
  1196. # x = Variable(torch.randn(1, 256, 1)).cuda()
  1197. # decoder = CNN_decoder(256, 16).cuda()
  1198. # y = decoder(x)
  1199. class Graphsage_Encoder(nn.Module):
  1200. def __init__(self, feature_size, input_size, layer_num):
  1201. super(Graphsage_Encoder, self).__init__()
  1202. self.linear_projection = nn.Linear(feature_size, input_size)
  1203. self.input_size = input_size
  1204. # linear for hop 3
  1205. self.linear_3_0 = nn.Linear(input_size * (2 ** 0), input_size * (2 ** 1))
  1206. self.linear_3_1 = nn.Linear(input_size * (2 ** 1), input_size * (2 ** 2))
  1207. self.linear_3_2 = nn.Linear(input_size * (2 ** 2), input_size * (2 ** 3))
  1208. # linear for hop 2
  1209. self.linear_2_0 = nn.Linear(input_size * (2 ** 0), input_size * (2 ** 1))
  1210. self.linear_2_1 = nn.Linear(input_size * (2 ** 1), input_size * (2 ** 2))
  1211. # linear for hop 1
  1212. self.linear_1_0 = nn.Linear(input_size * (2 ** 0), input_size * (2 ** 1))
  1213. # linear for hop 0
  1214. self.linear_0_0 = nn.Linear(input_size * (2 ** 0), input_size * (2 ** 1))
  1215. self.linear = nn.Linear(input_size * (2 + 2 + 4 + 8), input_size * (16))
  1216. self.bn_3_0 = nn.BatchNorm1d(self.input_size * (2 ** 1))
  1217. self.bn_3_1 = nn.BatchNorm1d(self.input_size * (2 ** 2))
  1218. self.bn_3_2 = nn.BatchNorm1d(self.input_size * (2 ** 3))
  1219. self.bn_2_0 = nn.BatchNorm1d(self.input_size * (2 ** 1))
  1220. self.bn_2_1 = nn.BatchNorm1d(self.input_size * (2 ** 2))
  1221. self.bn_1_0 = nn.BatchNorm1d(self.input_size * (2 ** 1))
  1222. self.bn_0_0 = nn.BatchNorm1d(self.input_size * (2 ** 1))
  1223. self.bn = nn.BatchNorm1d(input_size * (16))
  1224. self.relu = nn.ReLU()
  1225. for m in self.modules():
  1226. if isinstance(m, nn.Linear):
  1227. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  1228. elif isinstance(m, nn.BatchNorm1d):
  1229. m.weight.data.fill_(1)
  1230. m.bias.data.zero_()
  1231. def forward(self, nodes_list, nodes_count_list):
  1232. '''
  1233. :param nodes: a list, each element n_i is a tensor for node's k-i hop neighbours
  1234. (the first nodes_hop is the furthest neighbor)
  1235. where n_i = N * num_neighbours * features
  1236. nodes_count: a list, each element is a list that show how many neighbours belongs to the father node
  1237. :return:
  1238. '''
  1239. # 3-hop feature
  1240. # nodes original features to representations
  1241. nodes_list[0] = Variable(nodes_list[0]).cuda()
  1242. nodes_list[0] = self.linear_projection(nodes_list[0])
  1243. nodes_features = self.linear_3_0(nodes_list[0])
  1244. nodes_features = self.bn_3_0(nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1)))
  1245. nodes_features = nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1))
  1246. nodes_features = self.relu(nodes_features)
  1247. # nodes count from previous hop
  1248. nodes_count = nodes_count_list[0]
  1249. # print(nodes_count,nodes_count.size())
  1250. # aggregated representations placeholder, feature dim * 2
  1251. nodes_features_farther = Variable(
  1252. torch.Tensor(nodes_features.size(0), nodes_count.size(1), nodes_features.size(2))).cuda()
  1253. i = 0
  1254. for j in range(nodes_count.size(1)):
  1255. # mean pooling for each father node
  1256. # print(nodes_count[:,j][0],type(nodes_count[:,j][0]))
  1257. nodes_features_farther[:, j, :] = torch.mean(nodes_features[:, i:i + int(nodes_count[:, j][0]), :], 1,
  1258. keepdim=False)
  1259. i += int(nodes_count[:, j][0])
  1260. # assign node_features
  1261. nodes_features = nodes_features_farther
  1262. nodes_features = self.linear_3_1(nodes_features)
  1263. nodes_features = self.bn_3_1(nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1)))
  1264. nodes_features = nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1))
  1265. nodes_features = self.relu(nodes_features)
  1266. # nodes count from previous hop
  1267. nodes_count = nodes_count_list[1]
  1268. # aggregated representations placeholder, feature dim * 2
  1269. nodes_features_farther = Variable(
  1270. torch.Tensor(nodes_features.size(0), nodes_count.size(1), nodes_features.size(2))).cuda()
  1271. i = 0
  1272. for j in range(nodes_count.size(1)):
  1273. # mean pooling for each father node
  1274. nodes_features_farther[:, j, :] = torch.mean(nodes_features[:, i:i + int(nodes_count[:, j][0]), :], 1,
  1275. keepdim=False)
  1276. i += int(nodes_count[:, j][0])
  1277. # assign node_features
  1278. nodes_features = nodes_features_farther
  1279. # print('nodes_feature',nodes_features.size())
  1280. nodes_features = self.linear_3_2(nodes_features)
  1281. nodes_features = self.bn_3_2(nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1)))
  1282. nodes_features = nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1))
  1283. # nodes_features = self.relu(nodes_features)
  1284. # nodes count from previous hop
  1285. nodes_features_hop_3 = torch.mean(nodes_features, 1, keepdim=True)
  1286. # print(nodes_features_hop_3.size())
  1287. # 2-hop feature
  1288. # nodes original features to representations
  1289. nodes_list[1] = Variable(nodes_list[1]).cuda()
  1290. nodes_list[1] = self.linear_projection(nodes_list[1])
  1291. nodes_features = self.linear_2_0(nodes_list[1])
  1292. nodes_features = self.bn_2_0(nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1)))
  1293. nodes_features = nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1))
  1294. nodes_features = self.relu(nodes_features)
  1295. # nodes count from previous hop
  1296. nodes_count = nodes_count_list[1]
  1297. # aggregated representations placeholder, feature dim * 2
  1298. nodes_features_farther = Variable(
  1299. torch.Tensor(nodes_features.size(0), nodes_count.size(1), nodes_features.size(2))).cuda()
  1300. i = 0
  1301. for j in range(nodes_count.size(1)):
  1302. # mean pooling for each father node
  1303. nodes_features_farther[:, j, :] = torch.mean(nodes_features[:, i:i + int(nodes_count[:, j][0]), :], 1,
  1304. keepdim=False)
  1305. i += int(nodes_count[:, j][0])
  1306. # assign node_features
  1307. nodes_features = nodes_features_farther
  1308. nodes_features = self.linear_2_1(nodes_features)
  1309. nodes_features = self.bn_2_1(nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1)))
  1310. nodes_features = nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1))
  1311. # nodes_features = self.relu(nodes_features)
  1312. # nodes count from previous hop
  1313. nodes_features_hop_2 = torch.mean(nodes_features, 1, keepdim=True)
  1314. # print(nodes_features_hop_2.size())
  1315. # 1-hop feature
  1316. # nodes original features to representations
  1317. nodes_list[2] = Variable(nodes_list[2]).cuda()
  1318. nodes_list[2] = self.linear_projection(nodes_list[2])
  1319. nodes_features = self.linear_1_0(nodes_list[2])
  1320. nodes_features = self.bn_1_0(nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1)))
  1321. nodes_features = nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1))
  1322. # nodes_features = self.relu(nodes_features)
  1323. # nodes count from previous hop
  1324. nodes_features_hop_1 = torch.mean(nodes_features, 1, keepdim=True)
  1325. # print(nodes_features_hop_1.size())
  1326. # own feature
  1327. nodes_list[3] = Variable(nodes_list[3]).cuda()
  1328. nodes_list[3] = self.linear_projection(nodes_list[3])
  1329. nodes_features = self.linear_0_0(nodes_list[3])
  1330. nodes_features = self.bn_0_0(nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1)))
  1331. nodes_features_hop_0 = nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1))
  1332. # print(nodes_features_hop_0.size())
  1333. # concatenate
  1334. nodes_features = torch.cat(
  1335. (nodes_features_hop_0, nodes_features_hop_1, nodes_features_hop_2, nodes_features_hop_3), dim=2)
  1336. nodes_features = self.linear(nodes_features)
  1337. # nodes_features = self.bn(nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1)))
  1338. nodes_features = nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1))
  1339. # print(nodes_features.size())
  1340. return (nodes_features)