You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTransformer.py 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import math
  2. import numpy as np
  3. import pandas as pd
  4. import networkx as nx
  5. import scipy as sp
  6. import seaborn as sns
  7. import time
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.init as init
  11. import torch.nn.functional as F
  12. from torch.nn.parameter import Parameter
  13. from torch.nn.modules.module import Module
  14. from torch import Tensor
  15. if torch.cuda.is_available():
  16. torch.device('cuda')
  17. """
  18. Utils:
  19. Data Loader
  20. Feature Matrix Constructor
  21. Random Node Remover
  22. """
  23. def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_attributes=True, graph_labels=True):
  24. '''
  25. load many graphs, e.g. enzymes
  26. :return: a list of graphs
  27. '''
  28. print('Loading graph dataset: ' + str(name))
  29. G = nx.Graph()
  30. # load data
  31. # path = '../dataset/' + name + '/'
  32. path = '/content/gdrive/My Drive/' + name + '/'
  33. data_adj = np.loadtxt(path + name + '_A.txt', delimiter=',').astype(int)
  34. if node_attributes:
  35. data_node_att = np.loadtxt(path + name + '_node_attributes.txt', delimiter=',')
  36. data_node_label = np.loadtxt(path + name + '_node_labels.txt', delimiter=',').astype(int)
  37. data_graph_indicator = np.loadtxt(path + name + '_graph_indicator.txt', delimiter=',').astype(int)
  38. if graph_labels:
  39. data_graph_labels = np.loadtxt(path + name + '_graph_labels.txt', delimiter=',').astype(int)
  40. data_tuple = list(map(tuple, data_adj))
  41. G.add_edges_from(data_tuple)
  42. for i in range(data_node_label.shape[0]):
  43. if node_attributes:
  44. G.add_node(i + 1, feature=data_node_att[i])
  45. G.add_node(i + 1, label=data_node_label[i])
  46. G.remove_nodes_from(list(nx.isolates(G)))
  47. graph_num = data_graph_indicator.max()
  48. node_list = np.arange(data_graph_indicator.shape[0]) + 1
  49. graphs = []
  50. max_nodes = 0
  51. for i in range(graph_num):
  52. nodes = node_list[data_graph_indicator == i + 1]
  53. G_sub = G.subgraph(nodes)
  54. if graph_labels:
  55. G_sub.graph['label'] = data_graph_labels[i]
  56. if G_sub.number_of_nodes() >= min_num_nodes and G_sub.number_of_nodes() <= max_num_nodes:
  57. graphs.append(G_sub)
  58. if G_sub.number_of_nodes() > max_nodes:
  59. max_nodes = G_sub.number_of_nodes()
  60. print('Loaded')
  61. return graphs
  62. def feature_matrix(g):
  63. '''
  64. constructs the feautre matrix (N x 3) for the enzymes datasets
  65. '''
  66. esm = nx.get_node_attributes(g, 'label')
  67. piazche = np.zeros((len(esm), 3))
  68. for i, (k, v) in enumerate(esm.items()):
  69. piazche[i][v-1] = 1
  70. return piazche
  71. def remove_random_node(graph, max_size=40, min_size=10):
  72. '''
  73. removes a random node from the gragh
  74. returns the remaining graph matrix and the removed node links
  75. '''
  76. if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
  77. return None
  78. relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
  79. choice = np.random.choice(list(relabeled_graph.nodes()))
  80. remaining_graph = nx.to_numpy_matrix(relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))))
  81. removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
  82. graph_length = len(remaining_graph)
  83. # source_graph = np.pad(remaining_graph, [(0, max_size - graph_length), (0, max_size - graph_length)])
  84. # target_graph = np.copy(source_graph)
  85. removed_node_row = np.asarray(removed_node)[0]
  86. # target_graph[graph_length] = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
  87. return remaining_graph, removed_node_row
  88. """"
  89. Layers:
  90. GCN
  91. """
  92. class GraphConv(nn.Module):
  93. def __init__(self, input_dim, output_dim):
  94. super().__init__()
  95. self.input_dim = input_dim
  96. self.output_dim = output_dim
  97. self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
  98. self.relu = nn.ReLU()
  99. def forward(self, x, adj):
  100. '''
  101. x is hamun feature matrix
  102. adj ham ke is adjacency matrix of the graph
  103. '''
  104. y = torch.matmul(adj, x)
  105. print(y.shape)
  106. print(self.weight.shape)
  107. y = torch.matmul(y, self.weight.double())
  108. return y