Browse Source

padding added

master
Yassaman Ommi 4 years ago
parent
commit
cbaad92456
1 changed files with 21 additions and 1 deletions
  1. 21
    1
      GraphTransformer.py

+ 21
- 1
GraphTransformer.py View File

print('Loaded') print('Loaded')
return graphs return graphs



def feature_matrix(g): def feature_matrix(g):
''' '''
constructs the feautre matrix (N x 3) for the enzymes datasets constructs the feautre matrix (N x 3) for the enzymes datasets
piazche[i][v-1] = 1 piazche[i][v-1] = 1
return piazche return piazche



# def remove_random_node(graph, max_size=40, min_size=10): # def remove_random_node(graph, max_size=40, min_size=10):
# ''' # '''
# removes a random node from the gragh # removes a random node from the gragh
def prepare_graph_data(graph, max_size=40, min_size=10): def prepare_graph_data(graph, max_size=40, min_size=10):
''' '''
gets a graph as an input gets a graph as an input
returns a graph with a randomly removed node adj matrix [0], its feature matrix [0], the removed node true links [2]
returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2]
''' '''
if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size: if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:
return None return None
y = torch.matmul(y, self.weight.double()) y = torch.matmul(y, self.weight.double())
return y return y



class GraphAttn(nn.Module): class GraphAttn(nn.Module):
def __init__(self, heads, model_dim, dropout=0.1): def __init__(self, heads, model_dim, dropout=0.1):
super().__init__() super().__init__()


return output return output



class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, input_size, hidden_size): def __init__(self, input_size, hidden_size):
super().__init__() super().__init__()
output = self.fully_connected2(relu) output = self.fully_connected2(relu)
output = self.sigmoid(output) output = self.sigmoid(output)
return output return output


class Hydra(nn.Module):
def __init__(self, gcn_input, model_dim, head):
super().__init__()

self.GCN = GraphConv(input_dim=gcn_input, output_dim=model_dim).cuda()
self.GAT = GraphAttn(heads=head, model_dim=model_dim).cuda()
self.MLP = FeedForward(input_size=model_dim, hidden_size=gcn_input).cuda()

def forward(self, x, adj):
gcn_outputs = self.GCN(x, adj)
gat_output = self.GAT(gcn_outputs)
mlp_output = self.MLP(gat_output)

return output

Loading…
Cancel
Save