You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 57KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500
  1. from __future__ import unicode_literals, print_function, division
  2. from io import open
  3. import unicodedata
  4. import string
  5. import re
  6. import random
  7. import torch
  8. import torch.nn as nn
  9. from torch.autograd import Variable
  10. from torch import optim
  11. import torch.nn.functional as F
  12. import torch.nn.init as init
  13. from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
  14. from collections import OrderedDict
  15. import math
  16. import numpy as np
  17. import time
  18. def binary_cross_entropy_weight(y_pred, y,has_weight=False, weight_length=1, weight_max=10):
  19. '''
  20. :param y_pred:
  21. :param y:
  22. :param weight_length: how long until the end of sequence shall we add weight
  23. :param weight_value: the magnitude that the weight is enhanced
  24. :return:
  25. '''
  26. if has_weight:
  27. weight = torch.ones(y.size(0),y.size(1),y.size(2))
  28. weight_linear = torch.arange(1,weight_length+1)/weight_length*weight_max
  29. weight_linear = weight_linear.view(1,weight_length,1).repeat(y.size(0),1,y.size(2))
  30. weight[:,-1*weight_length:,:] = weight_linear
  31. loss = F.binary_cross_entropy(y_pred, y, weight=weight.cuda())
  32. else:
  33. loss = F.binary_cross_entropy(y_pred, y)
  34. return loss
  35. def sample_tensor(y,sample=True, thresh=0.5):
  36. # do sampling
  37. if sample:
  38. y_thresh = Variable(torch.rand(y.size())).cuda()
  39. y_result = torch.gt(y,y_thresh).float()
  40. # do max likelihood based on some threshold
  41. else:
  42. y_thresh = Variable(torch.ones(y.size())*thresh).cuda()
  43. y_result = torch.gt(y, y_thresh).float()
  44. return y_result
  45. def gumbel_softmax(logits, temperature, eps=1e-9):
  46. '''
  47. :param logits: shape: N*L
  48. :param temperature:
  49. :param eps:
  50. :return:
  51. '''
  52. # get gumbel noise
  53. noise = torch.rand(logits.size())
  54. noise.add_(eps).log_().neg_()
  55. noise.add_(eps).log_().neg_()
  56. noise = Variable(noise).cuda()
  57. x = (logits + noise) / temperature
  58. x = F.softmax(x)
  59. return x
  60. # for i in range(10):
  61. # x = Variable(torch.randn(1,10)).cuda()
  62. # y = gumbel_softmax(x, temperature=0.01)
  63. # print(x)
  64. # print(y)
  65. # _,id = y.topk(1)
  66. # print(id)
  67. def gumbel_sigmoid(logits, temperature):
  68. '''
  69. :param logits:
  70. :param temperature:
  71. :param eps:
  72. :return:
  73. '''
  74. # get gumbel noise
  75. noise = torch.rand(logits.size()) # uniform(0,1)
  76. noise_logistic = torch.log(noise)-torch.log(1-noise) # logistic(0,1)
  77. noise = Variable(noise_logistic).cuda()
  78. x = (logits + noise) / temperature
  79. x = F.sigmoid(x)
  80. return x
  81. # x = Variable(torch.randn(100)).cuda()
  82. # y = gumbel_sigmoid(x,temperature=0.01)
  83. # print(x)
  84. # print(y)
  85. def sample_sigmoid(y, sample, thresh=0.5, sample_time=2):
  86. '''
  87. do sampling over unnormalized score
  88. :param y: input
  89. :param sample: Bool
  90. :param thresh: if not sample, the threshold
  91. :param sampe_time: how many times do we sample, if =1, do single sample
  92. :return: sampled result
  93. '''
  94. # do sigmoid first
  95. y = F.sigmoid(y)
  96. # do sampling
  97. if sample:
  98. if sample_time>1:
  99. y_result = Variable(torch.rand(y.size(0),y.size(1),y.size(2))).cuda()
  100. # loop over all batches
  101. for i in range(y_result.size(0)):
  102. # do 'multi_sample' times sampling
  103. for j in range(sample_time):
  104. y_thresh = Variable(torch.rand(y.size(1), y.size(2))).cuda()
  105. y_result[i] = torch.gt(y[i], y_thresh).float()
  106. if (torch.sum(y_result[i]).data>0).any():
  107. break
  108. # else:
  109. # print('all zero',j)
  110. else:
  111. y_thresh = Variable(torch.rand(y.size(0),y.size(1),y.size(2))).cuda()
  112. y_result = torch.gt(y,y_thresh).float()
  113. # do max likelihood based on some threshold
  114. else:
  115. y_thresh = Variable(torch.ones(y.size(0), y.size(1), y.size(2))*thresh).cuda()
  116. y_result = torch.gt(y, y_thresh).float()
  117. return y_result
  118. def sample_sigmoid_supervised(y_pred, y, current, y_len, sample_time=2):
  119. '''
  120. do sampling over unnormalized score
  121. :param y_pred: input
  122. :param y: supervision
  123. :param sample: Bool
  124. :param thresh: if not sample, the threshold
  125. :param sampe_time: how many times do we sample, if =1, do single sample
  126. :return: sampled result
  127. '''
  128. # do sigmoid first
  129. y_pred = F.sigmoid(y_pred)
  130. # do sampling
  131. y_result = Variable(torch.rand(y_pred.size(0), y_pred.size(1), y_pred.size(2))).cuda()
  132. # loop over all batches
  133. for i in range(y_result.size(0)):
  134. # using supervision
  135. if current<y_len[i]:
  136. while True:
  137. y_thresh = Variable(torch.rand(y_pred.size(1), y_pred.size(2))).cuda()
  138. y_result[i] = torch.gt(y_pred[i], y_thresh).float()
  139. # print('current',current)
  140. # print('y_result',y_result[i].data)
  141. # print('y',y[i])
  142. y_diff = y_result[i].data-y[i]
  143. if (y_diff>=0).all():
  144. break
  145. # supervision done
  146. else:
  147. # do 'multi_sample' times sampling
  148. for j in range(sample_time):
  149. y_thresh = Variable(torch.rand(y_pred.size(1), y_pred.size(2))).cuda()
  150. y_result[i] = torch.gt(y_pred[i], y_thresh).float()
  151. if (torch.sum(y_result[i]).data>0).any():
  152. break
  153. return y_result
  154. def sample_sigmoid_supervised_simple(y_pred, y, current, y_len, sample_time=2):
  155. '''
  156. do sampling over unnormalized score
  157. :param y_pred: input
  158. :param y: supervision
  159. :param sample: Bool
  160. :param thresh: if not sample, the threshold
  161. :param sampe_time: how many times do we sample, if =1, do single sample
  162. :return: sampled result
  163. '''
  164. # do sigmoid first
  165. y_pred = F.sigmoid(y_pred)
  166. # do sampling
  167. y_result = Variable(torch.rand(y_pred.size(0), y_pred.size(1), y_pred.size(2))).cuda()
  168. # loop over all batches
  169. for i in range(y_result.size(0)):
  170. # using supervision
  171. if current<y_len[i]:
  172. y_result[i] = y[i]
  173. # supervision done
  174. else:
  175. # do 'multi_sample' times sampling
  176. for j in range(sample_time):
  177. y_thresh = Variable(torch.rand(y_pred.size(1), y_pred.size(2))).cuda()
  178. y_result[i] = torch.gt(y_pred[i], y_thresh).float()
  179. if (torch.sum(y_result[i]).data>0).any():
  180. break
  181. return y_result
  182. ################### current adopted model, LSTM+MLP || LSTM+VAE || LSTM+LSTM (where LSTM can be GRU as well)
  183. #####
  184. # definition of terms
  185. # h: hidden state of LSTM
  186. # y: edge prediction, model output
  187. # n: noise for generator
  188. # l: whether an output is real or not, binary
  189. # plain LSTM model
  190. class LSTM_plain(nn.Module):
  191. def __init__(self, input_size, embedding_size, hidden_size, num_layers, has_input=True, has_output=False, output_size=None):
  192. super(LSTM_plain, self).__init__()
  193. self.num_layers = num_layers
  194. self.hidden_size = hidden_size
  195. self.has_input = has_input
  196. self.has_output = has_output
  197. if has_input:
  198. self.input = nn.Linear(input_size, embedding_size)
  199. self.rnn = nn.LSTM(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  200. else:
  201. self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  202. if has_output:
  203. self.output = nn.Sequential(
  204. nn.Linear(hidden_size, embedding_size),
  205. nn.ReLU(),
  206. nn.Linear(embedding_size, output_size)
  207. )
  208. self.relu = nn.ReLU()
  209. # initialize
  210. self.hidden = None # need initialize before forward run
  211. for name, param in self.rnn.named_parameters():
  212. if 'bias' in name:
  213. nn.init.constant(param, 0.25)
  214. elif 'weight' in name:
  215. nn.init.xavier_uniform(param,gain=nn.init.calculate_gain('sigmoid'))
  216. for m in self.modules():
  217. if isinstance(m, nn.Linear):
  218. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  219. def init_hidden(self, batch_size):
  220. return (Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).cuda(),
  221. Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).cuda())
  222. def forward(self, input_raw, pack=False, input_len=None):
  223. if self.has_input:
  224. input = self.input(input_raw)
  225. input = self.relu(input)
  226. else:
  227. input = input_raw
  228. if pack:
  229. input = pack_padded_sequence(input, input_len, batch_first=True)
  230. output_raw, self.hidden = self.rnn(input, self.hidden)
  231. if pack:
  232. output_raw = pad_packed_sequence(output_raw, batch_first=True)[0]
  233. if self.has_output:
  234. output_raw = self.output(output_raw)
  235. # return hidden state at each time step
  236. return output_raw
  237. # plain GRU model
  238. class GRU_plain(nn.Module):
  239. def __init__(self, input_size, embedding_size, hidden_size, num_layers, has_input=True, has_output=False, output_size=None):
  240. super(GRU_plain, self).__init__()
  241. self.num_layers = num_layers
  242. self.hidden_size = hidden_size
  243. self.has_input = has_input
  244. self.has_output = has_output
  245. if has_input:
  246. self.input = nn.Linear(input_size, embedding_size)
  247. self.rnn = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
  248. batch_first=True)
  249. else:
  250. self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  251. if has_output:
  252. self.output = nn.Sequential(
  253. nn.Linear(hidden_size, embedding_size),
  254. nn.ReLU(),
  255. nn.Linear(embedding_size, output_size)
  256. )
  257. self.relu = nn.ReLU()
  258. # initialize
  259. self.hidden = None # need initialize before forward run
  260. for name, param in self.rnn.named_parameters():
  261. if 'bias' in name:
  262. nn.init.constant(param, 0.25)
  263. elif 'weight' in name:
  264. nn.init.xavier_uniform(param,gain=nn.init.calculate_gain('sigmoid'))
  265. for m in self.modules():
  266. if isinstance(m, nn.Linear):
  267. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  268. def init_hidden(self, batch_size):
  269. return Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)).cuda()
  270. def forward(self, input_raw, pack=False, input_len=None):
  271. if self.has_input:
  272. input = self.input(input_raw)
  273. input = self.relu(input)
  274. else:
  275. input = input_raw
  276. if pack:
  277. input = pack_padded_sequence(input, input_len, batch_first=True)
  278. output_raw, self.hidden = self.rnn(input, self.hidden)
  279. if pack:
  280. output_raw = pad_packed_sequence(output_raw, batch_first=True)[0]
  281. if self.has_output:
  282. output_raw = self.output(output_raw)
  283. # return hidden state at each time step
  284. return output_raw
  285. # a deterministic linear output
  286. class MLP_plain(nn.Module):
  287. def __init__(self, h_size, embedding_size, y_size):
  288. super(MLP_plain, self).__init__()
  289. self.deterministic_output = nn.Sequential(
  290. nn.Linear(h_size, embedding_size),
  291. nn.ReLU(),
  292. nn.Linear(embedding_size, y_size)
  293. )
  294. for m in self.modules():
  295. if isinstance(m, nn.Linear):
  296. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  297. def forward(self, h):
  298. y = self.deterministic_output(h)
  299. return y
  300. # a deterministic linear output, additional output indicates if the sequence should continue grow
  301. class MLP_token_plain(nn.Module):
  302. def __init__(self, h_size, embedding_size, y_size):
  303. super(MLP_token_plain, self).__init__()
  304. self.deterministic_output = nn.Sequential(
  305. nn.Linear(h_size, embedding_size),
  306. nn.ReLU(),
  307. nn.Linear(embedding_size, y_size)
  308. )
  309. self.token_output = nn.Sequential(
  310. nn.Linear(h_size, embedding_size),
  311. nn.ReLU(),
  312. nn.Linear(embedding_size, 1)
  313. )
  314. for m in self.modules():
  315. if isinstance(m, nn.Linear):
  316. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  317. def forward(self, h):
  318. y = self.deterministic_output(h)
  319. t = self.token_output(h)
  320. return y,t
  321. # a deterministic linear output (update: add noise)
  322. class MLP_VAE_plain(nn.Module):
  323. def __init__(self, h_size, embedding_size, y_size):
  324. super(MLP_VAE_plain, self).__init__()
  325. self.encode_11 = nn.Linear(h_size, embedding_size) # mu
  326. self.encode_12 = nn.Linear(h_size, embedding_size) # lsgms
  327. self.decode_1 = nn.Linear(embedding_size, embedding_size)
  328. self.decode_2 = nn.Linear(embedding_size, y_size) # make edge prediction (reconstruct)
  329. self.relu = nn.ReLU()
  330. for m in self.modules():
  331. if isinstance(m, nn.Linear):
  332. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  333. def forward(self, h):
  334. # encoder
  335. z_mu = self.encode_11(h)
  336. z_lsgms = self.encode_12(h)
  337. # reparameterize
  338. z_sgm = z_lsgms.mul(0.5).exp_()
  339. eps = Variable(torch.randn(z_sgm.size())).cuda()
  340. z = eps*z_sgm + z_mu
  341. # decoder
  342. y = self.decode_1(z)
  343. y = self.relu(y)
  344. y = self.decode_2(y)
  345. return y, z_mu, z_lsgms
  346. # a deterministic linear output (update: add noise)
  347. class MLP_VAE_conditional_plain(nn.Module):
  348. def __init__(self, h_size, embedding_size, y_size):
  349. super(MLP_VAE_conditional_plain, self).__init__()
  350. self.encode_11 = nn.Linear(h_size, embedding_size) # mu
  351. self.encode_12 = nn.Linear(h_size, embedding_size) # lsgms
  352. self.decode_1 = nn.Linear(embedding_size+h_size, embedding_size)
  353. self.decode_2 = nn.Linear(embedding_size, y_size) # make edge prediction (reconstruct)
  354. self.relu = nn.ReLU()
  355. for m in self.modules():
  356. if isinstance(m, nn.Linear):
  357. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  358. def forward(self, h):
  359. # encoder
  360. z_mu = self.encode_11(h)
  361. z_lsgms = self.encode_12(h)
  362. # reparameterize
  363. z_sgm = z_lsgms.mul(0.5).exp_()
  364. eps = Variable(torch.randn(z_sgm.size(0), z_sgm.size(1), z_sgm.size(2))).cuda()
  365. z = eps * z_sgm + z_mu
  366. # decoder
  367. y = self.decode_1(torch.cat((h,z),dim=2))
  368. y = self.relu(y)
  369. y = self.decode_2(y)
  370. return y, z_mu, z_lsgms
  371. ########### baseline model 1: Learning deep generative model of graphs
  372. class DGM_graphs(nn.Module):
  373. def __init__(self,h_size):
  374. # h_size: node embedding size
  375. # h_size*2: graph embedding size
  376. super(DGM_graphs, self).__init__()
  377. ### all modules used by the model
  378. ## 1 message passing, 2 times
  379. self.m_uv_1 = nn.Linear(h_size*2, h_size*2)
  380. self.f_n_1 = nn.GRUCell(h_size*2, h_size) # input_size, hidden_size
  381. self.m_uv_2 = nn.Linear(h_size * 2, h_size * 2)
  382. self.f_n_2 = nn.GRUCell(h_size * 2, h_size) # input_size, hidden_size
  383. ## 2 graph embedding and new node embedding
  384. # for graph embedding
  385. self.f_m = nn.Linear(h_size, h_size*2)
  386. self.f_gate = nn.Sequential(
  387. nn.Linear(h_size,1),
  388. nn.Sigmoid()
  389. )
  390. # for new node embedding
  391. self.f_m_init = nn.Linear(h_size, h_size*2)
  392. self.f_gate_init = nn.Sequential(
  393. nn.Linear(h_size,1),
  394. nn.Sigmoid()
  395. )
  396. self.f_init = nn.Linear(h_size*2, h_size)
  397. ## 3 f_addnode
  398. self.f_an = nn.Sequential(
  399. nn.Linear(h_size*2,1),
  400. nn.Sigmoid()
  401. )
  402. ## 4 f_addedge
  403. self.f_ae = nn.Sequential(
  404. nn.Linear(h_size * 2, 1),
  405. nn.Sigmoid()
  406. )
  407. ## 5 f_nodes
  408. self.f_s = nn.Linear(h_size*2, 1)
  409. def message_passing(node_neighbor, node_embedding, model):
  410. node_embedding_new = []
  411. for i in range(len(node_neighbor)):
  412. neighbor_num = len(node_neighbor[i])
  413. if neighbor_num > 0:
  414. node_self = node_embedding[i].expand(neighbor_num, node_embedding[i].size(1))
  415. node_self_neighbor = torch.cat([node_embedding[j] for j in node_neighbor[i]], dim=0)
  416. message = torch.sum(model.m_uv_1(torch.cat((node_self, node_self_neighbor), dim=1)), dim=0, keepdim=True)
  417. node_embedding_new.append(model.f_n_1(message, node_embedding[i]))
  418. else:
  419. message_null = Variable(torch.zeros((node_embedding[i].size(0),node_embedding[i].size(1)*2))).cuda()
  420. node_embedding_new.append(model.f_n_1(message_null, node_embedding[i]))
  421. node_embedding = node_embedding_new
  422. node_embedding_new = []
  423. for i in range(len(node_neighbor)):
  424. neighbor_num = len(node_neighbor[i])
  425. if neighbor_num > 0:
  426. node_self = node_embedding[i].expand(neighbor_num, node_embedding[i].size(1))
  427. node_self_neighbor = torch.cat([node_embedding[j] for j in node_neighbor[i]], dim=0)
  428. message = torch.sum(model.m_uv_1(torch.cat((node_self, node_self_neighbor), dim=1)), dim=0, keepdim=True)
  429. node_embedding_new.append(model.f_n_1(message, node_embedding[i]))
  430. else:
  431. message_null = Variable(torch.zeros((node_embedding[i].size(0), node_embedding[i].size(1) * 2))).cuda()
  432. node_embedding_new.append(model.f_n_1(message_null, node_embedding[i]))
  433. return node_embedding_new
  434. def calc_graph_embedding(node_embedding_cat, model):
  435. node_embedding_graph = model.f_m(node_embedding_cat)
  436. node_embedding_graph_gate = model.f_gate(node_embedding_cat)
  437. graph_embedding = torch.sum(torch.mul(node_embedding_graph, node_embedding_graph_gate), dim=0, keepdim=True)
  438. return graph_embedding
  439. def calc_init_embedding(node_embedding_cat, model):
  440. node_embedding_init = model.f_m_init(node_embedding_cat)
  441. node_embedding_init_gate = model.f_gate_init(node_embedding_cat)
  442. init_embedding = torch.sum(torch.mul(node_embedding_init, node_embedding_init_gate), dim=0, keepdim=True)
  443. init_embedding = model.f_init(init_embedding)
  444. return init_embedding
  445. ################################################## code that are NOT used for final version #############
  446. # RNN that updates according to graph structure, new proposed model
  447. class Graph_RNN_structure(nn.Module):
  448. def __init__(self, hidden_size, batch_size, output_size, num_layers, is_dilation=True, is_bn=True):
  449. super(Graph_RNN_structure, self).__init__()
  450. ## model configuration
  451. self.hidden_size = hidden_size
  452. self.batch_size = batch_size
  453. self.output_size = output_size
  454. self.num_layers = num_layers # num_layers of cnn_output
  455. self.is_bn=is_bn
  456. ## model
  457. self.relu = nn.ReLU()
  458. # self.linear_output = nn.Linear(hidden_size, 1)
  459. # self.linear_output_simple = nn.Linear(hidden_size, output_size)
  460. # for state transition use only, input is null
  461. # self.gru = nn.GRU(input_size=1, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  462. # use CNN to produce output prediction
  463. # self.cnn_output = nn.Sequential(
  464. # nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=1, padding=1),
  465. # # nn.BatchNorm1d(hidden_size),
  466. # nn.ReLU(),
  467. # nn.Conv1d(hidden_size, 1, kernel_size=3, dilation=1, padding=1)
  468. # )
  469. if is_dilation:
  470. self.conv_block = nn.ModuleList([nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=2**i, padding=2**i) for i in range(num_layers-1)])
  471. else:
  472. self.conv_block = nn.ModuleList([nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=1, padding=1) for i in range(num_layers-1)])
  473. self.bn_block = nn.ModuleList([nn.BatchNorm1d(hidden_size) for i in range(num_layers-1)])
  474. self.conv_out = nn.Conv1d(hidden_size, 1, kernel_size=3, dilation=1, padding=1)
  475. # # use CNN to do state transition
  476. # self.cnn_transition = nn.Sequential(
  477. # nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=1, padding=1),
  478. # # nn.BatchNorm1d(hidden_size),
  479. # nn.ReLU(),
  480. # nn.Conv1d(hidden_size, hidden_size, kernel_size=3, dilation=1, padding=1)
  481. # )
  482. # use linear to do transition, same as GCN mean aggregator
  483. self.linear_transition = nn.Sequential(
  484. nn.Linear(hidden_size,hidden_size),
  485. nn.ReLU()
  486. )
  487. # GRU based output, output a single edge prediction at a time
  488. # self.gru_output = nn.GRU(input_size=1, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  489. # use a list to keep all generated hidden vectors, each hidden has size batch*hidden_dim*1, and the list size is expanding
  490. # when using convolution to compute attention weight, we need to first concat the list into a pytorch variable: batch*hidden_dim*current_num_nodes
  491. self.hidden_all = []
  492. ## initialize
  493. for m in self.modules():
  494. if isinstance(m, nn.Linear):
  495. # print('linear')
  496. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  497. # print(m.weight.data.size())
  498. if isinstance(m, nn.Conv1d):
  499. # print('conv1d')
  500. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  501. # print(m.weight.data.size())
  502. if isinstance(m, nn.BatchNorm1d):
  503. # print('batchnorm1d')
  504. m.weight.data.fill_(1)
  505. m.bias.data.zero_()
  506. # print(m.weight.data.size())
  507. if isinstance(m, nn.GRU):
  508. # print('gru')
  509. m.weight_ih_l0.data = init.xavier_uniform(m.weight_ih_l0.data,
  510. gain=nn.init.calculate_gain('sigmoid'))
  511. m.weight_hh_l0.data = init.xavier_uniform(m.weight_hh_l0.data,
  512. gain=nn.init.calculate_gain('sigmoid'))
  513. m.bias_ih_l0.data = torch.ones(m.bias_ih_l0.data.size(0)) * 0.25
  514. m.bias_hh_l0.data = torch.ones(m.bias_hh_l0.data.size(0)) * 0.25
  515. def init_hidden(self,len=None):
  516. if len is None:
  517. return Variable(torch.ones(self.batch_size, self.hidden_size, 1)).cuda()
  518. else:
  519. hidden_list = []
  520. for i in range(len):
  521. hidden_list.append(Variable(torch.ones(self.batch_size, self.hidden_size, 1)).cuda())
  522. return hidden_list
  523. # only run a single forward step
  524. def forward(self, x, teacher_forcing, temperature = 0.5, bptt=True,bptt_len=20, flexible=True,max_prev_node=100):
  525. # x: batch*1*self.output_size, the groud truth
  526. # todo: current only look back to self.output_size nodes, try to look back according to bfs sequence
  527. # 1 first compute new state
  528. # print('hidden_all', self.hidden_all[-1*self.output_size:])
  529. # hidden_all_cat = torch.cat(self.hidden_all[-1*self.output_size:], dim=2)
  530. # # # add BPTT, detach the first variable
  531. # if bptt:
  532. # self.hidden_all[0] = Variable(self.hidden_all[0].data).cuda()
  533. hidden_all_cat = torch.cat(self.hidden_all, dim=2)
  534. # print(hidden_all_cat.size())
  535. # print('hidden_all_cat',hidden_all_cat.size())
  536. # att_weight size: batch*1*current_num_nodes
  537. for i in range(self.num_layers-1):
  538. hidden_all_cat = self.conv_block[i](hidden_all_cat)
  539. if self.is_bn:
  540. hidden_all_cat = self.bn_block[i](hidden_all_cat)
  541. hidden_all_cat = self.relu(hidden_all_cat)
  542. x_pred = self.conv_out(hidden_all_cat)
  543. # 2 then compute output, using a gru
  544. # first try the simple version, directly give the edge prediction
  545. # x_pred = self.linear_output_simple(hidden_new)
  546. # x_pred = x_pred.view(x_pred.size(0),1,x_pred.size(1))
  547. # todo: use a gru version output
  548. # if sample==False:
  549. # # when training: we know the ground truth, input the sequence at once
  550. # y_pred,_ = self.gru_output(x, hidden_new.permute(2,0,1))
  551. # y_pred = self.linear_output(y_pred)
  552. # else:
  553. # # when validating, we need to sampling at each time step
  554. # y_pred = Variable(torch.zeros(x.size(0), x.size(1), x.size(2))).cuda()
  555. # y_pred_long = Variable(torch.zeros(x.size(0), x.size(1), x.size(2))).cuda()
  556. # x_step = x[:, 0:1, :]
  557. # for i in range(x.size(1)):
  558. # y_step,_ = self.gru_output(x_step)
  559. # y_step = self.linear_output(y_step)
  560. # y_pred[:, i, :] = y_step
  561. # y_step = F.sigmoid(y_step)
  562. # x_step = sample(y_step, sample=True, thresh=0.45)
  563. # y_pred_long[:, i, :] = x_step
  564. # pass
  565. # 3 then update self.hidden_all list
  566. # i.e., model will use ground truth to update new node
  567. # x_pred_sample = gumbel_sigmoid(x_pred, temperature=temperature)
  568. x_pred_sample = sample_tensor(F.sigmoid(x_pred),sample=True)
  569. thresh = 0.5
  570. x_thresh = Variable(torch.ones(x_pred_sample.size(0), x_pred_sample.size(1), x_pred_sample.size(2)) * thresh).cuda()
  571. x_pred_sample_long = torch.gt(x_pred_sample, x_thresh).long()
  572. if teacher_forcing:
  573. # first mask previous hidden states
  574. hidden_all_cat_select = hidden_all_cat*x
  575. x_sum = torch.sum(x, dim=2, keepdim=True).float()
  576. # i.e., the model will use it's own prediction to attend
  577. else:
  578. # first mask previous hidden states
  579. hidden_all_cat_select = hidden_all_cat*x_pred_sample
  580. x_sum = torch.sum(x_pred_sample_long, dim=2, keepdim=True).float()
  581. # update hidden vector for new nodes
  582. hidden_new = torch.sum(hidden_all_cat_select, dim=2, keepdim=True) / x_sum
  583. hidden_new = self.linear_transition(hidden_new.permute(0, 2, 1))
  584. hidden_new = hidden_new.permute(0, 2, 1)
  585. if flexible:
  586. # use ground truth to maintaing history state
  587. if teacher_forcing:
  588. x_id = torch.min(torch.nonzero(torch.squeeze(x.data)))
  589. self.hidden_all = self.hidden_all[x_id:]
  590. # use prediction to maintaing history state
  591. else:
  592. x_id = torch.min(torch.nonzero(torch.squeeze(x_pred_sample_long.data)))
  593. start = max(len(self.hidden_all)-max_prev_node+1, x_id)
  594. self.hidden_all = self.hidden_all[start:]
  595. # maintaing a fixed size history state
  596. else:
  597. # self.hidden_all.pop(0)
  598. self.hidden_all = self.hidden_all[1:]
  599. self.hidden_all.append(hidden_new)
  600. # 4 return prediction
  601. # print('x_pred',x_pred)
  602. # print('x_pred_mean', torch.mean(x_pred))
  603. # print('x_pred_sample_mean', torch.mean(x_pred_sample))
  604. return x_pred, x_pred_sample
  605. # batch_size = 8
  606. # output_size = 4
  607. # generator = Graph_RNN_structure(hidden_size=16, batch_size=batch_size, output_size=output_size, num_layers=1).cuda()
  608. # for i in range(4):
  609. # generator.hidden_all.append(generator.init_hidden())
  610. #
  611. # x = Variable(torch.rand(batch_size,1,output_size)).cuda()
  612. # x_pred = generator(x,teacher_forcing=True, sample=True)
  613. # print(x_pred)
  614. # current baseline model, generating a graph by lstm
  615. class Graph_generator_LSTM(nn.Module):
  616. def __init__(self,feature_size, input_size, hidden_size, output_size, batch_size, num_layers):
  617. super(Graph_generator_LSTM, self).__init__()
  618. self.batch_size = batch_size
  619. self.num_layers = num_layers
  620. self.hidden_size = hidden_size
  621. self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
  622. self.linear_input = nn.Linear(feature_size, input_size)
  623. self.linear_output = nn.Linear(hidden_size, output_size)
  624. self.relu = nn.ReLU()
  625. # initialize
  626. # self.hidden,self.cell = self.init_hidden()
  627. self.hidden = self.init_hidden()
  628. self.lstm.weight_ih_l0.data = init.xavier_uniform(self.lstm.weight_ih_l0.data, gain=nn.init.calculate_gain('sigmoid'))
  629. self.lstm.weight_hh_l0.data = init.xavier_uniform(self.lstm.weight_hh_l0.data, gain=nn.init.calculate_gain('sigmoid'))
  630. self.lstm.bias_ih_l0.data = torch.ones(self.lstm.bias_ih_l0.data.size(0))*0.25
  631. self.lstm.bias_hh_l0.data = torch.ones(self.lstm.bias_hh_l0.data.size(0))*0.25
  632. for m in self.modules():
  633. if isinstance(m, nn.Linear):
  634. m.weight.data = init.xavier_uniform(m.weight.data,gain=nn.init.calculate_gain('relu'))
  635. def init_hidden(self):
  636. return (Variable(torch.zeros(self.num_layers,self.batch_size, self.hidden_size)).cuda(), Variable(torch.zeros(self.num_layers,self.batch_size, self.hidden_size)).cuda())
  637. def forward(self, input_raw, pack=False,len=None):
  638. input = self.linear_input(input_raw)
  639. input = self.relu(input)
  640. if pack:
  641. input = pack_padded_sequence(input, len, batch_first=True)
  642. output_raw, self.hidden = self.lstm(input, self.hidden)
  643. if pack:
  644. output_raw = pad_packed_sequence(output_raw, batch_first=True)[0]
  645. output = self.linear_output(output_raw)
  646. return output
  647. # a simple MLP generator output
  648. class Graph_generator_LSTM_output_generator(nn.Module):
  649. def __init__(self,h_size, n_size, y_size):
  650. super(Graph_generator_LSTM_output_generator, self).__init__()
  651. # one layer MLP
  652. self.generator_output = nn.Sequential(
  653. nn.Linear(h_size+n_size, 64),
  654. nn.ReLU(),
  655. nn.Linear(64, y_size),
  656. nn.Sigmoid()
  657. )
  658. def forward(self,h,n,temperature):
  659. y_cat = torch.cat((h,n), dim=2)
  660. y = self.generator_output(y_cat)
  661. # y = gumbel_sigmoid(y,temperature=temperature)
  662. return y
  663. # a simple MLP discriminator
  664. class Graph_generator_LSTM_output_discriminator(nn.Module):
  665. def __init__(self, h_size, y_size):
  666. super(Graph_generator_LSTM_output_discriminator, self).__init__()
  667. # one layer MLP
  668. self.discriminator_output = nn.Sequential(
  669. nn.Linear(h_size+y_size, 64),
  670. nn.ReLU(),
  671. nn.Linear(64, 1),
  672. nn.Sigmoid()
  673. )
  674. def forward(self,h,y):
  675. y_cat = torch.cat((h,y),dim=2)
  676. l = self.discriminator_output(y_cat)
  677. return l
  678. # GCN basic operation
  679. class GraphConv(nn.Module):
  680. def __init__(self, input_dim, output_dim):
  681. super(GraphConv, self).__init__()
  682. self.input_dim = input_dim
  683. self.output_dim = output_dim
  684. self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
  685. # self.relu = nn.ReLU()
  686. def forward(self, x, adj):
  687. y = torch.matmul(adj, x)
  688. y = torch.matmul(y,self.weight)
  689. return y
  690. # vanilla GCN encoder
  691. class GCN_encoder(nn.Module):
  692. def __init__(self, input_dim, hidden_dim, output_dim):
  693. super(GCN_encoder, self).__init__()
  694. self.conv1 = GraphConv(input_dim=input_dim, output_dim=hidden_dim)
  695. self.conv2 = GraphConv(input_dim=hidden_dim, output_dim=output_dim)
  696. # self.bn1 = nn.BatchNorm1d(output_dim)
  697. # self.bn2 = nn.BatchNorm1d(output_dim)
  698. self.relu = nn.ReLU()
  699. for m in self.modules():
  700. if isinstance(m, GraphConv):
  701. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  702. # init_range = np.sqrt(6.0 / (m.input_dim + m.output_dim))
  703. # m.weight.data = torch.rand([m.input_dim, m.output_dim]).cuda()*init_range
  704. # print('find!')
  705. elif isinstance(m, nn.BatchNorm1d):
  706. m.weight.data.fill_(1)
  707. m.bias.data.zero_()
  708. def forward(self,x,adj):
  709. x = self.conv1(x,adj)
  710. # x = x/torch.sum(x, dim=2, keepdim=True)
  711. x = self.relu(x)
  712. # x = self.bn1(x)
  713. x = self.conv2(x,adj)
  714. # x = x / torch.sum(x, dim=2, keepdim=True)
  715. return x
  716. # vanilla GCN decoder
  717. class GCN_decoder(nn.Module):
  718. def __init__(self):
  719. super(GCN_decoder, self).__init__()
  720. # self.act = nn.Sigmoid()
  721. def forward(self,x):
  722. # x_t = x.view(-1,x.size(2),x.size(1))
  723. x_t = x.permute(0,2,1)
  724. # print('x',x)
  725. # print('x_t',x_t)
  726. y = torch.matmul(x, x_t)
  727. return y
  728. # GCN based graph embedding
  729. # allowing for arbitrary num of nodes
  730. class GCN_encoder_graph(nn.Module):
  731. def __init__(self,input_dim, hidden_dim, output_dim,num_layers):
  732. super(GCN_encoder_graph, self).__init__()
  733. self.num_layers = num_layers
  734. self.conv_first = GraphConv(input_dim=input_dim, output_dim=hidden_dim)
  735. # self.conv_hidden1 = GraphConv(input_dim=hidden_dim, output_dim=hidden_dim)
  736. # self.conv_hidden2 = GraphConv(input_dim=hidden_dim, output_dim=hidden_dim)
  737. self.conv_block = nn.ModuleList([GraphConv(input_dim=hidden_dim, output_dim=hidden_dim) for i in range(num_layers)])
  738. self.conv_last = GraphConv(input_dim=hidden_dim, output_dim=output_dim)
  739. self.act = nn.ReLU()
  740. for m in self.modules():
  741. if isinstance(m, GraphConv):
  742. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  743. # init_range = np.sqrt(6.0 / (m.input_dim + m.output_dim))
  744. # m.weight.data = torch.rand([m.input_dim, m.output_dim]).cuda()*init_range
  745. # print('find!')
  746. def forward(self,x,adj):
  747. x = self.conv_first(x,adj)
  748. x = self.act(x)
  749. out_all = []
  750. out, _ = torch.max(x, dim=1, keepdim=True)
  751. out_all.append(out)
  752. for i in range(self.num_layers-2):
  753. x = self.conv_block[i](x,adj)
  754. x = self.act(x)
  755. out,_ = torch.max(x, dim=1, keepdim = True)
  756. out_all.append(out)
  757. x = self.conv_last(x,adj)
  758. x = self.act(x)
  759. out,_ = torch.max(x, dim=1, keepdim = True)
  760. out_all.append(out)
  761. output = torch.cat(out_all, dim = 1)
  762. output = output.permute(1,0,2)
  763. # print(out)
  764. return output
  765. # x = Variable(torch.rand(1,8,10)).cuda()
  766. # adj = Variable(torch.rand(1,8,8)).cuda()
  767. # model = GCN_encoder_graph(10,10,10).cuda()
  768. # y = model(x,adj)
  769. # print(y.size())
  770. def preprocess(A):
  771. # Get size of the adjacency matrix
  772. size = A.size(1)
  773. # Get the degrees for each node
  774. degrees = torch.sum(A, dim=2)
  775. # Create diagonal matrix D from the degrees of the nodes
  776. D = Variable(torch.zeros(A.size(0),A.size(1),A.size(2))).cuda()
  777. for i in range(D.size(0)):
  778. D[i, :, :] = torch.diag(torch.pow(degrees[i,:], -0.5))
  779. # Cholesky decomposition of D
  780. # D = np.linalg.cholesky(D)
  781. # Inverse of the Cholesky decomposition of D
  782. # D = np.linalg.inv(D)
  783. # Create an identity matrix of size x size
  784. # Create A hat
  785. # Return A_hat
  786. A_normal = torch.matmul(torch.matmul(D,A), D)
  787. # print(A_normal)
  788. return A_normal
  789. # a sequential GCN model, GCN with n layers
  790. class GCN_generator(nn.Module):
  791. def __init__(self, hidden_dim):
  792. super(GCN_generator, self).__init__()
  793. # todo: add an linear_input module to map the input feature into 'hidden_dim'
  794. self.conv = GraphConv(input_dim=hidden_dim, output_dim=hidden_dim)
  795. self.act = nn.ReLU()
  796. # initialize
  797. for m in self.modules():
  798. if isinstance(m, GraphConv):
  799. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  800. def forward(self,x,teacher_force=False,adj_real=None):
  801. # x: batch * node_num * feature
  802. batch_num = x.size(0)
  803. node_num = x.size(1)
  804. adj = Variable(torch.eye(node_num).view(1,node_num,node_num).repeat(batch_num,1,1)).cuda()
  805. adj_output = Variable(torch.eye(node_num).view(1,node_num,node_num).repeat(batch_num,1,1)).cuda()
  806. # do GCN n times
  807. # todo: try if residual connections are plausible
  808. # todo: add higher order of adj (adj^2, adj^3, ...)
  809. # todo: try if norm everytim is plausible
  810. # first do GCN 1 time to preprocess the raw features
  811. # x_new = self.conv(x, adj)
  812. # x_new = self.act(x_new)
  813. # x = x + x_new
  814. x = self.conv(x, adj)
  815. x = self.act(x)
  816. # x = x / torch.norm(x, p=2, dim=2, keepdim=True)
  817. # then do GCN rest n-1 times
  818. for i in range(1, node_num):
  819. # 1 calc prob of a new edge, output the result in adj_output
  820. x_last = x[:,i:i+1,:].clone()
  821. x_prev = x[:,0:i,:].clone()
  822. x_prev = x_prev
  823. x_last = x_last
  824. prob = x_prev @ x_last.permute(0,2,1)
  825. adj_output[:,i,0:i] = prob.permute(0,2,1).clone()
  826. adj_output[:,0:i,i] = prob.clone()
  827. # 2 update adj
  828. if teacher_force:
  829. adj = Variable(torch.eye(node_num).view(1, node_num, node_num).repeat(batch_num, 1, 1)).cuda()
  830. adj[:,0:i+1,0:i+1] = adj_real[:,0:i+1,0:i+1].clone()
  831. else:
  832. adj[:, i, 0:i] = prob.permute(0,2,1).clone()
  833. adj[:, 0:i, i] = prob.clone()
  834. adj = preprocess(adj)
  835. # print(adj)
  836. # print(adj.min().data[0],adj.max().data[0])
  837. # print(x.min().data[0],x.max().data[0])
  838. # 3 do graph conv, with residual connection
  839. # x_new = self.conv(x, adj)
  840. # x_new = self.act(x_new)
  841. # x = x + x_new
  842. x = self.conv(x, adj)
  843. x = self.act(x)
  844. # x = x / torch.norm(x, p=2, dim=2, keepdim=True)
  845. # one = Variable(torch.ones(adj_output.size(0), adj_output.size(1), adj_output.size(2)) * 1.00).cuda().float()
  846. # two = Variable(torch.ones(adj_output.size(0), adj_output.size(1), adj_output.size(2)) * 2.01).cuda().float()
  847. # adj_output = (adj_output + one) / two
  848. # print(adj_output.max().data[0], adj_output.min().data[0])
  849. return adj_output
  850. # #### test code ####
  851. # print('teacher forcing')
  852. # # print('no teacher forcing')
  853. #
  854. # start = time.time()
  855. # generator = GCN_generator(hidden_dim=4)
  856. # end = time.time()
  857. # print('model build time', end-start)
  858. # for run in range(10):
  859. # for i in [500]:
  860. # for batch in [1,10,100]:
  861. # start = time.time()
  862. # torch.manual_seed(123)
  863. # x = Variable(torch.rand(batch,i,4)).cuda()
  864. # adj = Variable(torch.eye(i).view(1,i,i).repeat(batch,1,1)).cuda()
  865. # # print('x', x)
  866. # # print('adj', adj)
  867. #
  868. # # y = generator(x)
  869. # y = generator(x,True,adj)
  870. # # print('y',y)
  871. # end = time.time()
  872. # print('node num', i, ' batch size',batch, ' run time', end-start)
  873. class CNN_decoder(nn.Module):
  874. def __init__(self, input_size, output_size, stride = 2):
  875. super(CNN_decoder, self).__init__()
  876. self.input_size = input_size
  877. self.output_size = output_size
  878. self.relu = nn.ReLU()
  879. self.deconv1_1 = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.input_size/2), kernel_size=3, stride=stride)
  880. self.bn1_1 = nn.BatchNorm1d(int(self.input_size/2))
  881. self.deconv1_2 = nn.ConvTranspose1d(in_channels=int(self.input_size/2), out_channels=int(self.input_size/2), kernel_size=3, stride=stride)
  882. self.bn1_2 = nn.BatchNorm1d(int(self.input_size/2))
  883. self.deconv1_3 = nn.ConvTranspose1d(in_channels=int(self.input_size/2), out_channels=int(self.output_size), kernel_size=3, stride=1, padding=1)
  884. self.deconv2_1 = nn.ConvTranspose1d(in_channels=int(self.input_size/2), out_channels=int(self.input_size / 4), kernel_size=3, stride=stride)
  885. self.bn2_1 = nn.BatchNorm1d(int(self.input_size / 4))
  886. self.deconv2_2 = nn.ConvTranspose1d(in_channels=int(self.input_size / 4), out_channels=int(self.input_size/4), kernel_size=3, stride=stride)
  887. self.bn2_2 = nn.BatchNorm1d(int(self.input_size / 4))
  888. self.deconv2_3 = nn.ConvTranspose1d(in_channels=int(self.input_size / 4), out_channels=int(self.output_size), kernel_size=3, stride=1, padding=1)
  889. self.deconv3_1 = nn.ConvTranspose1d(in_channels=int(self.input_size / 4), out_channels=int(self.input_size / 8), kernel_size=3, stride=stride)
  890. self.bn3_1 = nn.BatchNorm1d(int(self.input_size / 8))
  891. self.deconv3_2 = nn.ConvTranspose1d(in_channels=int(self.input_size / 8), out_channels=int(self.input_size / 8), kernel_size=3, stride=stride)
  892. self.bn3_2 = nn.BatchNorm1d(int(self.input_size / 8))
  893. self.deconv3_3 = nn.ConvTranspose1d(in_channels=int(self.input_size / 8), out_channels=int(self.output_size), kernel_size=3, stride=1, padding=1)
  894. for m in self.modules():
  895. if isinstance(m, nn.ConvTranspose1d):
  896. # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  897. # m.weight.dataset.normal_(0, math.sqrt(2. / n))
  898. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  899. elif isinstance(m, nn.BatchNorm1d):
  900. m.weight.data.fill_(1)
  901. m.bias.data.zero_()
  902. def forward(self, x):
  903. '''
  904. :param
  905. x: batch * channel * length
  906. :return:
  907. '''
  908. # hop1
  909. x = self.deconv1_1(x)
  910. x = self.bn1_1(x)
  911. x = self.relu(x)
  912. # print(x.size())
  913. x = self.deconv1_2(x)
  914. x = self.bn1_2(x)
  915. x = self.relu(x)
  916. # print(x.size())
  917. x_hop1 = self.deconv1_3(x)
  918. # print(x_hop1.size())
  919. # hop2
  920. x = self.deconv2_1(x)
  921. x = self.bn2_1(x)
  922. x = self.relu(x)
  923. # print(x.size())
  924. x = self.deconv2_2(x)
  925. x = self.bn2_2(x)
  926. x = self.relu(x)
  927. x_hop2 = self.deconv2_3(x)
  928. # print(x_hop2.size())
  929. # hop3
  930. x = self.deconv3_1(x)
  931. x = self.bn3_1(x)
  932. x = self.relu(x)
  933. # print(x.size())
  934. x = self.deconv3_2(x)
  935. x = self.bn3_2(x)
  936. x = self.relu(x)
  937. # print(x.size())
  938. x_hop3 = self.deconv3_3(x)
  939. # print(x_hop3.size())
  940. return x_hop1,x_hop2,x_hop3
  941. # # reference code for doing residual connections
  942. # def _make_layer(self, block, planes, blocks, stride=1):
  943. # downsample = None
  944. # if stride != 1 or self.inplanes != planes * block.expansion:
  945. # downsample = nn.Sequential(
  946. # nn.Conv2d(self.inplanes, planes * block.expansion,
  947. # kernel_size=1, stride=stride, bias=False),
  948. # nn.BatchNorm2d(planes * block.expansion),
  949. # )
  950. #
  951. # layers = []
  952. # layers.append(block(self.inplanes, planes, stride, downsample))
  953. # self.inplanes = planes * block.expansion
  954. # for i in range(1, blocks):
  955. # layers.append(block(self.inplanes, planes))
  956. #
  957. # return nn.Sequential(*layers)
  958. class CNN_decoder_share(nn.Module):
  959. def __init__(self, input_size, output_size, stride, hops):
  960. super(CNN_decoder_share, self).__init__()
  961. self.input_size = input_size
  962. self.output_size = output_size
  963. self.hops = hops
  964. self.relu = nn.ReLU()
  965. self.deconv = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.input_size), kernel_size=3, stride=stride)
  966. self.bn = nn.BatchNorm1d(int(self.input_size))
  967. self.deconv_out = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.output_size), kernel_size=3, stride=1, padding=1)
  968. for m in self.modules():
  969. if isinstance(m, nn.ConvTranspose1d):
  970. # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  971. # m.weight.dataset.normal_(0, math.sqrt(2. / n))
  972. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  973. elif isinstance(m, nn.BatchNorm1d):
  974. m.weight.data.fill_(1)
  975. m.bias.data.zero_()
  976. def forward(self, x):
  977. '''
  978. :param
  979. x: batch * channel * length
  980. :return:
  981. '''
  982. # hop1
  983. x = self.deconv(x)
  984. x = self.bn(x)
  985. x = self.relu(x)
  986. # print(x.size())
  987. x = self.deconv(x)
  988. x = self.bn(x)
  989. x = self.relu(x)
  990. # print(x.size())
  991. x_hop1 = self.deconv_out(x)
  992. # print(x_hop1.size())
  993. # hop2
  994. x = self.deconv(x)
  995. x = self.bn(x)
  996. x = self.relu(x)
  997. # print(x.size())
  998. x = self.deconv(x)
  999. x = self.bn(x)
  1000. x = self.relu(x)
  1001. x_hop2 = self.deconv_out(x)
  1002. # print(x_hop2.size())
  1003. # hop3
  1004. x = self.deconv(x)
  1005. x = self.bn(x)
  1006. x = self.relu(x)
  1007. # print(x.size())
  1008. x = self.deconv(x)
  1009. x = self.bn(x)
  1010. x = self.relu(x)
  1011. # print(x.size())
  1012. x_hop3 = self.deconv_out(x)
  1013. # print(x_hop3.size())
  1014. return x_hop1,x_hop2,x_hop3
  1015. class CNN_decoder_attention(nn.Module):
  1016. def __init__(self, input_size, output_size, stride=2):
  1017. super(CNN_decoder_attention, self).__init__()
  1018. self.input_size = input_size
  1019. self.output_size = output_size
  1020. self.relu = nn.ReLU()
  1021. self.deconv = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.input_size),
  1022. kernel_size=3, stride=stride)
  1023. self.bn = nn.BatchNorm1d(int(self.input_size))
  1024. self.deconv_out = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.output_size),
  1025. kernel_size=3, stride=1, padding=1)
  1026. self.deconv_attention = nn.ConvTranspose1d(in_channels=int(self.input_size), out_channels=int(self.input_size),
  1027. kernel_size=1, stride=1, padding=0)
  1028. self.bn_attention = nn.BatchNorm1d(int(self.input_size))
  1029. self.relu_leaky = nn.LeakyReLU(0.2)
  1030. for m in self.modules():
  1031. if isinstance(m, nn.ConvTranspose1d):
  1032. # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  1033. # m.weight.dataset.normal_(0, math.sqrt(2. / n))
  1034. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  1035. elif isinstance(m, nn.BatchNorm1d):
  1036. m.weight.data.fill_(1)
  1037. m.bias.data.zero_()
  1038. def forward(self, x):
  1039. '''
  1040. :param
  1041. x: batch * channel * length
  1042. :return:
  1043. '''
  1044. # hop1
  1045. x = self.deconv(x)
  1046. x = self.bn(x)
  1047. x = self.relu(x)
  1048. x = self.deconv(x)
  1049. x = self.bn(x)
  1050. x = self.relu(x)
  1051. x_hop1 = self.deconv_out(x)
  1052. x_hop1_attention = self.deconv_attention(x)
  1053. # x_hop1_attention = self.bn_attention(x_hop1_attention)
  1054. x_hop1_attention = self.relu(x_hop1_attention)
  1055. x_hop1_attention = torch.matmul(x_hop1_attention,
  1056. x_hop1_attention.view(-1,x_hop1_attention.size(2),x_hop1_attention.size(1)))
  1057. # x_hop1_attention_sum = torch.norm(x_hop1_attention, 2, dim=1, keepdim=True)
  1058. # x_hop1_attention = x_hop1_attention/x_hop1_attention_sum
  1059. # print(x_hop1.size())
  1060. # hop2
  1061. x = self.deconv(x)
  1062. x = self.bn(x)
  1063. x = self.relu(x)
  1064. x = self.deconv(x)
  1065. x = self.bn(x)
  1066. x = self.relu(x)
  1067. x_hop2 = self.deconv_out(x)
  1068. x_hop2_attention = self.deconv_attention(x)
  1069. # x_hop2_attention = self.bn_attention(x_hop2_attention)
  1070. x_hop2_attention = self.relu(x_hop2_attention)
  1071. x_hop2_attention = torch.matmul(x_hop2_attention,
  1072. x_hop2_attention.view(-1, x_hop2_attention.size(2), x_hop2_attention.size(1)))
  1073. # x_hop2_attention_sum = torch.norm(x_hop2_attention, 2, dim=1, keepdim=True)
  1074. # x_hop2_attention = x_hop2_attention/x_hop2_attention_sum
  1075. # print(x_hop2.size())
  1076. # hop3
  1077. x = self.deconv(x)
  1078. x = self.bn(x)
  1079. x = self.relu(x)
  1080. x = self.deconv(x)
  1081. x = self.bn(x)
  1082. x = self.relu(x)
  1083. x_hop3 = self.deconv_out(x)
  1084. x_hop3_attention = self.deconv_attention(x)
  1085. # x_hop3_attention = self.bn_attention(x_hop3_attention)
  1086. x_hop3_attention = self.relu(x_hop3_attention)
  1087. x_hop3_attention = torch.matmul(x_hop3_attention,
  1088. x_hop3_attention.view(-1, x_hop3_attention.size(2), x_hop3_attention.size(1)))
  1089. # x_hop3_attention_sum = torch.norm(x_hop3_attention, 2, dim=1, keepdim=True)
  1090. # x_hop3_attention = x_hop3_attention / x_hop3_attention_sum
  1091. # print(x_hop3.size())
  1092. return x_hop1, x_hop2, x_hop3, x_hop1_attention, x_hop2_attention, x_hop3_attention
  1093. #### test code ####
  1094. # x = Variable(torch.randn(1, 256, 1)).cuda()
  1095. # decoder = CNN_decoder(256, 16).cuda()
  1096. # y = decoder(x)
  1097. class Graphsage_Encoder(nn.Module):
  1098. def __init__(self, feature_size, input_size, layer_num):
  1099. super(Graphsage_Encoder, self).__init__()
  1100. self.linear_projection = nn.Linear(feature_size, input_size)
  1101. self.input_size = input_size
  1102. # linear for hop 3
  1103. self.linear_3_0 = nn.Linear(input_size*(2 ** 0), input_size*(2 ** 1))
  1104. self.linear_3_1 = nn.Linear(input_size*(2 ** 1), input_size*(2 ** 2))
  1105. self.linear_3_2 = nn.Linear(input_size*(2 ** 2), input_size*(2 ** 3))
  1106. # linear for hop 2
  1107. self.linear_2_0 = nn.Linear(input_size * (2 ** 0), input_size * (2 ** 1))
  1108. self.linear_2_1 = nn.Linear(input_size * (2 ** 1), input_size * (2 ** 2))
  1109. # linear for hop 1
  1110. self.linear_1_0 = nn.Linear(input_size * (2 ** 0), input_size * (2 ** 1))
  1111. # linear for hop 0
  1112. self.linear_0_0 = nn.Linear(input_size * (2 ** 0), input_size * (2 ** 1))
  1113. self.linear = nn.Linear(input_size*(2+2+4+8), input_size*(16))
  1114. self.bn_3_0 = nn.BatchNorm1d(self.input_size * (2 ** 1))
  1115. self.bn_3_1 = nn.BatchNorm1d(self.input_size * (2 ** 2))
  1116. self.bn_3_2 = nn.BatchNorm1d(self.input_size * (2 ** 3))
  1117. self.bn_2_0 = nn.BatchNorm1d(self.input_size * (2 ** 1))
  1118. self.bn_2_1 = nn.BatchNorm1d(self.input_size * (2 ** 2))
  1119. self.bn_1_0 = nn.BatchNorm1d(self.input_size * (2 ** 1))
  1120. self.bn_0_0 = nn.BatchNorm1d(self.input_size * (2 ** 1))
  1121. self.bn = nn.BatchNorm1d(input_size*(16))
  1122. self.relu = nn.ReLU()
  1123. for m in self.modules():
  1124. if isinstance(m, nn.Linear):
  1125. m.weight.data = init.xavier_uniform(m.weight.data,gain=nn.init.calculate_gain('relu'))
  1126. elif isinstance(m, nn.BatchNorm1d):
  1127. m.weight.data.fill_(1)
  1128. m.bias.data.zero_()
  1129. def forward(self, nodes_list, nodes_count_list):
  1130. '''
  1131. :param nodes: a list, each element n_i is a tensor for node's k-i hop neighbours
  1132. (the first nodes_hop is the furthest neighbor)
  1133. where n_i = N * num_neighbours * features
  1134. nodes_count: a list, each element is a list that show how many neighbours belongs to the father node
  1135. :return:
  1136. '''
  1137. # 3-hop feature
  1138. # nodes original features to representations
  1139. nodes_list[0] = Variable(nodes_list[0]).cuda()
  1140. nodes_list[0] = self.linear_projection(nodes_list[0])
  1141. nodes_features = self.linear_3_0(nodes_list[0])
  1142. nodes_features = self.bn_3_0(nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1)))
  1143. nodes_features = nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1))
  1144. nodes_features = self.relu(nodes_features)
  1145. # nodes count from previous hop
  1146. nodes_count = nodes_count_list[0]
  1147. # print(nodes_count,nodes_count.size())
  1148. # aggregated representations placeholder, feature dim * 2
  1149. nodes_features_farther = Variable(torch.Tensor(nodes_features.size(0), nodes_count.size(1), nodes_features.size(2))).cuda()
  1150. i = 0
  1151. for j in range(nodes_count.size(1)):
  1152. # mean pooling for each father node
  1153. # print(nodes_count[:,j][0],type(nodes_count[:,j][0]))
  1154. nodes_features_farther[:,j,:] = torch.mean(nodes_features[:, i:i+int(nodes_count[:,j][0]), :], 1, keepdim = False)
  1155. i += int(nodes_count[:,j][0])
  1156. # assign node_features
  1157. nodes_features = nodes_features_farther
  1158. nodes_features = self.linear_3_1(nodes_features)
  1159. nodes_features = self.bn_3_1(nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1)))
  1160. nodes_features = nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1))
  1161. nodes_features = self.relu(nodes_features)
  1162. # nodes count from previous hop
  1163. nodes_count = nodes_count_list[1]
  1164. # aggregated representations placeholder, feature dim * 2
  1165. nodes_features_farther = Variable(torch.Tensor(nodes_features.size(0), nodes_count.size(1), nodes_features.size(2))).cuda()
  1166. i = 0
  1167. for j in range(nodes_count.size(1)):
  1168. # mean pooling for each father node
  1169. nodes_features_farther[:,j,:] = torch.mean(nodes_features[:, i:i+int(nodes_count[:,j][0]), :], 1, keepdim = False)
  1170. i += int(nodes_count[:,j][0])
  1171. # assign node_features
  1172. nodes_features = nodes_features_farther
  1173. # print('nodes_feature',nodes_features.size())
  1174. nodes_features = self.linear_3_2(nodes_features)
  1175. nodes_features = self.bn_3_2(nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1)))
  1176. nodes_features = nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1))
  1177. # nodes_features = self.relu(nodes_features)
  1178. # nodes count from previous hop
  1179. nodes_features_hop_3 = torch.mean(nodes_features, 1, keepdim=True)
  1180. # print(nodes_features_hop_3.size())
  1181. # 2-hop feature
  1182. # nodes original features to representations
  1183. nodes_list[1] = Variable(nodes_list[1]).cuda()
  1184. nodes_list[1] = self.linear_projection(nodes_list[1])
  1185. nodes_features = self.linear_2_0(nodes_list[1])
  1186. nodes_features = self.bn_2_0(nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1)))
  1187. nodes_features = nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1))
  1188. nodes_features = self.relu(nodes_features)
  1189. # nodes count from previous hop
  1190. nodes_count = nodes_count_list[1]
  1191. # aggregated representations placeholder, feature dim * 2
  1192. nodes_features_farther = Variable(torch.Tensor(nodes_features.size(0), nodes_count.size(1), nodes_features.size(2))).cuda()
  1193. i = 0
  1194. for j in range(nodes_count.size(1)):
  1195. # mean pooling for each father node
  1196. nodes_features_farther[:,j,:] = torch.mean(nodes_features[:, i:i+int(nodes_count[:,j][0]), :], 1, keepdim = False)
  1197. i += int(nodes_count[:,j][0])
  1198. # assign node_features
  1199. nodes_features = nodes_features_farther
  1200. nodes_features = self.linear_2_1(nodes_features)
  1201. nodes_features = self.bn_2_1(nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1)))
  1202. nodes_features = nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1))
  1203. # nodes_features = self.relu(nodes_features)
  1204. # nodes count from previous hop
  1205. nodes_features_hop_2 = torch.mean(nodes_features, 1, keepdim=True)
  1206. # print(nodes_features_hop_2.size())
  1207. # 1-hop feature
  1208. # nodes original features to representations
  1209. nodes_list[2] = Variable(nodes_list[2]).cuda()
  1210. nodes_list[2] = self.linear_projection(nodes_list[2])
  1211. nodes_features = self.linear_1_0(nodes_list[2])
  1212. nodes_features = self.bn_1_0(nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1)))
  1213. nodes_features = nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1))
  1214. # nodes_features = self.relu(nodes_features)
  1215. # nodes count from previous hop
  1216. nodes_features_hop_1 = torch.mean(nodes_features, 1, keepdim=True)
  1217. # print(nodes_features_hop_1.size())
  1218. # own feature
  1219. nodes_list[3] = Variable(nodes_list[3]).cuda()
  1220. nodes_list[3] = self.linear_projection(nodes_list[3])
  1221. nodes_features = self.linear_0_0(nodes_list[3])
  1222. nodes_features = self.bn_0_0(nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1)))
  1223. nodes_features_hop_0 = nodes_features.view(-1, nodes_features.size(2), nodes_features.size(1))
  1224. # print(nodes_features_hop_0.size())
  1225. # concatenate
  1226. nodes_features = torch.cat((nodes_features_hop_0, nodes_features_hop_1, nodes_features_hop_2, nodes_features_hop_3),dim=2)
  1227. nodes_features = self.linear(nodes_features)
  1228. # nodes_features = self.bn(nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1)))
  1229. nodes_features = nodes_features.view(-1,nodes_features.size(2),nodes_features.size(1))
  1230. # print(nodes_features.size())
  1231. return(nodes_features)