You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTransformer.py 5.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import math
  2. import numpy as np
  3. import pandas as pd
  4. import networkx as nx
  5. import scipy as sp
  6. import seaborn as sns
  7. import time
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.init as init
  11. import torch.nn.functional as F
  12. from torch.nn.parameter import Parameter
  13. from torch.nn.modules.module import Module
  14. from torch import Tensor
  15. if torch.cuda.is_available():
  16. torch.device('cuda')
  17. """
  18. Utils:
  19. Data Loader
  20. Feature Matrix Constructor
  21. Random Node Remover
  22. """
  23. def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_attributes=True, graph_labels=True):
  24. '''
  25. load many graphs, e.g. enzymes
  26. :return: a list of graphs
  27. '''
  28. print('Loading graph dataset: ' + str(name))
  29. G = nx.Graph()
  30. # load data
  31. # path = '../dataset/' + name + '/'
  32. path = '/content/gdrive/My Drive/' + name + '/'
  33. data_adj = np.loadtxt(path + name + '_A.txt', delimiter=',').astype(int)
  34. if node_attributes:
  35. data_node_att = np.loadtxt(path + name + '_node_attributes.txt', delimiter=',')
  36. data_node_label = np.loadtxt(path + name + '_node_labels.txt', delimiter=',').astype(int)
  37. data_graph_indicator = np.loadtxt(path + name + '_graph_indicator.txt', delimiter=',').astype(int)
  38. if graph_labels:
  39. data_graph_labels = np.loadtxt(path + name + '_graph_labels.txt', delimiter=',').astype(int)
  40. data_tuple = list(map(tuple, data_adj))
  41. G.add_edges_from(data_tuple)
  42. for i in range(data_node_label.shape[0]):
  43. if node_attributes:
  44. G.add_node(i + 1, feature=data_node_att[i])
  45. G.add_node(i + 1, label=data_node_label[i])
  46. G.remove_nodes_from(list(nx.isolates(G)))
  47. graph_num = data_graph_indicator.max()
  48. node_list = np.arange(data_graph_indicator.shape[0]) + 1
  49. graphs = []
  50. max_nodes = 0
  51. for i in range(graph_num):
  52. nodes = node_list[data_graph_indicator == i + 1]
  53. G_sub = G.subgraph(nodes)
  54. if graph_labels:
  55. G_sub.graph['label'] = data_graph_labels[i]
  56. if G_sub.number_of_nodes() >= min_num_nodes and G_sub.number_of_nodes() <= max_num_nodes:
  57. graphs.append(G_sub)
  58. if G_sub.number_of_nodes() > max_nodes:
  59. max_nodes = G_sub.number_of_nodes()
  60. print('Loaded')
  61. return graphs
  62. def feature_matrix(g):
  63. '''
  64. constructs the feautre matrix (N x 3) for the enzymes datasets
  65. '''
  66. esm = nx.get_node_attributes(g, 'label')
  67. piazche = np.zeros((len(esm), 3))
  68. for i, (k, v) in enumerate(esm.items()):
  69. piazche[i][v-1] = 1
  70. return piazche
  71. def remove_random_node(graph, max_size=40, min_size=10):
  72. '''
  73. removes a random node from the gragh
  74. returns the remaining graph matrix and the removed node links
  75. '''
  76. if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
  77. return None
  78. relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
  79. choice = np.random.choice(list(relabeled_graph.nodes()))
  80. remaining_graph = nx.to_numpy_matrix(relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))))
  81. removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
  82. graph_length = len(remaining_graph)
  83. # source_graph = np.pad(remaining_graph, [(0, max_size - graph_length), (0, max_size - graph_length)])
  84. # target_graph = np.copy(source_graph)
  85. removed_node_row = np.asarray(removed_node)[0]
  86. # target_graph[graph_length] = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
  87. return remaining_graph, removed_node_row
  88. """"
  89. Layers:
  90. Graph Convolution
  91. Graph Multihead Attention
  92. """
  93. class GraphConv(nn.Module):
  94. def __init__(self, input_dim, output_dim):
  95. super().__init__()
  96. self.input_dim = input_dim
  97. self.output_dim = output_dim
  98. self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
  99. self.relu = nn.ReLU()
  100. def forward(self, x, adj):
  101. '''
  102. x is the feature matrix constructed in feature_matrix function
  103. adj ham ke is adjacency matrix of the graph
  104. '''
  105. y = torch.matmul(adj, x)
  106. # print(y.shape)
  107. # print(self.weight.shape)
  108. y = torch.matmul(y, self.weight.double())
  109. return y
  110. class GraphAttn(nn.Module):
  111. def __init__(self, heads, model_dim, dropout=0.1):
  112. super().__init__()
  113. self.model_dim = model_dim
  114. self.key_dim = model_dim // heads
  115. self.heads = heads
  116. self.q_linear = nn.Linear(model_dim, model_dim).cuda()
  117. self.v_linear = nn.Linear(model_dim, model_dim).cuda()
  118. self.k_linear = nn.Linear(model_dim, model_dim).cuda()
  119. self.dropout = nn.Dropout(dropout)
  120. self.out = nn.Linear(model_dim, model_dim).cuda()
  121. def forward(self, query, key, value):
  122. # print(q, k, v)
  123. bs = query.size(0) # size of the graph
  124. key = self.k_linear(key).view(bs, -1, self.heads, self.key_dim)
  125. query = self.q_linear(query).view(bs, -1, self.heads, self.key_dim)
  126. value = self.v_linear(value).view(bs, -1, self.heads, self.key_dim)
  127. key = key.transpose(1,2)
  128. query = query.transpose(1,2)
  129. value = value.transpose(1,2)
  130. scores = attention(query, key, value, self.key_dim)
  131. concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)
  132. output = self.out(concat)
  133. return output