You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTransformer.py 6.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import math
  2. import numpy as np
  3. import pandas as pd
  4. import networkx as nx
  5. import scipy as sp
  6. import seaborn as sns
  7. import time
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.init as init
  11. import torch.nn.functional as F
  12. from torch.nn.parameter import Parameter
  13. from torch.nn.modules.module import Module
  14. from torch import Tensor
  15. if torch.cuda.is_available():
  16. torch.device('cuda')
  17. """
  18. Utils:
  19. Data Loader
  20. Feature Matrix Constructor
  21. Random Node Remover
  22. """
  23. def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_attributes=True, graph_labels=True):
  24. '''
  25. load many graphs, e.g. enzymes
  26. :return: a list of graphs
  27. '''
  28. print('Loading graph dataset: ' + str(name))
  29. G = nx.Graph()
  30. # load data
  31. # path = '../dataset/' + name + '/'
  32. path = '/content/gdrive/My Drive/' + name + '/'
  33. data_adj = np.loadtxt(path + name + '_A.txt', delimiter=',').astype(int)
  34. if node_attributes:
  35. data_node_att = np.loadtxt(path + name + '_node_attributes.txt', delimiter=',')
  36. data_node_label = np.loadtxt(path + name + '_node_labels.txt', delimiter=',').astype(int)
  37. data_graph_indicator = np.loadtxt(path + name + '_graph_indicator.txt', delimiter=',').astype(int)
  38. if graph_labels:
  39. data_graph_labels = np.loadtxt(path + name + '_graph_labels.txt', delimiter=',').astype(int)
  40. data_tuple = list(map(tuple, data_adj))
  41. G.add_edges_from(data_tuple)
  42. for i in range(data_node_label.shape[0]):
  43. if node_attributes:
  44. G.add_node(i + 1, feature=data_node_att[i])
  45. G.add_node(i + 1, label=data_node_label[i])
  46. G.remove_nodes_from(list(nx.isolates(G)))
  47. graph_num = data_graph_indicator.max()
  48. node_list = np.arange(data_graph_indicator.shape[0]) + 1
  49. graphs = []
  50. max_nodes = 0
  51. for i in range(graph_num):
  52. nodes = node_list[data_graph_indicator == i + 1]
  53. G_sub = G.subgraph(nodes)
  54. if graph_labels:
  55. G_sub.graph['label'] = data_graph_labels[i]
  56. if G_sub.number_of_nodes() >= min_num_nodes and G_sub.number_of_nodes() <= max_num_nodes:
  57. graphs.append(G_sub)
  58. if G_sub.number_of_nodes() > max_nodes:
  59. max_nodes = G_sub.number_of_nodes()
  60. print('Loaded')
  61. return graphs
  62. def feature_matrix(g):
  63. '''
  64. constructs the feautre matrix (N x 3) for the enzymes datasets
  65. '''
  66. esm = nx.get_node_attributes(g, 'label')
  67. piazche = np.zeros((len(esm), 3))
  68. for i, (k, v) in enumerate(esm.items()):
  69. piazche[i][v-1] = 1
  70. return piazche
  71. # def remove_random_node(graph, max_size=40, min_size=10):
  72. # '''
  73. # removes a random node from the gragh
  74. # returns the remaining graph matrix and the removed node links
  75. # '''
  76. # if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
  77. # return None
  78. # relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
  79. # choice = np.random.choice(list(relabeled_graph.nodes()))
  80. # remaining_graph = nx.to_numpy_matrix(relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))))
  81. # removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
  82. # graph_length = len(remaining_graph)
  83. # # source_graph = np.pad(remaining_graph, [(0, max_size - graph_length), (0, max_size - graph_length)])
  84. # # target_graph = np.copy(source_graph)
  85. # removed_node_row = np.asarray(removed_node)[0]
  86. # # target_graph[graph_length] = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
  87. # return remaining_graph, removed_node_row
  88. def prepare_graph_data(graph, max_size=40, min_size=10):
  89. '''
  90. gets a graph as an input
  91. returns a graph with a randomly removed node adj matrix [0], its feature matrix [0], the removed node true links [2]
  92. '''
  93. if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
  94. return None
  95. relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
  96. choice = np.random.choice(list(relabeled_graph.nodes()))
  97. remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))
  98. remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)
  99. graph_length = len(remaining_graph)
  100. remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)])
  101. removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
  102. removed_node_row = np.asarray(removed_node)[0]
  103. removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
  104. return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row
  105. """"
  106. Layers:
  107. Graph Convolution
  108. Graph Multihead Attention
  109. Feed-Forward (as a MLP)
  110. """
  111. class GraphConv(nn.Module):
  112. def __init__(self, input_dim, output_dim):
  113. super().__init__()
  114. self.input_dim = input_dim
  115. self.output_dim = output_dim
  116. self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
  117. self.relu = nn.ReLU()
  118. def forward(self, x, adj):
  119. '''
  120. x is the feature matrix constructed in feature_matrix function
  121. adj ham ke is adjacency matrix of the graph
  122. '''
  123. y = torch.matmul(adj, x)
  124. # print(y.shape)
  125. # print(self.weight.shape)
  126. y = torch.matmul(y, self.weight.double())
  127. return y
  128. class GraphAttn(nn.Module):
  129. def __init__(self, heads, model_dim, dropout=0.1):
  130. super().__init__()
  131. self.model_dim = model_dim
  132. self.key_dim = model_dim // heads
  133. self.heads = heads
  134. self.q_linear = nn.Linear(model_dim, model_dim).cuda()
  135. self.v_linear = nn.Linear(model_dim, model_dim).cuda()
  136. self.k_linear = nn.Linear(model_dim, model_dim).cuda()
  137. self.dropout = nn.Dropout(dropout)
  138. self.out = nn.Linear(model_dim, model_dim).cuda()
  139. def forward(self, query, key, value):
  140. # print(q, k, v)
  141. bs = query.size(0)
  142. key = self.k_linear(key.float()).view(bs, -1, self.heads, self.key_dim)
  143. query = self.q_linear(query.float()).view(bs, -1, self.heads, self.key_dim)
  144. value = self.v_linear(value.float()).view(bs, -1, self.heads, self.key_dim)
  145. key = key.transpose(1,2)
  146. query = query.transpose(1,2)
  147. value = value.transpose(1,2)
  148. scores = attention(query, key, value, self.key_dim)
  149. concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)
  150. output = self.out(concat)
  151. output = output.view(bs, self.model_dim)
  152. return output
  153. class FeedForward(nn.Module):
  154. def __init__(self, input_size, hidden_size):
  155. super().__init__()
  156. self.input_size = input_size
  157. self.hidden_size = hidden_size
  158. self.fully_connected1 = nn.Linear(self.input_size, self.hidden_size).cuda()
  159. self.relu = nn.ReLU()
  160. self.fully_connected2 = nn.Linear(self.hidden_size, 1).cuda()
  161. self.sigmoid = nn.Sigmoid()
  162. def forward(self, x):
  163. hidden = self.fully_connected1(x.float())
  164. relu = self.relu(hidden)
  165. output = self.fully_connected2(relu)
  166. output = self.sigmoid(output)
  167. return output