You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 6.9KB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import networkx as nx
  2. import numpy as np
  3. import scipy.io
  4. import os
  5. import shutil
  6. import random
  7. import torch
  8. from sklearn.model_selection import StratifiedKFold
  9. class S2VGraph(object):
  10. def __init__(self, g, label, node_tags=None, node_features=None):
  11. '''
  12. g: a networkx graph
  13. label: an integer graph label
  14. node_tags: a list of integer node tags
  15. node_features: a torch float tensor, one-hot representation of the tag that is used as input to neural nets
  16. edge_mat: a torch long tensor, contain edge list, will be used to create torch sparse tensor
  17. neighbors: list of neighbors (without self-loop)
  18. '''
  19. self.label = label
  20. self.g = g
  21. self.node_tags = node_tags
  22. self.neighbors = []
  23. self.node_features = 0
  24. self.edge_mat = 0
  25. self.max_neighbor = 0
  26. def load_data(dataset, degree_as_tag):
  27. '''
  28. dataset: name of dataset
  29. test_proportion: ratio of test train split
  30. seed: random seed for random splitting of dataset
  31. '''
  32. print('loading data')
  33. g_list = []
  34. label_dict = {}
  35. feat_dict = {}
  36. with open('dataset/%s/%s.txt' % (dataset, dataset), 'r') as f:
  37. n_g = int(f.readline().strip())
  38. for i in range(n_g):
  39. row = f.readline().strip().split()
  40. n, l = [int(w) for w in row]
  41. if not l in label_dict:
  42. mapped = len(label_dict)
  43. label_dict[l] = mapped
  44. g = nx.Graph()
  45. node_tags = []
  46. node_features = []
  47. n_edges = 0
  48. for j in range(n):
  49. g.add_node(j)
  50. row = f.readline().strip().split()
  51. tmp = int(row[1]) + 2
  52. if tmp == len(row):
  53. # no node attributes
  54. row = [int(w) for w in row]
  55. attr = None
  56. else:
  57. row, attr = [int(w) for w in row[:tmp]], np.array([float(w) for w in row[tmp:]])
  58. if not row[0] in feat_dict:
  59. mapped = len(feat_dict)
  60. feat_dict[row[0]] = mapped
  61. node_tags.append(feat_dict[row[0]])
  62. if tmp > len(row):
  63. node_features.append(attr)
  64. n_edges += row[1]
  65. for k in range(2, len(row)):
  66. g.add_edge(j, row[k])
  67. if node_features != []:
  68. node_features = np.stack(node_features)
  69. node_feature_flag = True
  70. else:
  71. node_features = None
  72. node_feature_flag = False
  73. assert len(g) == n
  74. if n < 21:
  75. g_list.append(g)
  76. return g_list, len(label_dict)
  77. def separate_data(graph_list, seed, fold_idx):
  78. assert 0 <= fold_idx and fold_idx < 10, "fold_idx must be from 0 to 9."
  79. skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
  80. labels = [graph.label for graph in graph_list]
  81. idx_list = []
  82. for idx in skf.split(np.zeros(len(labels)), labels):
  83. idx_list.append(idx)
  84. train_idx, test_idx = idx_list[fold_idx]
  85. train_graph_list = [graph_list[i] for i in train_idx]
  86. test_graph_list = [graph_list[i] for i in test_idx]
  87. return train_graph_list, test_graph_list
  88. def save_graphs_as_mat(graphs_list):
  89. if os.path.isdir("test_graphs"):
  90. shutil.rmtree("test_graphs")
  91. if not os.path.exists('test_graphs'):
  92. os.makedirs('test_graphs')
  93. counter = 0
  94. for g in graphs_list:
  95. counter += 1
  96. curr_graph = nx.to_numpy_array(g)
  97. numpy_matrix = nx.to_numpy_matrix(g)
  98. # print(1.0 - (np.count_nonzero(numpy_matrix) / float(numpy_matrix.size)))
  99. if counter == 101:
  100. print(1.0 - (np.count_nonzero(numpy_matrix) / float(numpy_matrix.size)))
  101. print("###########################################################")
  102. print(curr_graph[0])
  103. curr_graph_att = np.ones((len(curr_graph), 60))
  104. scipy.io.savemat('test_graphs/testgraph_{}_{}__.txt.mat'.format(curr_graph.shape[0], counter, g),
  105. {'data': curr_graph})
  106. scipy.io.savemat('test_graphs/testgraph_{}_{}__.usr.mat'.format(curr_graph.shape[0], counter, g),
  107. {'attributes': curr_graph_att})
  108. def move_random_node_to_the_last_index(adj):
  109. # selecting a random node and moving it to the last node position
  110. random_idx_for_delete = np.random.randint(adj.shape[0])
  111. deleted_node = adj[:, random_idx_for_delete].copy()
  112. for i in range(deleted_node.__len__()):
  113. if i >= random_idx_for_delete and i < deleted_node.__len__() - 1:
  114. deleted_node[i] = deleted_node[i + 1]
  115. elif i == deleted_node.__len__() - 1:
  116. deleted_node[i] = 0
  117. adj[:, random_idx_for_delete:adj.shape[0] - 1] = adj[:, random_idx_for_delete + 1:adj.shape[0]]
  118. adj[random_idx_for_delete:adj.shape[0] - 1, :] = adj[random_idx_for_delete + 1:adj.shape[0], :]
  119. adj = np.delete(adj, -1, axis=1)
  120. adj = np.delete(adj, -1, axis=0)
  121. adj = np.concatenate((adj, deleted_node[:deleted_node.shape[0] - 1]), axis=1)
  122. adj = np.concatenate((adj, np.transpose(deleted_node)), axis=0)
  123. return adj
  124. def prepare_kronEM_data(graphs_list, data_name, random_node_permutation_flag):
  125. if os.path.isdir("kronEM_main_graphs_" + data_name):
  126. shutil.rmtree("kronEM_main_graphs_" + data_name)
  127. if not os.path.exists("kronEM_main_graphs_" + data_name):
  128. os.makedirs("kronEM_main_graphs_" + data_name)
  129. if os.path.isdir("kronEM_graphs_with_missing_node_" + data_name):
  130. shutil.rmtree("kronEM_graphs_with_missing_node_" + data_name)
  131. if not os.path.exists("kronEM_graphs_with_missing_node_" + data_name):
  132. os.makedirs("kronEM_graphs_with_missing_node_" + data_name)
  133. counter = 0
  134. if random_node_permutation_flag:
  135. number_of_random_node_permutation_per_graph = 3
  136. else:
  137. number_of_random_node_permutation_per_graph = 1
  138. for g in graphs_list:
  139. for i in range(number_of_random_node_permutation_per_graph):
  140. counter += 1
  141. numpy_matrix = nx.to_numpy_matrix(g)
  142. if random_node_permutation_flag:
  143. numpy_matrix = move_random_node_to_the_last_index(numpy_matrix)
  144. file_main = open("kronEM_main_graphs_" + data_name + "/" + str(counter) + ".txt", "w")
  145. file_missing = open("kronEM_graphs_with_missing_node_" + data_name + "/" + str(counter) + ".txt", "w")
  146. with file_main as f:
  147. for i in range(numpy_matrix.shape[0]):
  148. for j in range(numpy_matrix.shape[0]):
  149. if numpy_matrix[i, j] == 1:
  150. f.write(str(i + 1) + "\t" + str(j + 1) + "\n")
  151. if i != numpy_matrix.shape[0] - 1 and j != numpy_matrix.shape[0] - 1:
  152. file_missing.write(str(i + 1) + "\t" + str(j + 1) + "\n")
  153. file_main.close()
  154. return