You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

analysis.py 11KB

4 years ago

  1. # this file is used to plot images
  2. from main import *
  3. args = Args()
  4. print(args.graph_type, args.note)
  5. # epoch = 16000
  6. epoch = 3000
  7. sample_time = 3
  8. def find_nearest_idx(array,value):
  9. idx = (np.abs(array-value)).argmin()
  10. return idx
  11. # for baseline model
  12. for num_layers in range(4,5):
  13. # give file name and figure name
  14. fname_real = args.graph_save_path + args.fname_real + str(0)
  15. fname_pred = args.graph_save_path + args.fname_pred + str(epoch) +'_'+str(sample_time)
  16. figname = args.figure_save_path + args.fname + str(epoch) +'_'+str(sample_time)
  17. # fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
  18. # str(epoch) + '_real_' + str(True) + '_' + str(num_layers)
  19. # fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
  20. # str(epoch) + '_pred_' + str(True) + '_' + str(num_layers)
  21. # figname = args.figure_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
  22. # str(epoch) + '_' + str(num_layers)
  23. print(fname_real)
  24. print(fname_pred)
  25. # load data
  26. graph_real_list = load_graph_list(fname_real + '.dat')
  27. shuffle(graph_real_list)
  28. graph_pred_list_raw = load_graph_list(fname_pred + '.dat')
  29. graph_real_len_list = np.array([len(graph_real_list[i]) for i in range(len(graph_real_list))])
  30. graph_pred_len_list_raw = np.array([len(graph_pred_list_raw[i]) for i in range(len(graph_pred_list_raw))])
  31. graph_pred_list = graph_pred_list_raw
  32. graph_pred_len_list = graph_pred_len_list_raw
  33. # # select samples
  34. # graph_pred_list = []
  35. # graph_pred_len_list = []
  36. # for value in graph_real_len_list:
  37. # pred_idx = find_nearest_idx(graph_pred_len_list_raw, value)
  38. # graph_pred_list.append(graph_pred_list_raw[pred_idx])
  39. # graph_pred_len_list.append(graph_pred_len_list_raw[pred_idx])
  40. # # delete
  41. # graph_pred_len_list_raw=np.delete(graph_pred_len_list_raw, pred_idx)
  42. # del graph_pred_list_raw[pred_idx]
  43. # if len(graph_pred_list)==200:
  44. # break
  45. # graph_pred_len_list = np.array(graph_pred_len_list)
  46. # # select pred data within certain range
  47. # len_min = np.amin(graph_real_len_list)
  48. # len_max = np.amax(graph_real_len_list)
  49. # pred_index = np.where((graph_pred_len_list>=len_min)&(graph_pred_len_list<=len_max))
  50. # # print(pred_index[0])
  51. # graph_pred_list = [graph_pred_list[i] for i in pred_index[0]]
  52. # graph_pred_len_list = graph_pred_len_list[pred_index[0]]
  53. # real_order = np.argsort(graph_real_len_list)
  54. # pred_order = np.argsort(graph_pred_len_list)
  55. real_order = np.argsort(graph_real_len_list)[::-1]
  56. pred_order = np.argsort(graph_pred_len_list)[::-1]
  57. # print(real_order)
  58. # print(pred_order)
  59. graph_real_list = [graph_real_list[i] for i in real_order]
  60. graph_pred_list = [graph_pred_list[i] for i in pred_order]
  61. # shuffle(graph_real_list)
  62. # shuffle(graph_pred_list)
  63. print('real average nodes', sum([graph_real_list[i].number_of_nodes() for i in range(len(graph_real_list))])/len(graph_real_list))
  64. print('pred average nodes', sum([graph_pred_list[i].number_of_nodes() for i in range(len(graph_pred_list))])/len(graph_pred_list))
  65. print('num of real graphs', len(graph_real_list))
  66. print('num of pred graphs', len(graph_pred_list))
  67. # # draw all graphs
  68. # for iter in range(8):
  69. # print('iter', iter)
  70. # graph_list = []
  71. # for i in range(8):
  72. # index = 8 * iter + i
  73. # # graph_real_list[index].remove_nodes_from(list(nx.isolates(graph_real_list[index])))
  74. # # graph_pred_list[index].remove_nodes_from(list(nx.isolates(graph_pred_list[index])))
  75. # graph_list.append(graph_real_list[index])
  76. # graph_list.append(graph_pred_list[index])
  77. # print('real', graph_real_list[index].number_of_nodes())
  78. # print('pred', graph_pred_list[index].number_of_nodes())
  79. #
  80. # draw_graph_list(graph_list, row=4, col=4, fname=figname + '_' + str(iter))
  81. # draw all graphs
  82. for iter in range(8):
  83. print('iter', iter)
  84. graph_list = []
  85. for i in range(8):
  86. index = 32 * iter + i
  87. # graph_real_list[index].remove_nodes_from(list(nx.isolates(graph_real_list[index])))
  88. # graph_pred_list[index].remove_nodes_from(list(nx.isolates(graph_pred_list[index])))
  89. # graph_list.append(graph_real_list[index])
  90. graph_list.append(graph_pred_list[index])
  91. # print('real', graph_real_list[index].number_of_nodes())
  92. print('pred', graph_pred_list[index].number_of_nodes())
  93. draw_graph_list(graph_list, row=4, col=4, fname=figname + '_' + str(iter)+'_pred')
  94. # draw all graphs
  95. for iter in range(8):
  96. print('iter', iter)
  97. graph_list = []
  98. for i in range(8):
  99. index = 16 * iter + i
  100. # graph_real_list[index].remove_nodes_from(list(nx.isolates(graph_real_list[index])))
  101. # graph_pred_list[index].remove_nodes_from(list(nx.isolates(graph_pred_list[index])))
  102. graph_list.append(graph_real_list[index])
  103. # graph_list.append(graph_pred_list[index])
  104. print('real', graph_real_list[index].number_of_nodes())
  105. # print('pred', graph_pred_list[index].number_of_nodes())
  106. draw_graph_list(graph_list, row=4, col=4, fname=figname + '_' + str(iter)+'_real')
  107. #
  108. # # for new model
  109. # elif args.note == 'GraphRNN_structure' and args.is_flexible==False:
  110. # for num_layers in range(4,5):
  111. # # give file name and figure name
  112. # # fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
  113. # # str(epoch) + '_real_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr)
  114. # # fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
  115. # # str(epoch) + '_pred_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr)
  116. #
  117. # fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
  118. # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt)+ '_' + str(args.bptt_len) + '_' + str(args.hidden_size)
  119. # fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + \
  120. # str(epoch) + '_real_' + str(args.num_layers) + '_' + str(args.bptt)+ '_' + str(args.bptt_len) + '_' + str(args.hidden_size)
  121. # figname = args.figure_save_path + args.note + '_' + args.graph_type + '_' + \
  122. # str(epoch) + '_pred_' + str(args.num_layers) + '_' + str(args.bptt)+ '_' + str(args.bptt_len) + '_' + str(args.hidden_size)
  123. # print(fname_real)
  124. # # load data
  125. # graph_real_list = load_graph_list(fname_real+'.dat')
  126. # graph_pred_list = load_graph_list(fname_pred+'.dat')
  127. #
  128. # graph_real_len_list = np.array([len(graph_real_list[i]) for i in range(len(graph_real_list))])
  129. # graph_pred_len_list = np.array([len(graph_pred_list[i]) for i in range(len(graph_pred_list))])
  130. # real_order = np.argsort(graph_real_len_list)[::-1]
  131. # pred_order = np.argsort(graph_pred_len_list)[::-1]
  132. # # print(real_order)
  133. # # print(pred_order)
  134. # graph_real_list = [graph_real_list[i] for i in real_order]
  135. # graph_pred_list = [graph_pred_list[i] for i in pred_order]
  136. #
  137. # shuffle(graph_pred_list)
  138. #
  139. #
  140. # print('real average nodes',
  141. # sum([graph_real_list[i].number_of_nodes() for i in range(len(graph_real_list))]) / len(graph_real_list))
  142. # print('pred average nodes',
  143. # sum([graph_pred_list[i].number_of_nodes() for i in range(len(graph_pred_list))]) / len(graph_pred_list))
  144. # print('num of graphs', len(graph_real_list))
  145. #
  146. # # draw all graphs
  147. # for iter in range(2):
  148. # print('iter', iter)
  149. # graph_list = []
  150. # for i in range(8):
  151. # index = 8*iter + i
  152. # graph_real_list[index].remove_nodes_from(nx.isolates(graph_real_list[index]))
  153. # graph_pred_list[index].remove_nodes_from(nx.isolates(graph_pred_list[index]))
  154. # graph_list.append(graph_real_list[index])
  155. # graph_list.append(graph_pred_list[index])
  156. # print('real', graph_real_list[index].number_of_nodes())
  157. # print('pred', graph_pred_list[index].number_of_nodes())
  158. # draw_graph_list(graph_list, row=4, col=4, fname=figname+'_'+str(iter))
  159. #
  160. #
  161. # # for new model
  162. # elif args.note == 'GraphRNN_structure' and args.is_flexible==True:
  163. # for num_layers in range(4,5):
  164. # graph_real_list = []
  165. # graph_pred_list = []
  166. # epoch_end = 30000
  167. # for epoch in [epoch_end-500*(8-i) for i in range(8)]:
  168. # # give file name and figure name
  169. # fname_real = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
  170. # str(epoch) + '_real_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr)
  171. # fname_pred = args.graph_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
  172. # str(epoch) + '_pred_bptt_' + str(args.bptt)+'_'+str(num_layers)+'_dilation_'+str(args.is_dilation)+'_flexible_'+str(args.is_flexible)+'_bn_'+str(args.is_bn)+'_lr_'+str(args.lr)
  173. #
  174. # # load data
  175. # graph_real_list += load_graph_list(fname_real+'.dat')
  176. # graph_pred_list += load_graph_list(fname_pred+'.dat')
  177. # print('num of graphs', len(graph_real_list))
  178. #
  179. # figname = args.figure_save_path + args.note + '_' + args.graph_type + '_' + str(args.graph_node_num) + '_' + \
  180. # str(epoch) + str(args.sample_when_validate) + '_' + str(num_layers) + '_dilation_' + str(args.is_dilation) + '_flexible_' + str(args.is_flexible) + '_bn_' + str(args.is_bn) + '_lr_' + str(args.lr)
  181. #
  182. # # draw all graphs
  183. # for iter in range(1):
  184. # print('iter', iter)
  185. # graph_list = []
  186. # for i in range(8):
  187. # index = 8*iter + i
  188. # graph_real_list[index].remove_nodes_from(nx.isolates(graph_real_list[index]))
  189. # graph_pred_list[index].remove_nodes_from(nx.isolates(graph_pred_list[index]))
  190. # graph_list.append(graph_real_list[index])
  191. # graph_list.append(graph_pred_list[index])
  192. # draw_graph_list(graph_list, row=4, col=4, fname=figname+'_'+str(iter))