You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long. 6.9KB

  1. import networkx as nx
  2. import numpy as np
  3. import
  4. import os
  5. import shutil
  6. import random
  7. import torch
  8. from sklearn.model_selection import StratifiedKFold
  9. class S2VGraph(object):
  10. def __init__(self, g, label, node_tags=None, node_features=None):
  11. '''
  12. g: a networkx graph
  13. label: an integer graph label
  14. node_tags: a list of integer node tags
  15. node_features: a torch float tensor, one-hot representation of the tag that is used as input to neural nets
  16. edge_mat: a torch long tensor, contain edge list, will be used to create torch sparse tensor
  17. neighbors: list of neighbors (without self-loop)
  18. '''
  19. self.label = label
  20. self.g = g
  21. self.node_tags = node_tags
  22. self.neighbors = []
  23. self.node_features = 0
  24. self.edge_mat = 0
  25. self.max_neighbor = 0
  26. def load_data(dataset, degree_as_tag):
  27. '''
  28. dataset: name of dataset
  29. test_proportion: ratio of test train split
  30. seed: random seed for random splitting of dataset
  31. '''
  32. print('loading data')
  33. g_list = []
  34. label_dict = {}
  35. feat_dict = {}
  36. with open('dataset/%s/%s.txt' % (dataset, dataset), 'r') as f:
  37. n_g = int(f.readline().strip())
  38. for i in range(n_g):
  39. row = f.readline().strip().split()
  40. n, l = [int(w) for w in row]
  41. if not l in label_dict:
  42. mapped = len(label_dict)
  43. label_dict[l] = mapped
  44. g = nx.Graph()
  45. node_tags = []
  46. node_features = []
  47. n_edges = 0
  48. for j in range(n):
  49. g.add_node(j)
  50. row = f.readline().strip().split()
  51. tmp = int(row[1]) + 2
  52. if tmp == len(row):
  53. # no node attributes
  54. row = [int(w) for w in row]
  55. attr = None
  56. else:
  57. row, attr = [int(w) for w in row[:tmp]], np.array([float(w) for w in row[tmp:]])
  58. if not row[0] in feat_dict:
  59. mapped = len(feat_dict)
  60. feat_dict[row[0]] = mapped
  61. node_tags.append(feat_dict[row[0]])
  62. if tmp > len(row):
  63. node_features.append(attr)
  64. n_edges += row[1]
  65. for k in range(2, len(row)):
  66. g.add_edge(j, row[k])
  67. if node_features != []:
  68. node_features = np.stack(node_features)
  69. node_feature_flag = True
  70. else:
  71. node_features = None
  72. node_feature_flag = False
  73. assert len(g) == n
  74. if n < 21:
  75. g_list.append(g)
  76. return g_list, len(label_dict)
  77. def separate_data(graph_list, seed, fold_idx):
  78. assert 0 <= fold_idx and fold_idx < 10, "fold_idx must be from 0 to 9."
  79. skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
  80. labels = [graph.label for graph in graph_list]
  81. idx_list = []
  82. for idx in skf.split(np.zeros(len(labels)), labels):
  83. idx_list.append(idx)
  84. train_idx, test_idx = idx_list[fold_idx]
  85. train_graph_list = [graph_list[i] for i in train_idx]
  86. test_graph_list = [graph_list[i] for i in test_idx]
  87. return train_graph_list, test_graph_list
  88. def save_graphs_as_mat(graphs_list):
  89. if os.path.isdir("test_graphs"):
  90. shutil.rmtree("test_graphs")
  91. if not os.path.exists('test_graphs'):
  92. os.makedirs('test_graphs')
  93. counter = 0
  94. for g in graphs_list:
  95. counter += 1
  96. curr_graph = nx.to_numpy_array(g)
  97. numpy_matrix = nx.to_numpy_matrix(g)
  98. # print(1.0 - (np.count_nonzero(numpy_matrix) / float(numpy_matrix.size)))
  99. if counter == 101:
  100. print(1.0 - (np.count_nonzero(numpy_matrix) / float(numpy_matrix.size)))
  101. print("###########################################################")
  102. print(curr_graph[0])
  103. curr_graph_att = np.ones((len(curr_graph), 60))
  104.'test_graphs/testgraph_{}_{}__.txt.mat'.format(curr_graph.shape[0], counter, g),
  105. {'data': curr_graph})
  106.'test_graphs/testgraph_{}_{}__.usr.mat'.format(curr_graph.shape[0], counter, g),
  107. {'attributes': curr_graph_att})
  108. def move_random_node_to_the_last_index(adj):
  109. # selecting a random node and moving it to the last node position
  110. random_idx_for_delete = np.random.randint(adj.shape[0])
  111. deleted_node = adj[:, random_idx_for_delete].copy()
  112. for i in range(deleted_node.__len__()):
  113. if i >= random_idx_for_delete and i < deleted_node.__len__() - 1:
  114. deleted_node[i] = deleted_node[i + 1]
  115. elif i == deleted_node.__len__() - 1:
  116. deleted_node[i] = 0
  117. adj[:, random_idx_for_delete:adj.shape[0] - 1] = adj[:, random_idx_for_delete + 1:adj.shape[0]]
  118. adj[random_idx_for_delete:adj.shape[0] - 1, :] = adj[random_idx_for_delete + 1:adj.shape[0], :]
  119. adj = np.delete(adj, -1, axis=1)
  120. adj = np.delete(adj, -1, axis=0)
  121. adj = np.concatenate((adj, deleted_node[:deleted_node.shape[0] - 1]), axis=1)
  122. adj = np.concatenate((adj, np.transpose(deleted_node)), axis=0)
  123. return adj
  124. def prepare_kronEM_data(graphs_list, data_name, random_node_permutation_flag):
  125. if os.path.isdir("kronEM_main_graphs_" + data_name):
  126. shutil.rmtree("kronEM_main_graphs_" + data_name)
  127. if not os.path.exists("kronEM_main_graphs_" + data_name):
  128. os.makedirs("kronEM_main_graphs_" + data_name)
  129. if os.path.isdir("kronEM_graphs_with_missing_node_" + data_name):
  130. shutil.rmtree("kronEM_graphs_with_missing_node_" + data_name)
  131. if not os.path.exists("kronEM_graphs_with_missing_node_" + data_name):
  132. os.makedirs("kronEM_graphs_with_missing_node_" + data_name)
  133. counter = 0
  134. if random_node_permutation_flag:
  135. number_of_random_node_permutation_per_graph = 3
  136. else:
  137. number_of_random_node_permutation_per_graph = 1
  138. for g in graphs_list:
  139. for i in range(number_of_random_node_permutation_per_graph):
  140. counter += 1
  141. numpy_matrix = nx.to_numpy_matrix(g)
  142. if random_node_permutation_flag:
  143. numpy_matrix = move_random_node_to_the_last_index(numpy_matrix)
  144. file_main = open("kronEM_main_graphs_" + data_name + "/" + str(counter) + ".txt", "w")
  145. file_missing = open("kronEM_graphs_with_missing_node_" + data_name + "/" + str(counter) + ".txt", "w")
  146. with file_main as f:
  147. for i in range(numpy_matrix.shape[0]):
  148. for j in range(numpy_matrix.shape[0]):
  149. if numpy_matrix[i, j] == 1:
  150. f.write(str(i + 1) + "\t" + str(j + 1) + "\n")
  151. if i != numpy_matrix.shape[0] - 1 and j != numpy_matrix.shape[0] - 1:
  152. file_missing.write(str(i + 1) + "\t" + str(j + 1) + "\n")
  153. file_main.close()
  154. return