Browse Source

my vasvas strikes back.

master
Yassaman Ommi 4 years ago
parent
commit
08f7b8518c
1 changed files with 88 additions and 88 deletions
  1. 88
    88
      GraphTransformer.py

+ 88
- 88
GraphTransformer.py View File

import math import math
import time
import torch
import numpy as np import numpy as np
import scipy as sp
import pandas as pd import pandas as pd
import torch.nn as nn
import networkx as nx import networkx as nx
import scipy as sp
import seaborn as sns import seaborn as sns
import time
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.init as init import torch.nn.init as init
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module from torch.nn.modules.module import Module
from torch import Tensor


if torch.cuda.is_available(): if torch.cuda.is_available():
torch.device('cuda') torch.device('cuda')
load many graphs, e.g. enzymes load many graphs, e.g. enzymes
:return: a list of graphs :return: a list of graphs
''' '''
print('Loading graph dataset: ' + str(name))
G = nx.Graph()
# load data
# path = '../dataset/' + name + '/'
path = '/content/gdrive/My Drive/' + name + '/'
data_adj = np.loadtxt(path + name + '_A.txt', delimiter=',').astype(int)
if node_attributes:
data_node_att = np.loadtxt(path + name + '_node_attributes.txt', delimiter=',')
data_node_label = np.loadtxt(path + name + '_node_labels.txt', delimiter=',').astype(int)
data_graph_indicator = np.loadtxt(path + name + '_graph_indicator.txt', delimiter=',').astype(int)
if graph_labels:
data_graph_labels = np.loadtxt(path + name + '_graph_labels.txt', delimiter=',').astype(int)
data_tuple = list(map(tuple, data_adj))
G.add_edges_from(data_tuple)
for i in range(data_node_label.shape[0]):
print('Loading graph dataset: ' + str(name))
G = nx.Graph()
# load data
# path = '../dataset/' + name + '/'
path = '/content/gdrive/My Drive/' + name + '/'
data_adj = np.loadtxt(path + name + '_A.txt', delimiter=',').astype(int)
if node_attributes: if node_attributes:
G.add_node(i + 1, feature=data_node_att[i])
G.add_node(i + 1, label=data_node_label[i])
G.remove_nodes_from(list(nx.isolates(G)))
graph_num = data_graph_indicator.max()
node_list = np.arange(data_graph_indicator.shape[0]) + 1
graphs = []
max_nodes = 0
for i in range(graph_num):
nodes = node_list[data_graph_indicator == i + 1]
G_sub = G.subgraph(nodes)
data_node_att = np.loadtxt(path + name + '_node_attributes.txt', delimiter=',')
data_node_label = np.loadtxt(path + name + '_node_labels.txt', delimiter=',').astype(int)
data_graph_indicator = np.loadtxt(path + name + '_graph_indicator.txt', delimiter=',').astype(int)
if graph_labels: if graph_labels:
G_sub.graph['label'] = data_graph_labels[i]

if G_sub.number_of_nodes() >= min_num_nodes and G_sub.number_of_nodes() <= max_num_nodes:
graphs.append(G_sub)
if G_sub.number_of_nodes() > max_nodes:
max_nodes = G_sub.number_of_nodes()
print('Loaded')

return graphs
data_graph_labels = np.loadtxt(path + name + '_graph_labels.txt', delimiter=',').astype(int)
data_tuple = list(map(tuple, data_adj))
G.add_edges_from(data_tuple)
for i in range(data_node_label.shape[0]):
if node_attributes:
G.add_node(i + 1, feature=data_node_att[i])
G.add_node(i + 1, label=data_node_label[i])
G.remove_nodes_from(list(nx.isolates(G)))
graph_num = data_graph_indicator.max()
node_list = np.arange(data_graph_indicator.shape[0]) + 1
graphs = []
max_nodes = 0
for i in range(graph_num):
nodes = node_list[data_graph_indicator == i + 1]
G_sub = G.subgraph(nodes)
if graph_labels:
G_sub.graph['label'] = data_graph_labels[i]

if G_sub.number_of_nodes() >= min_num_nodes and G_sub.number_of_nodes() <= max_num_nodes:
graphs.append(G_sub)
if G_sub.number_of_nodes() > max_nodes:
max_nodes = G_sub.number_of_nodes()
print('Loaded')

return graphs




def feature_matrix(g): def feature_matrix(g):
''' '''
constructs the feautre matrix (N x 3) for the enzymes datasets constructs the feautre matrix (N x 3) for the enzymes datasets
''' '''
esm = nx.get_node_attributes(g, 'label')
piazche = np.zeros((len(esm), 3))
for i, (k, v) in enumerate(esm.items()):
piazche[i][v-1] = 1
esm = nx.get_node_attributes(g, 'label')
piazche = np.zeros((len(esm), 3))
for i, (k, v) in enumerate(esm.items()):
piazche[i][v-1] = 1


return piazche
return piazche




# def remove_random_node(graph, max_size=40, min_size=10): # def remove_random_node(graph, max_size=40, min_size=10):
gets a graph as an input gets a graph as an input
returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2] returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2]
''' '''
if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
return None
relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
choice = np.random.choice(list(relabeled_graph.nodes()))
remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))
remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)
graph_length = len(remaining_graph)
remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)])
removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
removed_node_row = np.asarray(removed_node)[0]
removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row
if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
return None
relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)
choice = np.random.choice(list(relabeled_graph.nodes()))
remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))
remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)
graph_length = len(remaining_graph)
remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)])
removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]
removed_node_row = np.asarray(removed_node)[0]
removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])
return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row


"""" """"
Layers: Layers:


class GraphAttn(nn.Module): class GraphAttn(nn.Module):
def __init__(self, heads, model_dim, dropout=0.1): def __init__(self, heads, model_dim, dropout=0.1):
super().__init__()
self.model_dim = model_dim
self.key_dim = model_dim // heads
self.heads = heads
super().__init__()
self.model_dim = model_dim
self.key_dim = model_dim // heads
self.heads = heads


self.q_linear = nn.Linear(model_dim, model_dim).cuda()
self.v_linear = nn.Linear(model_dim, model_dim).cuda()
self.k_linear = nn.Linear(model_dim, model_dim).cuda()
self.q_linear = nn.Linear(model_dim, model_dim).cuda()
self.v_linear = nn.Linear(model_dim, model_dim).cuda()
self.k_linear = nn.Linear(model_dim, model_dim).cuda()


self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(model_dim, model_dim).cuda()
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(model_dim, model_dim).cuda()


def forward(self, query, key, value): def forward(self, query, key, value):
# print(q, k, v)
bs = query.size(0)
# print(q, k, v)
bs = query.size(0)


key = self.k_linear(key.float()).view(bs, -1, self.heads, self.key_dim)
query = self.q_linear(query.float()).view(bs, -1, self.heads, self.key_dim)
value = self.v_linear(value.float()).view(bs, -1, self.heads, self.key_dim)
key = self.k_linear(key.float()).view(bs, -1, self.heads, self.key_dim)
query = self.q_linear(query.float()).view(bs, -1, self.heads, self.key_dim)
value = self.v_linear(value.float()).view(bs, -1, self.heads, self.key_dim)


key = key.transpose(1,2)
query = query.transpose(1,2)
value = value.transpose(1,2)
key = key.transpose(1,2)
query = query.transpose(1,2)
value = value.transpose(1,2)


scores = attention(query, key, value, self.key_dim)
concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)
output = self.out(concat)
output = output.view(bs, self.model_dim)
scores = attention(query, key, value, self.key_dim)
concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)
output = self.out(concat)
output = output.view(bs, self.model_dim)


return output
return output




class FeedForward(nn.Module): class FeedForward(nn.Module):




class Hydra(nn.Module): class Hydra(nn.Module):
def __init__(self, gcn_input, model_dim, head):
super().__init__()
def __init__(self, gcn_input, model_dim, head):
super().__init__()


self.GCN = GraphConv(input_dim=gcn_input, output_dim=model_dim).cuda()
self.GAT = GraphAttn(heads=head, model_dim=model_dim).cuda()
self.MLP = FeedForward(input_size=model_dim, hidden_size=gcn_input).cuda()
self.GCN = GraphConv(input_dim=gcn_input, output_dim=model_dim).cuda()
self.GAT = GraphAttn(heads=head, model_dim=model_dim).cuda()
self.MLP = FeedForward(input_size=model_dim, hidden_size=gcn_input).cuda()


def forward(self, x, adj):
gcn_outputs = self.GCN(x, adj)
gat_output = self.GAT(gcn_outputs)
mlp_output = self.MLP(gat_output).reshape(1,-1)
def forward(self, x, adj):
gcn_outputs = self.GCN(x, adj)
gat_output = self.GAT(gcn_outputs)
mlp_output = self.MLP(gat_output).reshape(1,-1)


return mlp_output
return mlp_output

Loading…
Cancel
Save