You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTransformer.py 9.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import math
  2. import time
  3. import torch
  4. import numpy as np
  5. import scipy as sp
  6. import pandas as pd
  7. import torch.nn as nn
  8. import networkx as nx
  9. import seaborn as sns
  10. from torch import Tensor
  11. import torch.nn.init as init
  12. import torch.nn.functional as F
  13. from torch.nn.parameter import Parameter
  14. from torch.nn.modules.module import Module
  15. if torch.cuda.is_available():
  16. torch.device('cuda')
  17. """
  18. Utils:
  19. Data Loader
  20. Feature Matrix Constructor
  21. Random Node Remover
  22. """
  23. def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_attributes=True, graph_labels=True):
  24. '''
  25. load many graphs, e.g. enzymes
  26. :return: a list of graphs
  27. '''
  28. print('Loading graph dataset: ' + str(name))
  29. G = nx.Graph()
  30. # load data
  31. # path = '../dataset/' + name + '/'
  32. path = '/content/gdrive/My Drive/' + name + '/'
  33. data_adj = np.loadtxt(path + name + '_A.txt', delimiter=',').astype(int)
  34. if node_attributes:
  35. data_node_att = np.loadtxt(path + name + '_node_attributes.txt', delimiter=',')
  36. data_node_label = np.loadtxt(path + name + '_node_labels.txt', delimiter=',').astype(int)
  37. data_graph_indicator = np.loadtxt(path + name + '_graph_indicator.txt', delimiter=',').astype(int)
  38. if graph_labels:
  39. data_graph_labels = np.loadtxt(path + name + '_graph_labels.txt', delimiter=',').astype(int)
  40. data_tuple = list(map(tuple, data_adj))
  41. G.add_edges_from(data_tuple)
  42. for i in range(data_node_label.shape[0]):
  43. if node_attributes:
  44. G.add_node(i + 1, feature=data_node_att[i])
  45. G.add_node(i + 1, label=data_node_label[i])
  46. G.remove_nodes_from(list(nx.isolates(G)))
  47. graph_num = data_graph_indicator.max()
  48. node_list = np.arange(data_graph_indicator.shape[0]) + 1
  49. graphs = []
  50. max_nodes = 0
  51. for i in range(graph_num):
  52. nodes = node_list[data_graph_indicator == i + 1]
  53. G_sub = G.subgraph(nodes)
  54. if graph_labels:
  55. G_sub.graph['label'] = data_graph_labels[i]
  56. if G_sub.number_of_nodes() >= min_num_nodes and G_sub.number_of_nodes() <= max_num_nodes:
  57. graphs.append(G_sub)
  58. if G_sub.number_of_nodes() > max_nodes:
  59. max_nodes = G_sub.number_of_nodes()
  60. print('Loaded')
  61. return graphs
  62. def feature_matrix(g):
  63. '''
  64. constructs the feautre matrix (N x 3) for the enzymes datasets
  65. '''
  66. esm = nx.get_node_attributes(g, 'label')
  67. piazche = np.zeros((len(esm), 3))
  68. for i, (k, v) in enumerate(esm.items()):
  69. piazche[i][v-1] = 1
  70. return piazche
  71. # def remove_random_node(graph, max_size=40, min_size=10):
  72. # '''
  73. # removes a random node from the gragh
  74. # returns the remaining graph matrix and the removed node links
  75. # '''
  76. # if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
  77. # return None
  78. # relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
  79. # choice = np.random.choice(list(relabeled_graph.nodes()))
  80. # remaining_graph = nx.to_numpy_matrix(relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))))
  81. # removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
  82. # graph_length = len(remaining_graph)
  83. # # source_graph = np.pad(remaining_graph, [(0, max_size - graph_length), (0, max_size - graph_length)])
  84. # # target_graph = np.copy(source_graph)
  85. # removed_node_row = np.asarray(removed_node)[0]
  86. # # target_graph[graph_length] = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
  87. # return remaining_graph, removed_node_row
  88. def prepare_graph_data(graph, max_size=40, min_size=10):
  89. '''
  90. gets a graph as an input
  91. returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2]
  92. '''
  93. if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
  94. return None
  95. relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
  96. choice = np.random.choice(list(relabeled_graph.nodes()))
  97. remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))
  98. remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)
  99. graph_length = len(remaining_graph)
  100. remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)])
  101. removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
  102. removed_node_row = np.asarray(removed_node)[0]
  103. removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
  104. return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row
  105. """"
  106. Layers:
  107. Graph Convolution
  108. Graph Multihead Attention
  109. Feed-Forward (as a MLP)
  110. """
  111. class GraphConv(nn.Module):
  112. def __init__(self, input_dim, output_dim):
  113. super().__init__()
  114. self.input_dim = input_dim
  115. self.output_dim = output_dim
  116. self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
  117. self.relu = nn.ReLU()
  118. def forward(self, x, adj):
  119. '''
  120. x is the feature matrix constructed in feature_matrix function
  121. adj ham ke is adjacency matrix of the graph
  122. '''
  123. y = torch.matmul(adj, x)
  124. # print(y.shape)
  125. # print(self.weight.shape)
  126. y = torch.matmul(y, self.weight.double())
  127. return y
  128. class GraphAttn(nn.Module):
  129. def __init__(self, heads, model_dim, dropout=0.1):
  130. super().__init__()
  131. self.model_dim = model_dim
  132. self.key_dim = model_dim // heads
  133. self.heads = heads
  134. self.q_linear = nn.Linear(model_dim, model_dim).cuda()
  135. self.v_linear = nn.Linear(model_dim, model_dim).cuda()
  136. self.k_linear = nn.Linear(model_dim, model_dim).cuda()
  137. self.dropout = nn.Dropout(dropout)
  138. self.out = nn.Linear(model_dim, model_dim).cuda()
  139. def forward(self, query, key, value):
  140. # print(q, k, v)
  141. bs = query.size(0)
  142. key = self.k_linear(key.float()).view(bs, -1, self.heads, self.key_dim)
  143. query = self.q_linear(query.float()).view(bs, -1, self.heads, self.key_dim)
  144. value = self.v_linear(value.float()).view(bs, -1, self.heads, self.key_dim)
  145. key = key.transpose(1,2)
  146. query = query.transpose(1,2)
  147. value = value.transpose(1,2)
  148. scores = attention(query, key, value, self.key_dim)
  149. concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)
  150. output = self.out(concat)
  151. output = output.view(bs, self.model_dim)
  152. return output
  153. class FeedForward(nn.Module):
  154. def __init__(self, input_size, hidden_size):
  155. super().__init__()
  156. self.input_size = input_size
  157. self.hidden_size = hidden_size
  158. self.fully_connected1 = nn.Linear(self.input_size, self.hidden_size).cuda()
  159. self.relu = nn.ReLU()
  160. self.fully_connected2 = nn.Linear(self.hidden_size, 1).cuda()
  161. self.sigmoid = nn.Sigmoid()
  162. def forward(self, x):
  163. hidden = self.fully_connected1(x.float())
  164. relu = self.relu(hidden)
  165. output = self.fully_connected2(relu)
  166. output = self.sigmoid(output)
  167. return output
  168. class Hydra(nn.Module):
  169. def __init__(self, gcn_input, model_dim, head):
  170. super().__init__()
  171. self.GCN = GraphConv(input_dim=gcn_input, output_dim=model_dim).cuda()
  172. self.GAT = GraphAttn(heads=head, model_dim=model_dim).cuda()
  173. self.MLP = FeedForward(input_size=model_dim, hidden_size=gcn_input).cuda()
  174. def forward(self, x, adj):
  175. gcn_outputs = self.GCN(x, adj)
  176. gat_output = self.GAT(gcn_outputs)
  177. mlp_output = self.MLP(gat_output).reshape(1,-1)
  178. return mlp_output
  179. """"
  180. Train the Model
  181. Prepare data using DataLoader
  182. (data can't be batched)
  183. """
  184. def build_model(gcn_input, model_dim, head):
  185. model = Hydra(gcn_input, model_dim, head).cuda()
  186. return model
  187. def fn(batch):
  188. return batch[0]
  189. def train_model(model, trainloader, epoch, print_every=100):
  190. optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
  191. model.train()
  192. start = time.time()
  193. temp = start
  194. total_loss = 0
  195. for i in range(epoch):
  196. for batch, data in enumerate(trainloader, 0):
  197. adj, features, true_links = data
  198. adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda()
  199. # print(adj.shape)
  200. # print(features.shape)
  201. # print(true_links.shape)
  202. preds = model(features, adj)
  203. optim.zero_grad()
  204. loss = F.binary_cross_entropy(preds.double(), true_links.double())
  205. loss.backward()
  206. optim.step()
  207. total_loss += loss.item()
  208. if (i + 1) % print_every == 0:
  209. loss_avg = total_loss / print_every
  210. print("time = %dm, epoch %d, iter = %d, loss = %.3f,\
  211. %ds per %d iters" % ((time.time() - start) // 60,\
  212. epoch + 1, i + 1, loss_avg, time.time() - temp,\
  213. print_every))
  214. total_loss = 0
  215. temp = time.time()
  216. # prepare data
  217. # coop = sum([list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs])) for i in range(10)], [])
  218. coop = list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs]))
  219. trainloader = torch.utils.data.DataLoader(coop, collate_fn=fn, batch_size=1)
  220. model = build_model(3, 243, 9)
  221. train_model(model, trainloader, 100, 10)