You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462
  1. import torch
  2. import torchvision as tv
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. import matplotlib.pyplot as plt
  6. from random import shuffle
  7. import networkx as nx
  8. import pickle as pkl
  9. import scipy.sparse as sp
  10. import logging
  11. import random
  12. import shutil
  13. import os
  14. import time
  15. from model import *
  16. from utils import *
  17. # load ENZYMES and PROTEIN and DD dataset
  18. def Graph_load_batch(min_num_nodes = 20, max_num_nodes = 1000, name = 'ENZYMES',node_attributes = True,graph_labels=True):
  19. '''
  20. load many graphs, e.g. enzymes
  21. :return: a list of graphs
  22. '''
  23. print('Loading graph dataset: '+str(name))
  24. G = nx.Graph()
  25. # load data
  26. path = 'dataset/'+name+'/'
  27. data_adj = np.loadtxt(path+name+'_A.txt', delimiter=',').astype(int)
  28. if node_attributes:
  29. data_node_att = np.loadtxt(path+name+'_node_attributes.txt', delimiter=',')
  30. data_node_label = np.loadtxt(path+name+'_node_labels.txt', delimiter=',').astype(int)
  31. data_graph_indicator = np.loadtxt(path+name+'_graph_indicator.txt', delimiter=',').astype(int)
  32. if graph_labels:
  33. data_graph_labels = np.loadtxt(path+name+'_graph_labels.txt', delimiter=',').astype(int)
  34. data_tuple = list(map(tuple, data_adj))
  35. # print(len(data_tuple))
  36. # print(data_tuple[0])
  37. # add edges
  38. G.add_edges_from(data_tuple)
  39. # add node attributes
  40. for i in range(data_node_label.shape[0]):
  41. if node_attributes:
  42. G.add_node(i+1, feature = data_node_att[i])
  43. G.add_node(i+1, label = data_node_label[i])
  44. G.remove_nodes_from(list(nx.isolates(G)))
  45. # print(G.number_of_nodes())
  46. # print(G.number_of_edges())
  47. # split into graphs
  48. graph_num = data_graph_indicator.max()
  49. node_list = np.arange(data_graph_indicator.shape[0])+1
  50. graphs = []
  51. max_nodes = 0
  52. for i in range(graph_num):
  53. # find the nodes for each graph
  54. nodes = node_list[data_graph_indicator==i+1]
  55. G_sub = G.subgraph(nodes)
  56. if graph_labels:
  57. G_sub.graph['label'] = data_graph_labels[i]
  58. # print('nodes', G_sub.number_of_nodes())
  59. # print('edges', G_sub.number_of_edges())
  60. # print('label', G_sub.graph)
  61. if G_sub.number_of_nodes()>=min_num_nodes and G_sub.number_of_nodes()<=max_num_nodes:
  62. graphs.append(G_sub)
  63. if G_sub.number_of_nodes() > max_nodes:
  64. max_nodes = G_sub.number_of_nodes()
  65. # print(G_sub.number_of_nodes(), 'i', i)
  66. # print('Graph dataset name: {}, total graph num: {}'.format(name, len(graphs)))
  67. # logging.warning('Graphs loaded, total num: {}'.format(len(graphs)))
  68. print('Loaded')
  69. return graphs
  70. def test_graph_load_DD():
  71. graphs, max_num_nodes = Graph_load_batch(min_num_nodes=10,name='DD',node_attributes=False,graph_labels=True)
  72. shuffle(graphs)
  73. plt.switch_backend('agg')
  74. plt.hist([len(graphs[i]) for i in range(len(graphs))], bins=100)
  75. plt.savefig('figures/test.png')
  76. plt.close()
  77. row = 4
  78. col = 4
  79. draw_graph_list(graphs[0:row*col], row=row,col=col, fname='figures/test')
  80. print('max num nodes',max_num_nodes)
  81. def parse_index_file(filename):
  82. index = []
  83. for line in open(filename):
  84. index.append(int(line.strip()))
  85. return index
  86. # load cora, citeseer and pubmed dataset
  87. def Graph_load(dataset = 'cora'):
  88. '''
  89. Load a single graph dataset
  90. :param dataset: dataset name
  91. :return:
  92. '''
  93. names = ['x', 'tx', 'allx', 'graph']
  94. objects = []
  95. for i in range(len(names)):
  96. load = pkl.load(open("dataset/ind.{}.{}".format(dataset, names[i]), 'rb'), encoding='latin1')
  97. # print('loaded')
  98. objects.append(load)
  99. # print(load)
  100. x, tx, allx, graph = tuple(objects)
  101. test_idx_reorder = parse_index_file("dataset/ind.{}.test.index".format(dataset))
  102. test_idx_range = np.sort(test_idx_reorder)
  103. if dataset == 'citeseer':
  104. # Fix citeseer dataset (there are some isolated nodes in the graph)
  105. # Find isolated nodes, add them as zero-vecs into the right position
  106. test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder) + 1)
  107. tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
  108. tx_extended[test_idx_range - min(test_idx_range), :] = tx
  109. tx = tx_extended
  110. features = sp.vstack((allx, tx)).tolil()
  111. features[test_idx_reorder, :] = features[test_idx_range, :]
  112. G = nx.from_dict_of_lists(graph)
  113. adj = nx.adjacency_matrix(G)
  114. return adj, features, G
  115. ######### code test ########
  116. # adj, features,G = Graph_load()
  117. # print(adj)
  118. # print(G.number_of_nodes(), G.number_of_edges())
  119. # _,_,G = Graph_load(dataset='citeseer')
  120. # G = max(nx.connected_component_subgraphs(G), key=len)
  121. # G = nx.convert_node_labels_to_integers(G)
  122. #
  123. # count = 0
  124. # max_node = 0
  125. # for i in range(G.number_of_nodes()):
  126. # G_ego = nx.ego_graph(G, i, radius=3)
  127. # # draw_graph(G_ego,prefix='test'+str(i))
  128. # m = G_ego.number_of_nodes()
  129. # if m>max_node:
  130. # max_node = m
  131. # if m>=50:
  132. # print(i, G_ego.number_of_nodes(), G_ego.number_of_edges())
  133. # count += 1
  134. # print('count', count)
  135. # print('max_node', max_node)
  136. def bfs_seq(G, start_id):
  137. '''
  138. get a bfs node sequence
  139. :param G:
  140. :param start_id:
  141. :return:
  142. '''
  143. dictionary = dict(nx.bfs_successors(G, start_id))
  144. start = [start_id]
  145. output = [start_id]
  146. while len(start) > 0:
  147. next = []
  148. while len(start) > 0:
  149. current = start.pop(0)
  150. neighbor = dictionary.get(current)
  151. if neighbor is not None:
  152. #### a wrong example, should not permute here!
  153. # shuffle(neighbor)
  154. next = next + neighbor
  155. output = output + next
  156. start = next
  157. return output
  158. def encode_adj(adj, max_prev_node=10, is_full = False):
  159. '''
  160. :param adj: n*n, rows means time step, while columns are input dimension
  161. :param max_degree: we want to keep row number, but truncate column numbers
  162. :return:
  163. '''
  164. if is_full:
  165. max_prev_node = adj.shape[0]-1
  166. # pick up lower tri
  167. adj = np.tril(adj, k=-1)
  168. n = adj.shape[0]
  169. adj = adj[1:n, 0:n-1]
  170. # use max_prev_node to truncate
  171. # note: now adj is a (n-1)*(n-1) matrix
  172. adj_output = np.zeros((adj.shape[0], max_prev_node))
  173. for i in range(adj.shape[0]):
  174. input_start = max(0, i - max_prev_node + 1)
  175. input_end = i + 1
  176. output_start = max_prev_node + input_start - input_end
  177. output_end = max_prev_node
  178. adj_output[i, output_start:output_end] = adj[i, input_start:input_end]
  179. adj_output[i,:] = adj_output[i,:][::-1] # reverse order
  180. return adj_output
  181. def decode_adj(adj_output):
  182. '''
  183. recover to adj from adj_output
  184. note: here adj_output have shape (n-1)*m
  185. '''
  186. max_prev_node = adj_output.shape[1]
  187. adj = np.zeros((adj_output.shape[0], adj_output.shape[0]))
  188. for i in range(adj_output.shape[0]):
  189. input_start = max(0, i - max_prev_node + 1)
  190. input_end = i + 1
  191. output_start = max_prev_node + max(0, i - max_prev_node + 1) - (i + 1)
  192. output_end = max_prev_node
  193. adj[i, input_start:input_end] = adj_output[i,::-1][output_start:output_end] # reverse order
  194. adj_full = np.zeros((adj_output.shape[0]+1, adj_output.shape[0]+1))
  195. n = adj_full.shape[0]
  196. adj_full[1:n, 0:n-1] = np.tril(adj, 0)
  197. adj_full = adj_full + adj_full.T
  198. return adj_full
  199. def encode_adj_flexible(adj):
  200. '''
  201. return a flexible length of output
  202. note that here there is no loss when encoding/decoding an adj matrix
  203. :param adj: adj matrix
  204. :return:
  205. '''
  206. # pick up lower tri
  207. adj = np.tril(adj, k=-1)
  208. n = adj.shape[0]
  209. adj = adj[1:n, 0:n-1]
  210. adj_output = []
  211. input_start = 0
  212. for i in range(adj.shape[0]):
  213. input_end = i + 1
  214. adj_slice = adj[i, input_start:input_end]
  215. adj_output.append(adj_slice)
  216. non_zero = np.nonzero(adj_slice)[0]
  217. input_start = input_end-len(adj_slice)+np.amin(non_zero)
  218. return adj_output
  219. def decode_adj_flexible(adj_output):
  220. '''
  221. return a flexible length of output
  222. note that here there is no loss when encoding/decoding an adj matrix
  223. :param adj: adj matrix
  224. :return:
  225. '''
  226. adj = np.zeros((len(adj_output), len(adj_output)))
  227. for i in range(len(adj_output)):
  228. output_start = i+1-len(adj_output[i])
  229. output_end = i+1
  230. adj[i, output_start:output_end] = adj_output[i]
  231. adj_full = np.zeros((len(adj_output)+1, len(adj_output)+1))
  232. n = adj_full.shape[0]
  233. adj_full[1:n, 0:n-1] = np.tril(adj, 0)
  234. adj_full = adj_full + adj_full.T
  235. return adj_full
  236. def test_encode_decode_adj():
  237. ######## code test ###########
  238. G = nx.ladder_graph(5)
  239. G = nx.grid_2d_graph(20,20)
  240. G = nx.ladder_graph(200)
  241. G = nx.karate_club_graph()
  242. G = nx.connected_caveman_graph(2,3)
  243. print(G.number_of_nodes())
  244. adj = np.asarray(nx.to_numpy_matrix(G))
  245. G = nx.from_numpy_matrix(adj)
  246. #
  247. start_idx = np.random.randint(adj.shape[0])
  248. x_idx = np.array(bfs_seq(G, start_idx))
  249. adj = adj[np.ix_(x_idx, x_idx)]
  250. print('adj\n',adj)
  251. adj_output = encode_adj(adj,max_prev_node=5)
  252. print('adj_output\n',adj_output)
  253. adj_recover = decode_adj(adj_output,max_prev_node=5)
  254. print('adj_recover\n',adj_recover)
  255. print('error\n',np.amin(adj_recover-adj),np.amax(adj_recover-adj))
  256. adj_output = encode_adj_flexible(adj)
  257. for i in range(len(adj_output)):
  258. print(len(adj_output[i]))
  259. adj_recover = decode_adj_flexible(adj_output)
  260. print(adj_recover)
  261. print(np.amin(adj_recover-adj),np.amax(adj_recover-adj))
  262. def encode_adj_full(adj):
  263. '''
  264. return a n-1*n-1*2 tensor, the first dimension is an adj matrix, the second show if each entry is valid
  265. :param adj: adj matrix
  266. :return:
  267. '''
  268. # pick up lower tri
  269. adj = np.tril(adj, k=-1)
  270. n = adj.shape[0]
  271. adj = adj[1:n, 0:n-1]
  272. adj_output = np.zeros((adj.shape[0],adj.shape[1],2))
  273. adj_len = np.zeros(adj.shape[0])
  274. for i in range(adj.shape[0]):
  275. non_zero = np.nonzero(adj[i,:])[0]
  276. input_start = np.amin(non_zero)
  277. input_end = i + 1
  278. adj_slice = adj[i, input_start:input_end]
  279. # write adj
  280. adj_output[i,0:adj_slice.shape[0],0] = adj_slice[::-1] # put in reverse order
  281. # write stop token (if token is 0, stop)
  282. adj_output[i,0:adj_slice.shape[0],1] = 1 # put in reverse order
  283. # write sequence length
  284. adj_len[i] = adj_slice.shape[0]
  285. return adj_output,adj_len
  286. def decode_adj_full(adj_output):
  287. '''
  288. return an adj according to adj_output
  289. :param
  290. :return:
  291. '''
  292. # pick up lower tri
  293. adj = np.zeros((adj_output.shape[0]+1,adj_output.shape[1]+1))
  294. for i in range(adj_output.shape[0]):
  295. non_zero = np.nonzero(adj_output[i,:,1])[0] # get valid sequence
  296. input_end = np.amax(non_zero)
  297. adj_slice = adj_output[i, 0:input_end+1, 0] # get adj slice
  298. # write adj
  299. output_end = i+1
  300. output_start = i+1-input_end-1
  301. adj[i+1,output_start:output_end] = adj_slice[::-1] # put in reverse order
  302. adj = adj + adj.T
  303. return adj
  304. def test_encode_decode_adj_full():
  305. ########### code test #############
  306. # G = nx.ladder_graph(10)
  307. G = nx.karate_club_graph()
  308. # get bfs adj
  309. adj = np.asarray(nx.to_numpy_matrix(G))
  310. G = nx.from_numpy_matrix(adj)
  311. start_idx = np.random.randint(adj.shape[0])
  312. x_idx = np.array(bfs_seq(G, start_idx))
  313. adj = adj[np.ix_(x_idx, x_idx)]
  314. adj_output, adj_len = encode_adj_full(adj)
  315. print('adj\n',adj)
  316. print('adj_output[0]\n',adj_output[:,:,0])
  317. print('adj_output[1]\n',adj_output[:,:,1])
  318. # print('adj_len\n',adj_len)
  319. adj_recover = decode_adj_full(adj_output)
  320. print('adj_recover\n', adj_recover)
  321. print('error\n',adj_recover-adj)
  322. print('error_sum\n',np.amax(adj_recover-adj), np.amin(adj_recover-adj))
  323. ########## use pytorch dataloader
  324. class Graph_sequence_sampler_pytorch(torch.utils.data.Dataset):
  325. def __init__(self, G_list, max_num_node=None, max_prev_node=None, iteration=20000):
  326. self.adj_all = []
  327. self.len_all = []
  328. for G in G_list:
  329. self.adj_all.append(np.asarray(nx.to_numpy_matrix(G)))
  330. self.len_all.append(G.number_of_nodes())
  331. if max_num_node is None:
  332. self.n = max(self.len_all)
  333. else:
  334. self.n = max_num_node
  335. if max_prev_node is None:
  336. print('calculating max previous node, total iteration: {}'.format(iteration))
  337. self.max_prev_node = max(self.calc_max_prev_node(iter=iteration))
  338. print('max previous node: {}'.format(self.max_prev_node))
  339. else:
  340. self.max_prev_node = max_prev_node
  341. # self.max_prev_node = max_prev_node
  342. # # sort Graph in descending order
  343. # len_batch_order = np.argsort(np.array(self.len_all))[::-1]
  344. # self.len_all = [self.len_all[i] for i in len_batch_order]
  345. # self.adj_all = [self.adj_all[i] for i in len_batch_order]
  346. def __len__(self):
  347. return len(self.adj_all)
  348. def __getitem__(self, idx):
  349. adj_copy = self.adj_all[idx].copy()
  350. x_batch = np.zeros((self.n, self.max_prev_node)) # here zeros are padded for small graph
  351. x_batch[0,:] = 1 # the first input token is all ones
  352. y_batch = np.zeros((self.n, self.max_prev_node)) # here zeros are padded for small graph
  353. # generate input x, y pairs
  354. len_batch = adj_copy.shape[0]
  355. x_idx = np.random.permutation(adj_copy.shape[0])
  356. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  357. adj_copy_matrix = np.asmatrix(adj_copy)
  358. G = nx.from_numpy_matrix(adj_copy_matrix)
  359. # then do bfs in the permuted G
  360. start_idx = np.random.randint(adj_copy.shape[0])
  361. x_idx = np.array(bfs_seq(G, start_idx))
  362. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  363. adj_encoded = encode_adj(adj_copy.copy(), max_prev_node=self.max_prev_node)
  364. # get x and y and adj
  365. # for small graph the rest are zero padded
  366. y_batch[0:adj_encoded.shape[0], :] = adj_encoded
  367. x_batch[1:adj_encoded.shape[0] + 1, :] = adj_encoded
  368. return {'x':x_batch,'y':y_batch, 'len':len_batch}
  369. def calc_max_prev_node(self, iter=20000,topk=10):
  370. max_prev_node = []
  371. for i in range(iter):
  372. if i % (iter / 5) == 0:
  373. print('iter {} times'.format(i))
  374. adj_idx = np.random.randint(len(self.adj_all))
  375. adj_copy = self.adj_all[adj_idx].copy()
  376. # print('Graph size', adj_copy.shape[0])
  377. x_idx = np.random.permutation(adj_copy.shape[0])
  378. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  379. adj_copy_matrix = np.asmatrix(adj_copy)
  380. G = nx.from_numpy_matrix(adj_copy_matrix)
  381. # then do bfs in the permuted G
  382. start_idx = np.random.randint(adj_copy.shape[0])
  383. x_idx = np.array(bfs_seq(G, start_idx))
  384. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  385. # encode adj
  386. adj_encoded = encode_adj_flexible(adj_copy.copy())
  387. max_encoded_len = max([len(adj_encoded[i]) for i in range(len(adj_encoded))])
  388. max_prev_node.append(max_encoded_len)
  389. max_prev_node = sorted(max_prev_node)[-1*topk:]
  390. return max_prev_node
  391. ########## use pytorch dataloader
  392. class Graph_sequence_sampler_pytorch_nobfs(torch.utils.data.Dataset):
  393. def __init__(self, G_list, max_num_node=None):
  394. self.adj_all = []
  395. self.len_all = []
  396. for G in G_list:
  397. self.adj_all.append(np.asarray(nx.to_numpy_matrix(G)))
  398. self.len_all.append(G.number_of_nodes())
  399. if max_num_node is None:
  400. self.n = max(self.len_all)
  401. else:
  402. self.n = max_num_node
  403. def __len__(self):
  404. return len(self.adj_all)
  405. def __getitem__(self, idx):
  406. adj_copy = self.adj_all[idx].copy()
  407. x_batch = np.zeros((self.n, self.n-1)) # here zeros are padded for small graph
  408. x_batch[0,:] = 1 # the first input token is all ones
  409. y_batch = np.zeros((self.n, self.n-1)) # here zeros are padded for small graph
  410. # generate input x, y pairs
  411. len_batch = adj_copy.shape[0]
  412. x_idx = np.random.permutation(adj_copy.shape[0])
  413. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  414. adj_encoded = encode_adj(adj_copy.copy(), max_prev_node=self.n-1)
  415. # get x and y and adj
  416. # for small graph the rest are zero padded
  417. y_batch[0:adj_encoded.shape[0], :] = adj_encoded
  418. x_batch[1:adj_encoded.shape[0] + 1, :] = adj_encoded
  419. return {'x':x_batch,'y':y_batch, 'len':len_batch}
  420. # dataset = Graph_sequence_sampler_pytorch_nobfs(graphs)
  421. # print(dataset[1]['x'])
  422. # print(dataset[1]['y'])
  423. # print(dataset[1]['len'])
  424. ########## use pytorch dataloader
  425. ########## for completion task
  426. class Graph_sequence_sampler_pytorch_nobfs_for_completion(torch.utils.data.Dataset):
  427. def __init__(self, G_list, max_num_node=None):
  428. self.adj_all = []
  429. self.len_all = []
  430. for G in G_list:
  431. self.adj_all.append(np.asarray(nx.to_numpy_matrix(G)))
  432. self.len_all.append(G.number_of_nodes())
  433. if max_num_node is None:
  434. self.n = max(self.len_all)
  435. else:
  436. self.n = max_num_node
  437. def __len__(self):
  438. return len(self.adj_all)
  439. def shift_left(self, vector, idx):
  440. shifted_vector = vector
  441. length = max(np.shape(vector))
  442. for i in range(length):
  443. if i >= idx and i < length - 1:
  444. shifted_vector[:, i:i + 1] = vector[:, i + 1:i + 2]
  445. elif i == length - 1:
  446. shifted_vector[:, i:i + 1] = 0
  447. return shifted_vector
  448. def get_graph(self, adj):
  449. '''
  450. get a graph from zero-padded adj
  451. :param adj:
  452. :return:
  453. '''
  454. # remove all zeros rows and columns
  455. adj = adj[~np.all(adj == 0, axis=1)]
  456. adj = adj[:, ~np.all(adj == 0, axis=0)]
  457. adj = np.asmatrix(adj)
  458. G = nx.from_numpy_matrix(adj)
  459. return G
  460. def get_matrix_permutation_for_node_completion(self, input_matrix):
  461. G = self.get_graph(input_matrix)
  462. # graph_show(G)
  463. nodes = G.nodes()
  464. idx = random.randint(0, len(nodes) - 1)
  465. G.remove_node(nodes[idx])
  466. # graph_show(G)
  467. incomplete_matrix = np.asarray(nx.to_numpy_matrix(G))
  468. removed_vector = self.shift_left(input_matrix[idx:idx + 1, :], idx)
  469. incomplete_matrix = np.append(incomplete_matrix, removed_vector[:, :-1], axis=0)
  470. permuted_complete_matrix = np.append(incomplete_matrix, removed_vector.T, axis=1)
  471. # graph_show(get_graph(permuted_complete_matrix))
  472. return permuted_complete_matrix
  473. def __getitem__(self, idx):
  474. adj_copy = self.adj_all[idx].copy()
  475. x_batch = np.zeros((self.n, self.n-1)) # here zeros are padded for small graph
  476. x_batch[0,:] = 1 # the first input token is all ones
  477. y_batch = np.zeros((self.n, self.n-1)) # here zeros are padded for small graph
  478. # generate input x, y pairs
  479. len_batch = adj_copy.shape[0]
  480. x_idx = np.random.permutation(adj_copy.shape[0])
  481. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  482. adj_copy = self.get_matrix_permutation_for_node_completion(adj_copy)
  483. adj_encoded = encode_adj(adj_copy.copy(), max_prev_node=self.n-1)
  484. # get x and y and adj
  485. # for small graph the rest are zero padded
  486. y_batch[0:adj_encoded.shape[0], :] = adj_encoded
  487. x_batch[1:adj_encoded.shape[0] + 1, :] = adj_encoded
  488. return {'x':x_batch,'y':y_batch, 'len':len_batch}
  489. ########## use pytorch dataloader
  490. class Graph_sequence_sampler_pytorch_canonical(torch.utils.data.Dataset):
  491. def __init__(self, G_list, max_num_node=None, max_prev_node=None, iteration=20000):
  492. self.adj_all = []
  493. self.len_all = []
  494. for G in G_list:
  495. self.adj_all.append(np.asarray(nx.to_numpy_matrix(G)))
  496. self.len_all.append(G.number_of_nodes())
  497. if max_num_node is None:
  498. self.n = max(self.len_all)
  499. else:
  500. self.n = max_num_node
  501. if max_prev_node is None:
  502. # print('calculating max previous node, total iteration: {}'.format(iteration))
  503. # self.max_prev_node = max(self.calc_max_prev_node(iter=iteration))
  504. # print('max previous node: {}'.format(self.max_prev_node))
  505. self.max_prev_node = self.n-1
  506. else:
  507. self.max_prev_node = max_prev_node
  508. # self.max_prev_node = max_prev_node
  509. # # sort Graph in descending order
  510. # len_batch_order = np.argsort(np.array(self.len_all))[::-1]
  511. # self.len_all = [self.len_all[i] for i in len_batch_order]
  512. # self.adj_all = [self.adj_all[i] for i in len_batch_order]
  513. def __len__(self):
  514. return len(self.adj_all)
  515. def __getitem__(self, idx):
  516. adj_copy = self.adj_all[idx].copy()
  517. x_batch = np.zeros((self.n, self.max_prev_node)) # here zeros are padded for small graph
  518. x_batch[0,:] = 1 # the first input token is all ones
  519. y_batch = np.zeros((self.n, self.max_prev_node)) # here zeros are padded for small graph
  520. # generate input x, y pairs
  521. len_batch = adj_copy.shape[0]
  522. # adj_copy_matrix = np.asmatrix(adj_copy)
  523. # G = nx.from_numpy_matrix(adj_copy_matrix)
  524. # then do bfs in the permuted G
  525. # start_idx = G.number_of_nodes()-1
  526. # x_idx = np.array(bfs_seq(G, start_idx))
  527. # adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  528. adj_encoded = encode_adj(adj_copy, max_prev_node=self.max_prev_node)
  529. # get x and y and adj
  530. # for small graph the rest are zero padded
  531. y_batch[0:adj_encoded.shape[0], :] = adj_encoded
  532. x_batch[1:adj_encoded.shape[0] + 1, :] = adj_encoded
  533. return {'x':x_batch,'y':y_batch, 'len':len_batch}
  534. def calc_max_prev_node(self, iter=20000,topk=10):
  535. max_prev_node = []
  536. for i in range(iter):
  537. if i % (iter / 5) == 0:
  538. print('iter {} times'.format(i))
  539. adj_idx = np.random.randint(len(self.adj_all))
  540. adj_copy = self.adj_all[adj_idx].copy()
  541. # print('Graph size', adj_copy.shape[0])
  542. x_idx = np.random.permutation(adj_copy.shape[0])
  543. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  544. adj_copy_matrix = np.asmatrix(adj_copy)
  545. G = nx.from_numpy_matrix(adj_copy_matrix)
  546. # then do bfs in the permuted G
  547. start_idx = np.random.randint(adj_copy.shape[0])
  548. x_idx = np.array(bfs_seq(G, start_idx))
  549. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  550. # encode adj
  551. adj_encoded = encode_adj_flexible(adj_copy.copy())
  552. max_encoded_len = max([len(adj_encoded[i]) for i in range(len(adj_encoded))])
  553. max_prev_node.append(max_encoded_len)
  554. max_prev_node = sorted(max_prev_node)[-1*topk:]
  555. return max_prev_node
  556. ########## use pytorch dataloader
  557. class Graph_sequence_sampler_pytorch_nll(torch.utils.data.Dataset):
  558. def __init__(self, G_list, max_num_node=None, max_prev_node=None, iteration=20000):
  559. self.adj_all = []
  560. self.len_all = []
  561. for G in G_list:
  562. adj = np.asarray(nx.to_numpy_matrix(G))
  563. adj_temp = self.calc_adj(adj)
  564. self.adj_all.extend(adj_temp)
  565. self.len_all.append(G.number_of_nodes())
  566. if max_num_node is None:
  567. self.n = max(self.len_all)
  568. else:
  569. self.n = max_num_node
  570. if max_prev_node is None:
  571. # print('calculating max previous node, total iteration: {}'.format(iteration))
  572. # self.max_prev_node = max(self.calc_max_prev_node(iter=iteration))
  573. # print('max previous node: {}'.format(self.max_prev_node))
  574. self.max_prev_node = self.n-1
  575. else:
  576. self.max_prev_node = max_prev_node
  577. # self.max_prev_node = max_prev_node
  578. # # sort Graph in descending order
  579. # len_batch_order = np.argsort(np.array(self.len_all))[::-1]
  580. # self.len_all = [self.len_all[i] for i in len_batch_order]
  581. # self.adj_all = [self.adj_all[i] for i in len_batch_order]
  582. def __len__(self):
  583. return len(self.adj_all)
  584. def __getitem__(self, idx):
  585. adj_copy = self.adj_all[idx].copy()
  586. x_batch = np.zeros((self.n, self.max_prev_node)) # here zeros are padded for small graph
  587. x_batch[0,:] = 1 # the first input token is all ones
  588. y_batch = np.zeros((self.n, self.max_prev_node)) # here zeros are padded for small graph
  589. # generate input x, y pairs
  590. len_batch = adj_copy.shape[0]
  591. # adj_copy_matrix = np.asmatrix(adj_copy)
  592. # G = nx.from_numpy_matrix(adj_copy_matrix)
  593. # then do bfs in the permuted G
  594. # start_idx = G.number_of_nodes()-1
  595. # x_idx = np.array(bfs_seq(G, start_idx))
  596. # adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  597. adj_encoded = encode_adj(adj_copy, max_prev_node=self.max_prev_node)
  598. # get x and y and adj
  599. # for small graph the rest are zero padded
  600. y_batch[0:adj_encoded.shape[0], :] = adj_encoded
  601. x_batch[1:adj_encoded.shape[0] + 1, :] = adj_encoded
  602. return {'x':x_batch,'y':y_batch, 'len':len_batch}
  603. def calc_adj(self,adj):
  604. max_iter = 10000
  605. adj_all = [adj]
  606. adj_all_len = 1
  607. i_old = 0
  608. for i in range(max_iter):
  609. adj_copy = adj.copy()
  610. x_idx = np.random.permutation(adj_copy.shape[0])
  611. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  612. adj_copy_matrix = np.asmatrix(adj_copy)
  613. G = nx.from_numpy_matrix(adj_copy_matrix)
  614. # then do bfs in the permuted G
  615. start_idx = np.random.randint(adj_copy.shape[0])
  616. x_idx = np.array(bfs_seq(G, start_idx))
  617. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  618. add_flag = True
  619. for adj_exist in adj_all:
  620. if np.array_equal(adj_exist, adj_copy):
  621. add_flag = False
  622. break
  623. if add_flag:
  624. adj_all.append(adj_copy)
  625. adj_all_len += 1
  626. if adj_all_len % 10 ==0:
  627. print('adj found:',adj_all_len,'iter used',i)
  628. return adj_all
  629. # graphs = [nx.barabasi_albert_graph(20,3)]
  630. # graphs = [nx.grid_2d_graph(4,4)]
  631. # dataset = Graph_sequence_sampler_pytorch_nll(graphs)
  632. ############## below are codes not used in current version
  633. ############## they are based on pytorch default data loader, we should consider reimplement them in current datasets, since they are more efficient
  634. # normal version
  635. class Graph_sequence_sampler_truncate():
  636. '''
  637. the output will truncate according to the max_prev_node
  638. '''
  639. def __init__(self, G_list, max_node_num=25, batch_size=4, max_prev_node = 25):
  640. self.batch_size = batch_size
  641. self.n = max_node_num
  642. self.max_prev_node = max_prev_node
  643. self.adj_all = []
  644. for G in G_list:
  645. self.adj_all.append(np.asarray(nx.to_numpy_matrix(G)))
  646. def sample(self):
  647. # batch, length, feature
  648. x_batch = np.zeros((self.batch_size, self.n, self.max_prev_node)) # here zeros are padded for small graph
  649. y_batch = np.zeros((self.batch_size, self.n, self.max_prev_node)) # here zeros are padded for small graph
  650. len_batch = np.zeros(self.batch_size)
  651. # generate input x, y pairs
  652. for i in range(self.batch_size):
  653. # first sample and get a permuted adj
  654. adj_idx = np.random.randint(len(self.adj_all))
  655. adj_copy = self.adj_all[adj_idx].copy()
  656. len_batch[i] = adj_copy.shape[0]
  657. x_idx = np.random.permutation(adj_copy.shape[0])
  658. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  659. adj_copy_matrix = np.asmatrix(adj_copy)
  660. G = nx.from_numpy_matrix(adj_copy_matrix)
  661. # then do bfs in the permuted G
  662. start_idx = np.random.randint(adj_copy.shape[0])
  663. x_idx = np.array(bfs_seq(G, start_idx))
  664. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  665. adj_encoded = encode_adj(adj_copy.copy(), max_prev_node=self.max_prev_node)
  666. # get x and y and adj
  667. # for small graph the rest are zero padded
  668. y_batch[i, 0:adj_encoded.shape[0], :] = adj_encoded
  669. x_batch[i, 1:adj_encoded.shape[0]+1, :] = adj_encoded
  670. # sort in descending order
  671. len_batch_order = np.argsort(len_batch)[::-1]
  672. len_batch = len_batch[len_batch_order]
  673. x_batch = x_batch[len_batch_order,:,:]
  674. y_batch = y_batch[len_batch_order,:,:]
  675. return torch.from_numpy(x_batch).float(), torch.from_numpy(y_batch).float(), len_batch.astype('int').tolist()
  676. def calc_max_prev_node(self,iter):
  677. max_prev_node = []
  678. for i in range(iter):
  679. if i%(iter/10)==0:
  680. print(i)
  681. adj_idx = np.random.randint(len(self.adj_all))
  682. adj_copy = self.adj_all[adj_idx].copy()
  683. # print('Graph size', adj_copy.shape[0])
  684. x_idx = np.random.permutation(adj_copy.shape[0])
  685. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  686. adj_copy_matrix = np.asmatrix(adj_copy)
  687. G = nx.from_numpy_matrix(adj_copy_matrix)
  688. time1 = time.time()
  689. # then do bfs in the permuted G
  690. start_idx = np.random.randint(adj_copy.shape[0])
  691. x_idx = np.array(bfs_seq(G, start_idx))
  692. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  693. # encode adj
  694. adj_encoded = encode_adj_flexible(adj_copy.copy())
  695. max_encoded_len = max([len(adj_encoded[i]) for i in range(len(adj_encoded))])
  696. max_prev_node.append(max_encoded_len)
  697. max_prev_node = sorted(max_prev_node)[-100:]
  698. return max_prev_node
  699. # graphs, max_num_nodes = Graph_load_batch(min_num_nodes=6, name='DD',node_attributes=False)
  700. # dataset = Graph_sequence_sampler_truncate([nx.karate_club_graph()])
  701. # max_prev_nodes = dataset.calc_max_prev_node(iter=10000)
  702. # print(max_prev_nodes)
  703. # x,y,len = dataset.sample()
  704. # print('x',x)
  705. # print('y',y)
  706. # print(len)
  707. # only output y_batch (which is needed in batch version of new model)
  708. class Graph_sequence_sampler_fast():
  709. def __init__(self, G_list, max_node_num=25, batch_size=4, max_prev_node = 25):
  710. self.batch_size = batch_size
  711. self.G_list = G_list
  712. self.n = max_node_num
  713. self.max_prev_node = max_prev_node
  714. self.adj_all = []
  715. for G in G_list:
  716. self.adj_all.append(np.asarray(nx.to_numpy_matrix(G)))
  717. def sample(self):
  718. # batch, length, feature
  719. y_batch = np.zeros((self.batch_size, self.n, self.max_prev_node)) # here zeros are padded for small graph
  720. # generate input x, y pairs
  721. for i in range(self.batch_size):
  722. # first sample and get a permuted adj
  723. adj_idx = np.random.randint(len(self.adj_all))
  724. adj_copy = self.adj_all[adj_idx].copy()
  725. # print('graph size',adj_copy.shape[0])
  726. x_idx = np.random.permutation(adj_copy.shape[0])
  727. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  728. adj_copy_matrix = np.asmatrix(adj_copy)
  729. G = nx.from_numpy_matrix(adj_copy_matrix)
  730. # then do bfs in the permuted G
  731. start_idx = np.random.randint(adj_copy.shape[0])
  732. x_idx = np.array(bfs_seq(G, start_idx))
  733. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  734. # get the feature for the permuted G
  735. # dict = nx.bfs_successors(G, start_idx)
  736. # print('dict', dict, 'node num', self.G.number_of_nodes())
  737. # print('x idx', x_idx, 'len', len(x_idx))
  738. # print('adj')
  739. # np.set_printoptions(linewidth=200)
  740. # for print_i in range(adj_copy.shape[0]):
  741. # print(adj_copy[print_i].astype(int))
  742. # adj_before = adj_copy.copy()
  743. # encode adj
  744. adj_encoded = encode_adj(adj_copy.copy(), max_prev_node=self.max_prev_node)
  745. # print('adj encoded')
  746. # np.set_printoptions(linewidth=200)
  747. # for print_i in range(adj_copy.shape[0]):
  748. # print(adj_copy[print_i].astype(int))
  749. # decode adj
  750. # print('adj recover error')
  751. # adj_decode = decode_adj(adj_encoded.copy(), max_prev_node=self.max_prev_node)
  752. # adj_err = adj_decode-adj_copy
  753. # print(np.sum(adj_err))
  754. # if np.sum(adj_err)!=0:
  755. # print(adj_err)
  756. # np.set_printoptions(linewidth=200)
  757. # for print_i in range(adj_err.shape[0]):
  758. # print(adj_err[print_i].astype(int))
  759. # get x and y and adj
  760. # for small graph the rest are zero padded
  761. y_batch[i, 0:adj_encoded.shape[0], :] = adj_encoded
  762. # np.set_printoptions(linewidth=200,precision=3)
  763. # print('y\n')
  764. # for print_i in range(self.y_batch[i,:,:].shape[0]):
  765. # print(self.y_batch[i,:,:][print_i].astype(int))
  766. # print('x\n')
  767. # for print_i in range(self.x_batch[i, :, :].shape[0]):
  768. # print(self.x_batch[i, :, :][print_i].astype(int))
  769. # print('adj\n')
  770. # for print_i in range(self.adj_batch[i, :, :].shape[0]):
  771. # print(self.adj_batch[i, :, :][print_i].astype(int))
  772. # print('adj_norm\n')
  773. # for print_i in range(self.adj_norm_batch[i, :, :].shape[0]):
  774. # print(self.adj_norm_batch[i, :, :][print_i].astype(float))
  775. # print('feature\n')
  776. # for print_i in range(self.feature_batch[i, :, :].shape[0]):
  777. # print(self.feature_batch[i, :, :][print_i].astype(float))
  778. # print('x_batch\n',self.x_batch)
  779. # print('y_batch\n',self.y_batch)
  780. return torch.from_numpy(y_batch).float()
  781. # graphs, max_num_nodes = Graph_load_batch(min_num_nodes=6, name='PROTEINS_full')
  782. # print(max_num_nodes)
  783. # G = nx.ladder_graph(100)
  784. # # G1 = nx.karate_club_graph()
  785. # # G2 = nx.connected_caveman_graph(4,5)
  786. # G_list = [G]
  787. # dataset = Graph_sequence_sampler_fast(graphs, batch_size=128, max_node_num=max_num_nodes, max_prev_node=30)
  788. # for i in range(5):
  789. # time0 = time.time()
  790. # y = dataset.sample()
  791. # time1 = time.time()
  792. # print(i,'time', time1 - time0)
  793. # output size is flexible (using list to represent), batch size is 1
  794. class Graph_sequence_sampler_flexible():
  795. def __init__(self, G_list):
  796. self.G_list = G_list
  797. self.adj_all = []
  798. for G in G_list:
  799. self.adj_all.append(np.asarray(nx.to_numpy_matrix(G)))
  800. self.y_batch = []
  801. def sample(self):
  802. # generate input x, y pairs
  803. # first sample and get a permuted adj
  804. adj_idx = np.random.randint(len(self.adj_all))
  805. adj_copy = self.adj_all[adj_idx].copy()
  806. # print('graph size',adj_copy.shape[0])
  807. x_idx = np.random.permutation(adj_copy.shape[0])
  808. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  809. adj_copy_matrix = np.asmatrix(adj_copy)
  810. G = nx.from_numpy_matrix(adj_copy_matrix)
  811. # then do bfs in the permuted G
  812. start_idx = np.random.randint(adj_copy.shape[0])
  813. x_idx = np.array(bfs_seq(G, start_idx))
  814. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  815. # get the feature for the permuted G
  816. # dict = nx.bfs_successors(G, start_idx)
  817. # print('dict', dict, 'node num', self.G.number_of_nodes())
  818. # print('x idx', x_idx, 'len', len(x_idx))
  819. # print('adj')
  820. # np.set_printoptions(linewidth=200)
  821. # for print_i in range(adj_copy.shape[0]):
  822. # print(adj_copy[print_i].astype(int))
  823. # adj_before = adj_copy.copy()
  824. # encode adj
  825. adj_encoded = encode_adj_flexible(adj_copy.copy())
  826. # print('adj encoded')
  827. # np.set_printoptions(linewidth=200)
  828. # for print_i in range(adj_copy.shape[0]):
  829. # print(adj_copy[print_i].astype(int))
  830. # decode adj
  831. # print('adj recover error')
  832. # adj_decode = decode_adj(adj_encoded.copy(), max_prev_node=self.max_prev_node)
  833. # adj_err = adj_decode-adj_copy
  834. # print(np.sum(adj_err))
  835. # if np.sum(adj_err)!=0:
  836. # print(adj_err)
  837. # np.set_printoptions(linewidth=200)
  838. # for print_i in range(adj_err.shape[0]):
  839. # print(adj_err[print_i].astype(int))
  840. # get x and y and adj
  841. # for small graph the rest are zero padded
  842. self.y_batch=adj_encoded
  843. # np.set_printoptions(linewidth=200,precision=3)
  844. # print('y\n')
  845. # for print_i in range(self.y_batch[i,:,:].shape[0]):
  846. # print(self.y_batch[i,:,:][print_i].astype(int))
  847. # print('x\n')
  848. # for print_i in range(self.x_batch[i, :, :].shape[0]):
  849. # print(self.x_batch[i, :, :][print_i].astype(int))
  850. # print('adj\n')
  851. # for print_i in range(self.adj_batch[i, :, :].shape[0]):
  852. # print(self.adj_batch[i, :, :][print_i].astype(int))
  853. # print('adj_norm\n')
  854. # for print_i in range(self.adj_norm_batch[i, :, :].shape[0]):
  855. # print(self.adj_norm_batch[i, :, :][print_i].astype(float))
  856. # print('feature\n')
  857. # for print_i in range(self.feature_batch[i, :, :].shape[0]):
  858. # print(self.feature_batch[i, :, :][print_i].astype(float))
  859. return self.y_batch,adj_copy
  860. # G = nx.ladder_graph(5)
  861. # # G = nx.grid_2d_graph(20,20)
  862. # # G = nx.ladder_graph(200)
  863. # graphs = [G]
  864. #
  865. # graphs, max_num_nodes = Graph_load_batch(min_num_nodes=6, name='ENZYMES')
  866. # sampler = Graph_sequence_sampler_flexible(graphs)
  867. #
  868. # y_max_all = []
  869. # for i in range(10000):
  870. # y_raw,adj_copy = sampler.sample()
  871. # y_max = max(len(y_raw[i]) for i in range(len(y_raw)))
  872. # y_max_all.append(y_max)
  873. # # print('max bfs node',y_max)
  874. # print('max', max(y_max_all))
  875. # print(y[1])
  876. # print(Variable(torch.FloatTensor(y[1])).cuda(CUDA))
  877. ########### potential use: an encoder along with the GraphRNN decoder
  878. # preprocess the adjacency matrix
  879. def preprocess(A):
  880. # Get size of the adjacency matrix
  881. size = len(A)
  882. # Get the degrees for each node
  883. degrees = np.sum(A, axis=1)+1
  884. # Create diagonal matrix D from the degrees of the nodes
  885. D = np.diag(np.power(degrees, -0.5).flatten())
  886. # Cholesky decomposition of D
  887. # D = np.linalg.cholesky(D)
  888. # Inverse of the Cholesky decomposition of D
  889. # D = np.linalg.inv(D)
  890. # Create an identity matrix of size x size
  891. I = np.eye(size)
  892. # Create A hat
  893. A_hat = A + I
  894. # Return A_hat
  895. A_normal = np.dot(np.dot(D,A_hat),D)
  896. return A_normal
  897. # truncate the output seqence to save representation, and allowing for infinite generation
  898. # now having a list of graphs
  899. class Graph_sequence_sampler_bfs_permute_truncate_multigraph():
  900. def __init__(self, G_list, max_node_num=25, batch_size=4, max_prev_node = 25, feature = None):
  901. self.batch_size = batch_size
  902. self.G_list = G_list
  903. self.n = max_node_num
  904. self.max_prev_node = max_prev_node
  905. self.adj_all = []
  906. for G in G_list:
  907. self.adj_all.append(np.asarray(nx.to_numpy_matrix(G)))
  908. self.has_feature = feature
  909. def sample(self):
  910. # batch, length, feature
  911. # self.x_batch = np.ones((self.batch_size, self.n - 1, self.max_prev_node))
  912. x_batch = np.zeros((self.batch_size, self.n, self.max_prev_node)) # here zeros are padded for small graph
  913. # self.x_batch[:,0,:] = np.ones((self.batch_size, self.max_prev_node)) # first input is all ones
  914. # batch, length, feature
  915. y_batch = np.zeros((self.batch_size, self.n, self.max_prev_node)) # here zeros are padded for small graph
  916. # batch, length, length
  917. adj_batch = np.zeros((self.batch_size, self.n, self.n)) # here zeros are padded for small graph
  918. # batch, size, size
  919. adj_norm_batch = np.zeros((self.batch_size, self.n, self.n)) # here zeros are padded for small graph
  920. # batch, size, feature_len: degree and clustering coefficient
  921. if self.has_feature is None:
  922. feature_batch = np.zeros((self.batch_size, self.n, self.n)) # use one hot feature
  923. else:
  924. feature_batch = np.zeros((self.batch_size, self.n, 2))
  925. # generate input x, y pairs
  926. for i in range(self.batch_size):
  927. time0 = time.time()
  928. # first sample and get a permuted adj
  929. adj_idx = np.random.randint(len(self.adj_all))
  930. adj_copy = self.adj_all[adj_idx].copy()
  931. # print('Graph size', adj_copy.shape[0])
  932. x_idx = np.random.permutation(adj_copy.shape[0])
  933. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  934. adj_copy_matrix = np.asmatrix(adj_copy)
  935. G = nx.from_numpy_matrix(adj_copy_matrix)
  936. time1 = time.time()
  937. # then do bfs in the permuted G
  938. start_idx = np.random.randint(adj_copy.shape[0])
  939. x_idx = np.array(bfs_seq(G, start_idx))
  940. adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
  941. # get the feature for the permuted G
  942. node_list = [G.nodes()[i] for i in x_idx]
  943. feature_degree = np.array(list(G.degree(node_list).values()))[:,np.newaxis]
  944. feature_clustering = np.array(list(nx.clustering(G,nodes=node_list).values()))[:,np.newaxis]
  945. time2 = time.time()
  946. # dict = nx.bfs_successors(G, start_idx)
  947. # print('dict', dict, 'node num', self.G.number_of_nodes())
  948. # print('x idx', x_idx, 'len', len(x_idx))
  949. # print('adj')
  950. # np.set_printoptions(linewidth=200)
  951. # for print_i in range(adj_copy.shape[0]):
  952. # print(adj_copy[print_i].astype(int))
  953. # adj_before = adj_copy.copy()
  954. # encode adj
  955. adj_encoded = encode_adj(adj_copy.copy(), max_prev_node=self.max_prev_node)
  956. # print('adj encoded')
  957. # np.set_printoptions(linewidth=200)
  958. # for print_i in range(adj_copy.shape[0]):
  959. # print(adj_copy[print_i].astype(int))
  960. # decode adj
  961. # print('adj recover error')
  962. # adj_decode = decode_adj(adj_encoded.copy(), max_prev_node=self.max_prev_node)
  963. # adj_err = adj_decode-adj_copy
  964. # print(np.sum(adj_err))
  965. # if np.sum(adj_err)!=0:
  966. # print(adj_err)
  967. # np.set_printoptions(linewidth=200)
  968. # for print_i in range(adj_err.shape[0]):
  969. # print(adj_err[print_i].astype(int))
  970. # get x and y and adj
  971. # for small graph the rest are zero padded
  972. y_batch[i, 0:adj_encoded.shape[0], :] = adj_encoded
  973. x_batch[i, 1:adj_encoded.shape[0]+1, :] = adj_encoded
  974. adj_batch[i, 0:adj_copy.shape[0], 0:adj_copy.shape[0]] = adj_copy
  975. adj_copy_norm = preprocess(adj_copy)
  976. time3 = time.time()
  977. adj_norm_batch[i, 0:adj_copy.shape[0], 0:adj_copy.shape[0]] = adj_copy_norm
  978. if self.has_feature is None:
  979. feature_batch[i, 0:adj_copy.shape[0], 0:adj_copy.shape[0]] = np.eye(adj_copy.shape[0])
  980. else:
  981. feature_batch[i,0:adj_copy.shape[0],:] = np.concatenate((feature_degree,feature_clustering),axis=1)
  982. # np.set_printoptions(linewidth=200,precision=3)
  983. # print('y\n')
  984. # for print_i in range(self.y_batch[i,:,:].shape[0]):
  985. # print(self.y_batch[i,:,:][print_i].astype(int))
  986. # print('x\n')
  987. # for print_i in range(self.x_batch[i, :, :].shape[0]):
  988. # print(self.x_batch[i, :, :][print_i].astype(int))
  989. # print('adj\n')
  990. # for print_i in range(self.adj_batch[i, :, :].shape[0]):
  991. # print(self.adj_batch[i, :, :][print_i].astype(int))
  992. # print('adj_norm\n')
  993. # for print_i in range(self.adj_norm_batch[i, :, :].shape[0]):
  994. # print(self.adj_norm_batch[i, :, :][print_i].astype(float))
  995. # print('feature\n')
  996. # for print_i in range(self.feature_batch[i, :, :].shape[0]):
  997. # print(self.feature_batch[i, :, :][print_i].astype(float))
  998. time4 = time.time()
  999. # print('1 ',time1-time0)
  1000. # print('2 ',time2-time1)
  1001. # print('3 ',time3-time2)
  1002. # print('4 ',time4-time3)
  1003. # print('x_batch\n',self.x_batch)
  1004. # print('y_batch\n',self.y_batch)
  1005. return torch.from_numpy(x_batch).float(), torch.from_numpy(y_batch).float(),\
  1006. torch.from_numpy(adj_batch).float(), torch.from_numpy(adj_norm_batch).float(), torch.from_numpy(feature_batch).float()
  1007. # generate own synthetic dataset
  1008. def Graph_synthetic(seed):
  1009. G = nx.Graph()
  1010. np.random.seed(seed)
  1011. base = np.repeat(np.eye(5), 20, axis=0)
  1012. rand = np.random.randn(100, 5) * 0.05
  1013. node_features = base + rand
  1014. # # print('node features')
  1015. # for i in range(node_features.shape[0]):
  1016. # print(np.around(node_features[i], decimals=4))
  1017. node_distance_l1 = np.ones((node_features.shape[0], node_features.shape[0]))
  1018. node_distance_np = np.zeros((node_features.shape[0], node_features.shape[0]))
  1019. for i in range(node_features.shape[0]):
  1020. for j in range(node_features.shape[0]):
  1021. if i != j:
  1022. node_distance_l1[i,j] = np.sum(np.abs(node_features[i] - node_features[j]))
  1023. # print('node distance', node_distance_l1[i,j])
  1024. node_distance_np[i, j] = 1 / np.sum(np.abs(node_features[i] - node_features[j]) ** 2)
  1025. print('node distance max', np.max(node_distance_l1))
  1026. print('node distance min', np.min(node_distance_l1))
  1027. node_distance_np_sum = np.sum(node_distance_np, axis=1, keepdims=True)
  1028. embedding_dist = node_distance_np / node_distance_np_sum
  1029. # generate the graph
  1030. average_degree = 9
  1031. for i in range(node_features.shape[0]):
  1032. for j in range(i + 1, embedding_dist.shape[0]):
  1033. p = np.random.rand()
  1034. if p < embedding_dist[i, j] * average_degree:
  1035. G.add_edge(i, j)
  1036. G.remove_nodes_from(nx.isolates(G))
  1037. print('num of nodes', G.number_of_nodes())
  1038. print('num of edges', G.number_of_edges())
  1039. G_deg = nx.degree_histogram(G)
  1040. G_deg_sum = [a * b for a, b in zip(G_deg, range(0, len(G_deg)))]
  1041. print('average degree', sum(G_deg_sum) / G.number_of_nodes())
  1042. print('average path length', nx.average_shortest_path_length(G))
  1043. print('diameter', nx.diameter(G))
  1044. G_cluster = sorted(list(nx.clustering(G).values()))
  1045. print('average clustering coefficient', sum(G_cluster) / len(G_cluster))
  1046. print('Graph generation complete!')
  1047. # node_features = np.concatenate((node_features, np.zeros((1,node_features.shape[1]))),axis=0)
  1048. return G, node_features
  1049. # G = Graph_synthetic(10)
  1050. # return adj and features from a single graph
  1051. class GraphDataset_adj(torch.utils.data.Dataset):
  1052. """Graph Dataset"""
  1053. def __init__(self, G, features=None):
  1054. self.G = G
  1055. self.n = G.number_of_nodes()
  1056. adj = np.asarray(nx.to_numpy_matrix(self.G))
  1057. # permute adj
  1058. subgraph_idx = np.random.permutation(self.n)
  1059. # subgraph_idx = np.arange(self.n)
  1060. adj = adj[np.ix_(subgraph_idx, subgraph_idx)]
  1061. self.adj = torch.from_numpy(adj+np.eye(len(adj))).float()
  1062. self.adj_norm = torch.from_numpy(preprocess(adj)).float()
  1063. if features is None:
  1064. self.features = torch.Tensor(self.n, self.n)
  1065. self.features = nn.init.eye(self.features)
  1066. else:
  1067. features = features[subgraph_idx,:]
  1068. self.features = torch.from_numpy(features).float()
  1069. print('embedding size', self.features.size())
  1070. def __len__(self):
  1071. return 1
  1072. def __getitem__(self, idx):
  1073. sample = {'adj':self.adj,'adj_norm':self.adj_norm, 'features':self.features}
  1074. return sample
  1075. # G = nx.karate_club_graph()
  1076. # dataset = GraphDataset_adj(G)
  1077. # train_loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True, num_workers=1)
  1078. # for data in train_loader:
  1079. # print(data)
  1080. # return adj and features from a list of graphs
  1081. class GraphDataset_adj_batch(torch.utils.data.Dataset):
  1082. """Graph Dataset"""
  1083. def __init__(self, graphs, has_feature = True, num_nodes = 20):
  1084. self.graphs = graphs
  1085. self.has_feature = has_feature
  1086. self.num_nodes = num_nodes
  1087. def __len__(self):
  1088. return len(self.graphs)
  1089. def __getitem__(self, idx):
  1090. adj_raw = np.asarray(nx.to_numpy_matrix(self.graphs[idx]))
  1091. np.fill_diagonal(adj_raw,0) # in case the self connection already exists
  1092. # sample num_nodes size subgraph
  1093. subgraph_idx = np.random.permutation(adj_raw.shape[0])[0:self.num_nodes]
  1094. adj_raw = adj_raw[np.ix_(subgraph_idx,subgraph_idx)]
  1095. adj = torch.from_numpy(adj_raw+np.eye(len(adj_raw))).float()
  1096. adj_norm = torch.from_numpy(preprocess(adj_raw)).float()
  1097. adj_raw = torch.from_numpy(adj_raw).float()
  1098. if self.has_feature:
  1099. dictionary = nx.get_node_attributes(self.graphs[idx], 'feature')
  1100. features = np.zeros((self.num_nodes, list(dictionary.values())[0].shape[0]))
  1101. for i in range(self.num_nodes):
  1102. features[i, :] = list(dictionary.values())[subgraph_idx[i]]
  1103. # normalize
  1104. features -= np.mean(features, axis=0)
  1105. epsilon = 1e-6
  1106. features /= (np.std(features, axis=0)+epsilon)
  1107. features = torch.from_numpy(features).float()
  1108. else:
  1109. n = self.num_nodes
  1110. features = torch.Tensor(n, n)
  1111. features = nn.init.eye(features)
  1112. sample = {'adj':adj,'adj_norm':adj_norm, 'features':features, 'adj_raw':adj_raw}
  1113. return sample
  1114. # return adj and features from a list of graphs, batch size = 1, so that graphs can have various size each time
  1115. class GraphDataset_adj_batch_1(torch.utils.data.Dataset):
  1116. """Graph Dataset"""
  1117. def __init__(self, graphs, has_feature=True):
  1118. self.graphs = graphs
  1119. self.has_feature = has_feature
  1120. def __len__(self):
  1121. return len(self.graphs)
  1122. def __getitem__(self, idx):
  1123. adj_raw = np.asarray(nx.to_numpy_matrix(self.graphs[idx]))
  1124. np.fill_diagonal(adj_raw, 0) # in case the self connection already exists
  1125. n = adj_raw.shape[0]
  1126. # give a permutation
  1127. subgraph_idx = np.random.permutation(n)
  1128. # subgraph_idx = np.arange(n)
  1129. adj_raw = adj_raw[np.ix_(subgraph_idx, subgraph_idx)]
  1130. adj = torch.from_numpy(adj_raw + np.eye(len(adj_raw))).float()
  1131. adj_norm = torch.from_numpy(preprocess(adj_raw)).float()
  1132. if self.has_feature:
  1133. dictionary = nx.get_node_attributes(self.graphs[idx], 'feature')
  1134. features = np.zeros((n, list(dictionary.values())[0].shape[0]))
  1135. for i in range(n):
  1136. features[i, :] = list(dictionary.values())[i]
  1137. features = features[subgraph_idx, :]
  1138. # normalize
  1139. features -= np.mean(features, axis=0)
  1140. epsilon = 1e-6
  1141. features /= (np.std(features, axis=0) + epsilon)
  1142. features = torch.from_numpy(features).float()
  1143. else:
  1144. features = torch.Tensor(n, n)
  1145. features = nn.init.eye(features)
  1146. sample = {'adj': adj, 'adj_norm': adj_norm, 'features': features}
  1147. return sample
  1148. # get one node at a time, for a single graph
  1149. class GraphDataset(torch.utils.data.Dataset):
  1150. """Graph Dataset"""
  1151. def __init__(self, G, hops = 1, max_degree = 5, vocab_size = 35, embedding_dim = 35, embedding = None, shuffle_neighbour = True):
  1152. self.G = G
  1153. self.shuffle_neighbour = shuffle_neighbour
  1154. self.hops = hops
  1155. self.max_degree = max_degree
  1156. if embedding is None:
  1157. self.embedding = torch.Tensor(vocab_size, embedding_dim)
  1158. self.embedding = nn.init.eye(self.embedding)
  1159. else:
  1160. self.embedding = torch.from_numpy(embedding).float()
  1161. print('embedding size', self.embedding.size())
  1162. def __len__(self):
  1163. return len(self.G.nodes())
  1164. def __getitem__(self, idx):
  1165. idx = idx+1
  1166. idx_list = [idx]
  1167. node_list = [self.embedding[idx].view(-1, self.embedding.size(1))]
  1168. node_count_list = []
  1169. for i in range(self.hops):
  1170. # sample this hop
  1171. adj_list = np.array([])
  1172. adj_count_list = np.array([])
  1173. for idx in idx_list:
  1174. if self.shuffle_neighbour:
  1175. adj_list_new = list(self.G.adj[idx - 1])
  1176. random.shuffle(adj_list_new)
  1177. adj_list_new = np.array(adj_list_new) + 1
  1178. else:
  1179. adj_list_new = np.array(list(self.G.adj[idx-1]))+1
  1180. adj_count_list_new = np.array([len(adj_list_new)])
  1181. adj_list = np.concatenate((adj_list, adj_list_new), axis=0)
  1182. adj_count_list = np.concatenate((adj_count_list, adj_count_list_new), axis=0)
  1183. # print(i, adj_list)
  1184. # print(i, embedding(Variable(torch.from_numpy(adj_list)).long()))
  1185. index = torch.from_numpy(adj_list).long()
  1186. adj_list_emb = self.embedding[index]
  1187. node_list.append(adj_list_emb)
  1188. node_count_list.append(adj_count_list)
  1189. idx_list = adj_list
  1190. # padding, used as target
  1191. idx_list = [idx]
  1192. node_list_pad = [self.embedding[idx].view(-1, self.embedding.size(1))]
  1193. node_count_list_pad = []
  1194. node_adj_list = []
  1195. for i in range(self.hops):
  1196. adj_list = np.zeros(self.max_degree ** (i + 1))
  1197. adj_count_list = np.ones(self.max_degree ** (i)) * self.max_degree
  1198. for j, idx in enumerate(idx_list):
  1199. if idx == 0:
  1200. adj_list_new = np.zeros(self.max_degree)
  1201. else:
  1202. if self.shuffle_neighbour:
  1203. adj_list_new = list(self.G.adj[idx - 1])
  1204. # random.shuffle(adj_list_new)
  1205. adj_list_new = np.array(adj_list_new) + 1
  1206. else:
  1207. adj_list_new = np.array(list(self.G.adj[idx-1]))+1
  1208. start_idx = j * self.max_degree
  1209. incre_idx = min(self.max_degree, adj_list_new.shape[0])
  1210. adj_list[start_idx:start_idx + incre_idx] = adj_list_new[:incre_idx]
  1211. index = torch.from_numpy(adj_list).long()
  1212. adj_list_emb = self.embedding[index]
  1213. node_list_pad.append(adj_list_emb)
  1214. node_count_list_pad.append(adj_count_list)
  1215. idx_list = adj_list
  1216. # calc adj matrix
  1217. node_adj = torch.zeros(index.size(0),index.size(0))
  1218. for first in range(index.size(0)):
  1219. for second in range(first, index.size(0)):
  1220. if index[first]==index[second]:
  1221. node_adj[first,second] = 1
  1222. node_adj[second,first] = 1
  1223. elif self.G.has_edge(index[first],index[second]):
  1224. node_adj[first, second] = 0.5
  1225. node_adj[second, first] = 0.5
  1226. node_adj_list.append(node_adj)
  1227. node_list = list(reversed(node_list))
  1228. node_count_list = list(reversed(node_count_list))
  1229. node_list_pad = list(reversed(node_list_pad))
  1230. node_count_list_pad = list(reversed(node_count_list_pad))
  1231. node_adj_list = list(reversed(node_adj_list))
  1232. sample = {'node_list':node_list, 'node_count_list':node_count_list,
  1233. 'node_list_pad':node_list_pad, 'node_count_list_pad':node_count_list_pad, 'node_adj_list':node_adj_list}
  1234. return sample