You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTransformer.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import math
  2. import numpy as np
  3. import pandas as pd
  4. import networkx as nx
  5. import scipy as sp
  6. import seaborn as sns
  7. import time
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.init as init
  11. import torch.nn.functional as F
  12. from torch.nn.parameter import Parameter
  13. from torch.nn.modules.module import Module
  14. from torch import Tensor
  15. if torch.cuda.is_available():
  16. torch.device('cuda')
  17. """
  18. Utils:
  19. Data Loader
  20. Feature Matrix Constructor
  21. Random Node Remover
  22. """
  23. def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_attributes=True, graph_labels=True):
  24. '''
  25. load many graphs, e.g. enzymes
  26. :return: a list of graphs
  27. '''
  28. print('Loading graph dataset: ' + str(name))
  29. G = nx.Graph()
  30. # load data
  31. # path = '../dataset/' + name + '/'
  32. path = '/content/gdrive/My Drive/' + name + '/'
  33. data_adj = np.loadtxt(path + name + '_A.txt', delimiter=',').astype(int)
  34. if node_attributes:
  35. data_node_att = np.loadtxt(path + name + '_node_attributes.txt', delimiter=',')
  36. data_node_label = np.loadtxt(path + name + '_node_labels.txt', delimiter=',').astype(int)
  37. data_graph_indicator = np.loadtxt(path + name + '_graph_indicator.txt', delimiter=',').astype(int)
  38. if graph_labels:
  39. data_graph_labels = np.loadtxt(path + name + '_graph_labels.txt', delimiter=',').astype(int)
  40. data_tuple = list(map(tuple, data_adj))
  41. G.add_edges_from(data_tuple)
  42. for i in range(data_node_label.shape[0]):
  43. if node_attributes:
  44. G.add_node(i + 1, feature=data_node_att[i])
  45. G.add_node(i + 1, label=data_node_label[i])
  46. G.remove_nodes_from(list(nx.isolates(G)))
  47. graph_num = data_graph_indicator.max()
  48. node_list = np.arange(data_graph_indicator.shape[0]) + 1
  49. graphs = []
  50. max_nodes = 0
  51. for i in range(graph_num):
  52. nodes = node_list[data_graph_indicator == i + 1]
  53. G_sub = G.subgraph(nodes)
  54. if graph_labels:
  55. G_sub.graph['label'] = data_graph_labels[i]
  56. if G_sub.number_of_nodes() >= min_num_nodes and G_sub.number_of_nodes() <= max_num_nodes:
  57. graphs.append(G_sub)
  58. if G_sub.number_of_nodes() > max_nodes:
  59. max_nodes = G_sub.number_of_nodes()
  60. print('Loaded')
  61. return graphs
  62. def feature_matrix(g):
  63. '''
  64. constructs the feautre matrix (N x 3) for the enzymes datasets
  65. '''
  66. esm = nx.get_node_attributes(g, 'label')
  67. piazche = np.zeros((len(esm), 3))
  68. for i, (k, v) in enumerate(esm.items()):
  69. piazche[i][v-1] = 1
  70. return piazche
  71. # def remove_random_node(graph, max_size=40, min_size=10):
  72. # '''
  73. # removes a random node from the gragh
  74. # returns the remaining graph matrix and the removed node links
  75. # '''
  76. # if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
  77. # return None
  78. # relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
  79. # choice = np.random.choice(list(relabeled_graph.nodes()))
  80. # remaining_graph = nx.to_numpy_matrix(relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))))
  81. # removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
  82. # graph_length = len(remaining_graph)
  83. # # source_graph = np.pad(remaining_graph, [(0, max_size - graph_length), (0, max_size - graph_length)])
  84. # # target_graph = np.copy(source_graph)
  85. # removed_node_row = np.asarray(removed_node)[0]
  86. # # target_graph[graph_length] = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
  87. # return remaining_graph, removed_node_row
  88. def prepare_graph_data(graph, max_size=40, min_size=10):
  89. '''
  90. gets a graph as an input
  91. returns a graph with a randomly removed node adj matrix [0], its feature matrix [0], the removed node true links [2]
  92. '''
  93. if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
  94. return None
  95. relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
  96. choice = np.random.choice(list(relabeled_graph.nodes()))
  97. remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))
  98. remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)
  99. removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
  100. removed_node_row = np.asarray(removed_node)[0]
  101. return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row
  102. """"
  103. Layers:
  104. Graph Convolution
  105. Graph Multihead Attention
  106. """
  107. class GraphConv(nn.Module):
  108. def __init__(self, input_dim, output_dim):
  109. super().__init__()
  110. self.input_dim = input_dim
  111. self.output_dim = output_dim
  112. self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
  113. self.relu = nn.ReLU()
  114. def forward(self, x, adj):
  115. '''
  116. x is the feature matrix constructed in feature_matrix function
  117. adj ham ke is adjacency matrix of the graph
  118. '''
  119. y = torch.matmul(adj, x)
  120. # print(y.shape)
  121. # print(self.weight.shape)
  122. y = torch.matmul(y, self.weight.double())
  123. return y
  124. class GraphAttn(nn.Module):
  125. def __init__(self, heads, model_dim, dropout=0.1):
  126. super().__init__()
  127. self.model_dim = model_dim
  128. self.key_dim = model_dim // heads
  129. self.heads = heads
  130. self.q_linear = nn.Linear(model_dim, model_dim).cuda()
  131. self.v_linear = nn.Linear(model_dim, model_dim).cuda()
  132. self.k_linear = nn.Linear(model_dim, model_dim).cuda()
  133. self.dropout = nn.Dropout(dropout)
  134. self.out = nn.Linear(model_dim, model_dim).cuda()
  135. def forward(self, query, key, value):
  136. # print(q, k, v)
  137. bs = query.size(0) # size of the graph
  138. key = self.k_linear(key).view(bs, -1, self.heads, self.key_dim)
  139. query = self.q_linear(query).view(bs, -1, self.heads, self.key_dim)
  140. value = self.v_linear(value).view(bs, -1, self.heads, self.key_dim)
  141. key = key.transpose(1,2)
  142. query = query.transpose(1,2)
  143. value = value.transpose(1,2)
  144. scores = attention(query, key, value, self.key_dim)
  145. concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)
  146. output = self.out(concat)
  147. output = output.view(bs, self.model_dim)
  148. return output