You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

load_data.py 2.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import numpy as np
  2. import networkx as nx
  3. G = nx.Graph()
  4. # load data
  5. data_adj = np.loadtxt('ENZYMES_A.txt', delimiter=',').astype(int)
  6. data_node_att = np.loadtxt('ENZYMES_node_attributes.txt', delimiter=',')
  7. data_node_label = np.loadtxt('ENZYMES_node_labels.txt', delimiter=',').astype(int)
  8. data_graph_indicator = np.loadtxt('ENZYMES_graph_indicator.txt', delimiter=',').astype(int)
  9. data_graph_labels = np.loadtxt('ENZYMES_graph_labels.txt', delimiter=',').astype(int)
  10. data_tuple = list(map(tuple, data_adj))
  11. print(len(data_tuple))
  12. print(data_tuple[0])
  13. # add edges
  14. G.add_edges_from(data_tuple)
  15. # add node attributes
  16. for i in range(data_node_att.shape[0]):
  17. G.add_node(i+1, feature = data_node_att[i])
  18. G.add_node(i+1, label = data_node_label[i])
  19. G.remove_nodes_from(nx.isolates(G))
  20. print(G.number_of_nodes())
  21. print(G.number_of_edges())
  22. # split into graphs
  23. graph_num = 600
  24. node_list = np.arange(data_graph_indicator.shape[0])+1
  25. graphs = []
  26. node_num_list = []
  27. for i in range(graph_num):
  28. # find the nodes for each graph
  29. nodes = node_list[data_graph_indicator==i+1]
  30. G_sub = G.subgraph(nodes)
  31. graphs.append(G_sub)
  32. G_sub.graph['label'] = data_graph_labels[i]
  33. # print('nodes', G_sub.number_of_nodes())
  34. # print('edges', G_sub.number_of_edges())
  35. # print('label', G_sub.graph)
  36. node_num_list.append(G_sub.number_of_nodes())
  37. print('average', sum(node_num_list)/len(node_num_list))
  38. print('all', len(node_num_list))
  39. node_num_list = np.array(node_num_list)
  40. print('selected', len(node_num_list[node_num_list>10]))
  41. # print(graphs[0].nodes(data=True)[0][1]['feature'])
  42. # print(graphs[0].nodes())
  43. keys = tuple(graphs[0].nodes())
  44. # print(nx.get_node_attributes(graphs[0], 'feature'))
  45. dictionary = nx.get_node_attributes(graphs[0], 'feature')
  46. # print('keys', keys)
  47. # print('keys from dict', list(dictionary.keys()))
  48. # print('valuse from dict', list(dictionary.values()))
  49. features = np.zeros((len(dictionary), list(dictionary.values())[0].shape[0]))
  50. for i in range(len(dictionary)):
  51. features[i,:] = list(dictionary.values())[i]
  52. # print(features)
  53. # print(features.shape)