You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphvae_train.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import argparse
  2. from time import strftime, gmtime
  3. import statistics
  4. import matplotlib.pyplot as plt
  5. import networkx as nx
  6. import numpy as np
  7. from numpy import count_nonzero
  8. import os
  9. import shutil
  10. from random import shuffle
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.init as init
  14. from tensorboard_logger import configure, log_value
  15. from torch.autograd import Variable
  16. import torch.nn.functional as F
  17. from torch import optim
  18. from torch.optim.lr_scheduler import MultiStepLR
  19. from baselines.graphvae.evaluate import evaluate
  20. from baselines.graphvae.evaluate2 import evaluate as evaluate2
  21. from statistics import mean
  22. import data
  23. # from baselines.graphvae.graphvae_args import Graph_VAE_Args
  24. from baselines.graphvae.graphvae_model import GraphVAE
  25. from baselines.graphvae.graphvae_data import GraphAdjSampler
  26. from baselines.graphvae.args import GraphVAE_Args
  27. from baselines.graphvae.util import *
  28. CUDA = 0
  29. vae_args = GraphVAE_Args()
  30. LR_milestones = [100, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]
  31. def graph_statistics(graphs):
  32. sparsity = 0
  33. matrix_size_list = []
  34. counter = 0
  35. for i in range(len(graphs)):
  36. numpy_matrix = nx.to_numpy_matrix(graphs[i])
  37. non_zero = np.count_nonzero(numpy_matrix)
  38. # if numpy_matrix.shape[0] == 12:
  39. # print("^^^^^^^^^^^^^^^^^^^^^^^^^^^")
  40. # print(non_zero)
  41. # print(numpy_matrix.shape)
  42. if non_zero == numpy_matrix.shape[0] * numpy_matrix.shape[1] - numpy_matrix.shape[1]:
  43. counter += 1
  44. # print("salam")
  45. sparsity += 1.0 - (count_nonzero(numpy_matrix) / float(numpy_matrix.size))
  46. matrix_size_list.append(numpy_matrix.shape[0])
  47. smallest_graph_size = min(matrix_size_list)
  48. largest_graph_size = max(matrix_size_list)
  49. graph_size_std = statistics.stdev(matrix_size_list)
  50. sparsity /= len(graphs)
  51. print("*** smallest_graph_size = " + str(smallest_graph_size) +
  52. " *** largest_graph_size = " + str(largest_graph_size) +
  53. " *** mean_graph_size = " + str(statistics.mean(matrix_size_list)) +
  54. " *** graph_size_std = " + str(graph_size_std) +
  55. " *** average_graph_sparsity = " + str(sparsity))
  56. print("*** counter")
  57. print(counter)
  58. return
  59. def build_model(args, max_num_nodes):
  60. out_dim = max_num_nodes * (max_num_nodes + 1) // 2
  61. if args.feature_type == 'id':
  62. if vae_args.completion_mode_small_parameter_size:
  63. input_dim = max_num_nodes - vae_args.number_of_missing_nodes
  64. else:
  65. input_dim = max_num_nodes
  66. elif args.feature_type == 'deg':
  67. input_dim = 1
  68. elif args.feature_type == 'struct':
  69. input_dim = 2
  70. model = GraphVAE(input_dim, 16, 256, max_num_nodes, vae_args)
  71. return model
  72. def train(args, dataloader, test_dataset_loader, graphs_test, model):
  73. epoch = 1
  74. optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
  75. scheduler = MultiStepLR(optimizer, milestones=LR_milestones, gamma=0.1)
  76. model.train()
  77. # A = Graph_VAE_Args()
  78. for epoch in range(211):
  79. test_mae_list = []
  80. test_roc_score_list = []
  81. test_ap_score_list = []
  82. test_precision_list = []
  83. test_recall_list = []
  84. for batch_idx, data in enumerate(dataloader):
  85. model.zero_grad()
  86. features = data['features'].float()
  87. batch_num_nodes = data['num_nodes'].int().numpy()
  88. # print("features")
  89. # print(features)
  90. # print("************************************************")
  91. adj_input = data['adj'].float()
  92. # print("adj_input")
  93. # print(adj_input)
  94. # print("************************************************")
  95. features = Variable(features).cuda()
  96. adj_input = Variable(adj_input).cuda()
  97. loss = model(features, adj_input, batch_num_nodes)
  98. print('Epoch: ', epoch, ', Iter: ', batch_idx, ', Loss: ', loss)
  99. loss.backward()
  100. optimizer.step()
  101. scheduler.step()
  102. if epoch % 10 == 0:
  103. fname = vae_args.model_save_path + "GraphVAE" + str(epoch) + '.dat'
  104. log_value("Training Loss, dataset: " + arg_parse().dataset, loss.data, epoch * 32 + batch_idx)
  105. # test_mae, test_tp_div_pos, test_roc_score, test_ap_score = evaluate(test_dataset_loader, model)
  106. test_mae, test_roc_score, test_ap_score, test_precision, test_recall = evaluate2(
  107. graphs_test, model)
  108. test_mae_list.append(test_mae)
  109. test_roc_score_list.append(test_roc_score)
  110. test_ap_score_list.append(test_ap_score)
  111. test_precision_list.append(test_precision)
  112. test_recall_list.append(test_recall)
  113. log_value("Test MAE, dataset: " + arg_parse().dataset, test_mae, epoch * 32 + batch_idx)
  114. # log_value("Test test_tp_div_pos, dataset: " + arg_parse().dataset, test_tp_div_pos,
  115. # epoch * 32 + batch_idx)
  116. log_value("Test test_roc_score, dataset: " + arg_parse().dataset, test_roc_score,
  117. epoch * 32 + batch_idx)
  118. log_value("Test test_ap_score, dataset: " + arg_parse().dataset, test_ap_score,
  119. epoch * 32 + batch_idx)
  120. if epoch % 50 == 0 and epoch != 0:
  121. torch.save(model.state_dict(), fname)
  122. # test_mae = evaluate(test_dataset_loader, model, True)
  123. # break
  124. if len(test_mae_list) > 0:
  125. precision = mean(test_precision_list)
  126. recall = mean(test_recall_list)
  127. test_F_Measure = 2 * precision * recall / (precision + recall)
  128. print(
  129. "In Train: *** MAE - roc_score - ap_score - precision - recall - F_Measure : " + str(
  130. mean(test_mae_list)) + " _ "
  131. + str(mean(test_roc_score_list)) + " _ " + str(mean(test_ap_score_list)) + " _ "
  132. + str(precision) + " _ " + str(recall) + " _ "
  133. + str(test_F_Measure))
  134. def arg_parse():
  135. parser = argparse.ArgumentParser(description='GraphVAE arguments.')
  136. io_parser = parser.add_mutually_exclusive_group(required=False)
  137. io_parser.add_argument('--dataset', dest='dataset',
  138. help='Input dataset.')
  139. parser.add_argument('--lr', dest='lr', type=float,
  140. help='Learning rate.')
  141. parser.add_argument('--batch_size', dest='batch_size', type=int,
  142. help='Batch size.')
  143. parser.add_argument('--batch_ratio', dest='batch_ratio', type=int,
  144. help='Batch ratio.')
  145. parser.add_argument('--num_workers', dest='num_workers', type=int,
  146. help='Number of workers to load data.')
  147. parser.add_argument('--max_num_nodes', dest='max_num_nodes', type=int,
  148. help='Predefined maximum number of nodes in train/test graphs. -1 if determined by \
  149. training data.')
  150. parser.add_argument('--feature', dest='feature_type',
  151. help='Feature used for encoder. Can be: id, deg')
  152. parser.set_defaults(dataset='REDDITMULTI5K',
  153. feature_type='id',
  154. lr=0.01,
  155. batch_size=32,
  156. batch_ratio=10,
  157. num_workers=4,
  158. max_num_nodes=-1)
  159. return parser.parse_args()
  160. def main():
  161. prog_args = arg_parse()
  162. os.environ['CUDA_VISIBLE_DEVICES'] = str(CUDA)
  163. print('CUDA', CUDA)
  164. torch.manual_seed(1234)
  165. ### running log
  166. if prog_args.dataset == 'enzymes':
  167. print("SALAAAAAAAAAAAAAAAAAAAAAAAAAAAMMMMMM")
  168. graphs = data.Graph_load_batch(min_num_nodes=10, name='ENZYMES')
  169. num_graphs_raw = len(graphs)
  170. # print(num_graphs_raw)
  171. # print(type(graphs))
  172. # matrix = nx.to_numpy_matrix(graphs[1])
  173. # print(matrix.shape)
  174. elif prog_args.dataset == 'dd':
  175. graphs = data.Graph_load_batch(min_num_nodes=10, name='DD')
  176. num_graphs_raw = len(graphs)
  177. elif prog_args.dataset == 'ladder':
  178. graphs = []
  179. for i in range(100, 201):
  180. graphs.append(nx.ladder_graph(i))
  181. elif prog_args.dataset == 'barabasi':
  182. graphs = []
  183. for i in range(100, 200):
  184. for j in range(4, 5):
  185. for k in range(5):
  186. graphs.append(nx.barabasi_albert_graph(i, j))
  187. elif prog_args.dataset == 'citeseer':
  188. _, _, G = data.Graph_load(dataset='citeseer')
  189. G = max(nx.connected_component_subgraphs(G), key=len)
  190. G = nx.convert_node_labels_to_integers(G)
  191. graphs = []
  192. for i in range(G.number_of_nodes()):
  193. G_ego = nx.ego_graph(G, i, radius=3)
  194. if G_ego.number_of_nodes() >= 50 and (G_ego.number_of_nodes() <= 400):
  195. graphs.append(G_ego)
  196. elif prog_args.dataset == 'grid':
  197. graphs = []
  198. # for i in range(10, 20):
  199. # for j in range(10, 20):
  200. # graphs.append(nx.grid_2d_graph(i, j))
  201. # for i in range(5,10):
  202. # for j in range(5,10):
  203. # graphs.append(nx.grid_2d_graph(i,j))
  204. # *********************************
  205. graphs.append(nx.grid_2d_graph(2, 3))
  206. # graphs.append(nx.grid_2d_graph(2, 2))
  207. # graphs.append(nx.grid_2d_graph(2, 2))
  208. # graphs.append(nx.grid_2d_graph(2, 3))
  209. # graphs.append(nx.grid_2d_graph(2, 2))
  210. # graphs.append(nx.grid_2d_graph(2, 3))
  211. # graphs.append(nx.grid_2d_graph(2, 2))
  212. # graphs.append(nx.grid_2d_graph(2, 3))
  213. # graphs.append(nx.grid_2d_graph(4, 2))
  214. # graphs.append(nx.grid_2d_graph(3, 2))
  215. # graphs.append(nx.grid_2d_graph(3, 2))
  216. # graphs.append(nx.grid_2d_graph(1, 4))
  217. # graphs.append(nx.grid_2d_graph(1, 4))
  218. # graphs.append(nx.grid_2d_graph(1, 4))
  219. # graphs.append(nx.grid_2d_graph(4, 1))
  220. # graphs.append(nx.grid_2d_graph(1, 6))
  221. # graphs.append(nx.grid_2d_graph(6, 1))
  222. ###############################################################
  223. # graphs.append(nx.grid_2d_graph(3, 4))
  224. # graphs.append(nx.grid_2d_graph(1, 12))
  225. # graphs.append(nx.grid_2d_graph(2, 6))
  226. graphs.append(nx.grid_2d_graph(3, 4))
  227. # graphs.append(nx.grid_2d_graph(4, 3))
  228. graphs.append(nx.grid_2d_graph(6, 2))
  229. # graphs.append(nx.grid_2d_graph(12, 1))
  230. # # *********************************
  231. # graphs.append(nx.grid_2d_graph(1, 24))
  232. # graphs.append(nx.grid_2d_graph(2, 12))
  233. # graphs.append(nx.grid_2d_graph(3, 8))
  234. graphs.append(nx.grid_2d_graph(4, 6))
  235. graphs.append(nx.grid_2d_graph(6, 4))
  236. graphs.append(nx.grid_2d_graph(8, 3))
  237. graphs.append(nx.grid_2d_graph(12, 2))
  238. # graphs.append(nx.grid_2d_graph(24, 1))
  239. num_graphs_raw = len(graphs)
  240. elif prog_args.dataset == 'grid_big':
  241. graphs = []
  242. for i in range(36, 46):
  243. for j in range(36, 46):
  244. graphs.append(nx.grid_2d_graph(i, j))
  245. num_graphs_raw = len(graphs)
  246. elif prog_args.dataset == 'grid_small':
  247. graphs = []
  248. for i in range(2, 5):
  249. for j in range(2, 5):
  250. graphs.append(nx.grid_2d_graph(i, j))
  251. num_graphs_raw = len(graphs)
  252. else:
  253. graphs, num_classes = load_data(prog_args.dataset, True)
  254. # graphs = data.Graph_load_batch(min_num_nodes=10, name='DD')
  255. num_graphs_raw = len(graphs)
  256. if prog_args.max_num_nodes == -1:
  257. # max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  258. # print("@@@ max")
  259. # print(max_num_nodes)
  260. min_num_nodes = min([graphs[i].number_of_nodes() for i in range(len(graphs))])
  261. # print("@@@ min")
  262. # print(min_num_nodes)
  263. small_graphs_size = 0
  264. if prog_args.dataset != 'grid_small' and prog_args.dataset != 'grid':
  265. small_graphs = []
  266. for i in range(len(graphs)):
  267. if graphs[i].number_of_nodes() < 41:
  268. # if graphs[i].number_of_nodes() == 8 or graphs[i].number_of_nodes() == 16 or graphs[
  269. # i].number_of_nodes() == 32 or \
  270. # graphs[i].number_of_nodes() == 64 or graphs[i].number_of_nodes() == 128 or graphs[
  271. # i].number_of_nodes() == 256:
  272. small_graphs_size += 1
  273. small_graphs.append(graphs[i])
  274. graphs = small_graphs
  275. print(len(graphs))
  276. max_num_nodes = max([graphs[i].number_of_nodes() for i in range(len(graphs))])
  277. graph_statistics(graphs)
  278. else:
  279. max_num_nodes = prog_args.max_num_nodes
  280. # remove graphs with number of nodes greater than max_num_nodes
  281. graphs = [g for g in graphs if g.number_of_nodes() <= max_num_nodes]
  282. graphs_len = len(graphs)
  283. # print('Number of graphs removed due to upper-limit of number of nodes: ',
  284. # num_graphs_raw - graphs_len)
  285. graphs_test = graphs[int(0.8 * graphs_len):]
  286. # graphs_train = graphs[0:int(0.8*graphs_len)]
  287. # prepare train and test data
  288. random.seed(123)
  289. shuffle(graphs)
  290. graphs_len = len(graphs)
  291. graphs_test = graphs[int(0.8 * graphs_len):]
  292. # print("**** Test graphs statistics:")
  293. # print(len(graphs_test))
  294. # graph_statistics(graphs_test)
  295. # #################################################################
  296. kronEM_graphs = []
  297. for i in range(len(graphs_test)):
  298. if graphs_test[i].number_of_nodes() == 8 or graphs_test[i].number_of_nodes() == 16 or \
  299. graphs_test[i].number_of_nodes() == 32 or graphs_test[i].number_of_nodes() == 64 or graphs_test[
  300. i].number_of_nodes() == 128:
  301. kronEM_graphs.append(graphs_test[i])
  302. prepare_kronEM_data(kronEM_graphs, prog_args.dataset, True)
  303. # #################################################################
  304. graphs_train = graphs[0:int(0.8 * graphs_len)]
  305. # print("**** Train graphs statistics:")
  306. # print(len(graphs_train))
  307. # graphs_train = graphs
  308. save_graphs_as_mat(graphs_test)
  309. print('total graph num: {}, training set: {}'.format(len(graphs), len(graphs_train)))
  310. # print('max number node: {}'.format(max_num_nodes))
  311. # print('min number node: {}'.format(min_num_nodes))
  312. # print('small graphs size: {}'.format(small_graphs_size))
  313. dataset = GraphAdjSampler(graphs_train, max_num_nodes, vae_args.permutation_mode, vae_args.bfs_mode,
  314. vae_args.bfs_mode_with_arbitrary_node_deleted,
  315. features=prog_args.feature_type)
  316. test_dataset = GraphAdjSampler(graphs_test, max_num_nodes, vae_args.permutation_mode, vae_args.bfs_mode,
  317. vae_args.bfs_mode_with_arbitrary_node_deleted,
  318. features=prog_args.feature_type)
  319. # sample_strategy = torch.utils.data.sampler.WeightedRandomSampler(
  320. # [1.0 / len(dataset) for i in range(len(dataset))],
  321. # num_samples=prog_args.batch_size,
  322. # replacement=False)
  323. sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
  324. num_samples=prog_args.batch_size * prog_args.batch_ratio,
  325. replacement=True)
  326. test_sample_strategy = torch.utils.data.sampler.WeightedRandomSampler(
  327. [1.0 / len(test_dataset) for i in range(len(test_dataset))],
  328. num_samples=prog_args.batch_size * prog_args.batch_ratio,
  329. replacement=True)
  330. dataset_loader = torch.utils.data.DataLoader(
  331. dataset,
  332. batch_size=prog_args.batch_size,
  333. num_workers=prog_args.num_workers,
  334. sampler=sample_strategy)
  335. test_dataset_loader = torch.utils.data.DataLoader(
  336. test_dataset,
  337. batch_size=prog_args.batch_size,
  338. num_workers=prog_args.num_workers,
  339. sampler=test_sample_strategy)
  340. model = build_model(prog_args, max_num_nodes).cuda()
  341. train(prog_args, dataset_loader, test_dataset_loader, graphs_test, model)
  342. if __name__ == '__main__':
  343. if not os.path.isdir(vae_args.model_save_path):
  344. os.makedirs(vae_args.model_save_path)
  345. # configure(my_args.tensorboard_path, flush_secs=5)
  346. time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
  347. if vae_args.clean_tensorboard:
  348. if os.path.isdir("tensorboard"):
  349. shutil.rmtree("tensorboard")
  350. configure("tensorboard/run" + time, flush_secs=5)
  351. main()