You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTransformer.py 9.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.init as init
  5. import numpy as np
  6. import pandas as pd
  7. import networkx as nx
  8. import scipy as sp
  9. import seaborn as sns
  10. # from node2vec import Node2Vec
  11. from sklearn.decomposition import PCA
  12. import copy
  13. import time
  14. if torch.cuda.is_available():
  15. torch.device('cuda')
  16. """Utils:
  17. Data Loader / Attention / Clones / Embedder"""
  18. def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_attributes=True, graph_labels=True):
  19. '''
  20. load many graphs, e.g. enzymes
  21. :return: a list of graphs
  22. '''
  23. print('Loading graph dataset: ' + str(name))
  24. G = nx.Graph()
  25. # load data
  26. # path = '../dataset/' + name + '/'
  27. path = '/content/gdrive/My Drive/' + name + '/'
  28. data_adj = np.loadtxt(path + name + '_A.txt', delimiter=',').astype(int)
  29. if node_attributes:
  30. data_node_att = np.loadtxt(path + name + '_node_attributes.txt', delimiter=',')
  31. data_node_label = np.loadtxt(path + name + '_node_labels.txt', delimiter=',').astype(int)
  32. data_graph_indicator = np.loadtxt(path + name + '_graph_indicator.txt', delimiter=',').astype(int)
  33. if graph_labels:
  34. data_graph_labels = np.loadtxt(path + name + '_graph_labels.txt', delimiter=',').astype(int)
  35. data_tuple = list(map(tuple, data_adj))
  36. G.add_edges_from(data_tuple)
  37. for i in range(data_node_label.shape[0]):
  38. if node_attributes:
  39. G.add_node(i + 1, feature=data_node_att[i])
  40. G.add_node(i + 1, label=data_node_label[i])
  41. G.remove_nodes_from(list(nx.isolates(G)))
  42. graph_num = data_graph_indicator.max()
  43. node_list = np.arange(data_graph_indicator.shape[0]) + 1
  44. graphs = []
  45. max_nodes = 0
  46. for i in range(graph_num):
  47. nodes = node_list[data_graph_indicator == i + 1]
  48. G_sub = G.subgraph(nodes)
  49. if graph_labels:
  50. G_sub.graph['label'] = data_graph_labels[i]
  51. if G_sub.number_of_nodes() >= min_num_nodes and G_sub.number_of_nodes() <= max_num_nodes:
  52. graphs.append(G_sub)
  53. if G_sub.number_of_nodes() > max_nodes:
  54. max_nodes = G_sub.number_of_nodes()
  55. print('Loaded')
  56. return graphs
  57. def attention(query, key, value, d_key):
  58. scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_key)
  59. output = torch.matmul(scores, value)
  60. output = nn.functional.softmax(output)
  61. return output
  62. def get_clones(module, N):
  63. return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
  64. def embedder(graph, dimensions=32, walk_length=8, num_walks=200, workers=4):
  65. node2vec = Node2Vec(graph, dimensions=dimensions, walk_length=walk_length, num_walks=num_walks, workers=workers) # Use temp_folder for big graphs
  66. model = node2vec.fit(window=10, min_count=1, batch_words=4)
  67. return model.wv.vectors
  68. graphs = Graph_load_batch(min_num_nodes=10, name='ENZYMES')
  69. # G = graphs[1]
  70. # vecs = embedder(G)
  71. # pca = PCA(n_components=2)
  72. # principalComponents = pca.fit_transform(vecs)
  73. # principalDf = pd.DataFrame(data = principalComponents
  74. # , columns = ['principal component 1', 'principal component 2'])
  75. # principalDf.index = list(G.nodes())
  76. # sns.scatterplot(principalDf['principal component 1'], principalDf['principal component 2'])
  77. """Sublayers"""
  78. class MultiHeadAttention(nn.Module):
  79. def __init__(self, heads, d_model, dropout = 0.1):
  80. super().__init__()
  81. self.d_model = d_model
  82. self.d_k = d_model // heads
  83. self.h = heads
  84. self.q_linear = nn.Linear(d_model, d_model).cuda()
  85. self.v_linear = nn.Linear(d_model, d_model).cuda()
  86. self.k_linear = nn.Linear(d_model, d_model).cuda()
  87. self.dropout = nn.Dropout(dropout)
  88. self.out = nn.Linear(d_model, d_model)
  89. def forward(self, q, k, v):
  90. # print(q, k, v)
  91. bs = q.size(0)
  92. # perform linear operation and split into h heads
  93. k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
  94. q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
  95. v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
  96. # transpose to get dimensions bs * h * sl * d_model
  97. k = k.transpose(1,2)
  98. q = q.transpose(1,2)
  99. v = v.transpose(1,2)
  100. scores = attention(q, k, v, self.d_k)
  101. # concatenate heads and put through final linear layer
  102. concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
  103. output = self.out(concat)
  104. return output
  105. class FeedForward(nn.Module):
  106. def __init__(self, d_model, d_ff=2048, dropout = 0.1):
  107. super().__init__()
  108. self.linear_1 = nn.Linear(d_model, d_ff).cuda()
  109. self.dropout = nn.Dropout(dropout)
  110. self.linear_2 = nn.Linear(d_ff, d_model).cuda()
  111. def forward(self, x):
  112. x = self.dropout(nn.functional.relu(self.linear_1(x)))
  113. x = self.linear_2(x)
  114. return x
  115. class Norm(nn.Module):
  116. def __init__(self, d_model, eps = 1e-6):
  117. super().__init__()
  118. self.size = d_model
  119. self.alpha = nn.Parameter(torch.ones(self.size))
  120. self.bias = nn.Parameter(torch.zeros(self.size))
  121. self.eps = eps
  122. def forward(self, x):
  123. norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
  124. return norm
  125. """Layers"""
  126. class EncoderLayer(nn.Module):
  127. def __init__(self, d_model, heads, dropout = 0.1):
  128. super().__init__()
  129. self.norm_1 = Norm(d_model)
  130. self.norm_2 = Norm(d_model)
  131. self.attn = MultiHeadAttention(heads, d_model)
  132. self.ff = FeedForward(d_model)
  133. self.dropout_1 = nn.Dropout(dropout)
  134. self.dropout_2 = nn.Dropout(dropout)
  135. def forward(self, x):
  136. # x2 = self.norm_1(x)
  137. x = x + self.dropout_1(self.attn(x,x,x))
  138. # x2 = self.norm_2(x)
  139. x = x + self.dropout_2(self.ff(x))
  140. return x
  141. class DecoderLayer(nn.Module):
  142. def __init__(self, d_model, heads, dropout=0.1):
  143. super().__init__()
  144. self.norm_1 = Norm(d_model)
  145. self.norm_2 = Norm(d_model)
  146. self.norm_3 = Norm(d_model)
  147. self.dropout_1 = nn.Dropout(dropout)
  148. self.dropout_2 = nn.Dropout(dropout)
  149. self.dropout_3 = nn.Dropout(dropout)
  150. self.attn_1 = MultiHeadAttention(heads, d_model)
  151. self.attn_2 = MultiHeadAttention(heads, d_model)
  152. self.ff = FeedForward(d_model).cuda()
  153. def forward(self, x, e_outputs):
  154. # x2 = self.norm_1(x)
  155. x = x + self.dropout_1(self.attn_1(x, x, x))
  156. # x2 = self.norm_2(x)
  157. # x2 = self.norm_2(x)
  158. x = x + self.dropout_2(self.attn_2(x, e_outputs, e_outputs))
  159. # x2 = self.norm_3(x)
  160. x = x + self.dropout_3(self.ff(x))
  161. return x
  162. class Encoder(nn.Module):
  163. def __init__(self, vocab_size, d_model, N, heads):
  164. super().__init__()
  165. self.N = N
  166. self.layers = get_clones(EncoderLayer(d_model, heads), N)
  167. self.norm = Norm(d_model)
  168. def forward(self, src):
  169. x = src
  170. for i in range(N):
  171. x = self.layers[i](x)
  172. return self.norm(x)
  173. class Decoder(nn.Module):
  174. def __init__(self, data_size, d_model, N, heads):
  175. super().__init__()
  176. self.N = N
  177. self.layers = get_clones(DecoderLayer(d_model, heads), N)
  178. self.norm = Norm(d_model)
  179. def forward(self, trg, e_outputs):
  180. x = trg
  181. for i in range(self.N):
  182. x = self.layers[i](x, e_outputs)
  183. return self.norm(x)
  184. """The Mighty Transformer"""
  185. class Transformer(nn.Module):
  186. def __init__(self, src_graph, trg_graph, d_model, N, heads):
  187. super().__init__()
  188. self.encoder = Encoder(src_graph, d_model, N, heads)
  189. self.decoder = Decoder(trg_graph, d_model, N, heads)
  190. self.out = nn.Linear(d_model, trg_graph)
  191. def forward(self, src, trg):
  192. e_outputs = self.encoder(src)
  193. d_output = self.decoder(trg, e_outputs)
  194. output = self.out(d_output)
  195. return output
  196. def remove_random_node(graph, max_size=40, min_size=10):
  197. if len(graph.nodes) >= max_size or len(graph.nodes) < min_size:
  198. return None
  199. relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
  200. choice = np.random.choice(list(relabeled_graph.nodes))
  201. remaining_graph = nx.to_numpy_matrix(relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes))))
  202. removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
  203. graph_length = len(remaining_graph)
  204. source_graph = np.pad(remaining_graph, [(0, max_size - graph_length), (0, max_size - graph_length)])
  205. target_graph = np.copy(source_graph)
  206. removed_node_row = np.asarray(removed_node)[0]
  207. target_graph[graph_length] = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
  208. return source_graph, target_graph
  209. converted_graphs = list(filter(lambda x: x is not None, [remove_random_node(graph) for graph in graphs]))
  210. source_graphs = torch.Tensor([graph[0] for graph in converted_graphs])
  211. target_graphs = torch.Tensor([graph[1] for graph in converted_graphs])
  212. d_model = 40
  213. heads = 8
  214. N = 6
  215. src_size = len(source_graphs)
  216. trg_size = len(target_graphs)
  217. model = Transformer(src_size, trg_size, d_model, N, heads).cuda()
  218. #print(model)
  219. optim = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
  220. def train_model(epoch, print_every=100):
  221. model.train()
  222. start = time.time()
  223. temp = start
  224. total_loss = 0
  225. for i in range(epoch):
  226. src = source_graphs.cuda()
  227. trg = target_graphs.cuda()
  228. preds = model(src.float(), trg.float())
  229. optim.zero_grad()
  230. loss = torch.nn.functional.cross_entropy(preds.view(preds.size(-1), -1), trg.view(trg.size(0), -1))
  231. loss.backward()
  232. optim.step()
  233. total_loss += loss.data[0]
  234. if (i + 1) % print_every == 0:
  235. loss_avg = total_loss / print_every
  236. print("time = %dm, epoch %d, iter = %d, loss = %.3f,\
  237. # %ds per %d iters" % ((time.time() - start) // 60,\
  238. epoch + 1, i + 1, loss_avg, time.time() - temp,\
  239. print_every))
  240. total_loss = 0
  241. temp = time.time()
  242. train_model(1, 1)
  243. #preds = model(source_graphs[0].cuda(), target_graphs[0].cuda())
  244. #loss = torch.nn.functional.cross_entropy(preds.view(preds.size(-1), -1), target_graphs.view(target_graphs.size(0), -1))
  245. #