|
|
|
|
|
|
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_attributes=True, graph_labels=True):\n", |
|
|
"def Graph_load_batch(min_num_nodes=20, max_num_nodes=1000, name='ENZYMES', node_attributes=True, graph_labels=True):\n", |
|
|
" '''\n", |
|
|
|
|
|
" load many graphs, e.g. enzymes\n", |
|
|
|
|
|
" :return: a list of graphs\n", |
|
|
|
|
|
" '''\n", |
|
|
|
|
|
" print('Loading graph dataset: ' + str(name))\n", |
|
|
" print('Loading graph dataset: ' + str(name))\n", |
|
|
" G = nx.Graph()\n", |
|
|
" G = nx.Graph()\n", |
|
|
" # load data\n", |
|
|
" # load data\n", |
|
|
|
|
|
|
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
|
|
|
"cell_type": "markdown", |
|
|
"metadata": {}, |
|
|
"metadata": {}, |
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
"source": [ |
|
|
"Constructing" |
|
|
|
|
|
|
|
|
"Constructing feature matrix of a graph" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"def feature_matrix(g, max_nodes=40):\n", |
|
|
"def feature_matrix(g, max_nodes=40):\n", |
|
|
" esm = nx.get_node_attributes(g, 'label')\n", |
|
|
|
|
|
" piazche = np.zeros((max_nodes, 3))\n", |
|
|
|
|
|
" for i, (k, v) in enumerate(esm.items()):\n", |
|
|
|
|
|
" # print(i, k , v)\n", |
|
|
|
|
|
" piazche[i][v-1] = 1\n", |
|
|
|
|
|
" return piazche" |
|
|
|
|
|
|
|
|
" esm = nx.get_node_attributes(g, 'label')\n", |
|
|
|
|
|
" piazche = np.zeros((max_nodes, 3))\n", |
|
|
|
|
|
" for i, (k, v) in enumerate(esm.items()):\n", |
|
|
|
|
|
" # print(i, k , v)\n", |
|
|
|
|
|
" piazche[i][v-1] = 1\n", |
|
|
|
|
|
" return piazche" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"Removing a random node from a graph\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"Returns remaining graph with removed node links" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 12, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": {}, |
|
|
|
|
|
"colab_type": "code", |
|
|
|
|
|
"id": "PkeAYLDQ2W5D" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def remove_random_node(graph, max_size=40, min_size=10):\n", |
|
|
|
|
|
" if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:\n", |
|
|
|
|
|
" return None\n", |
|
|
|
|
|
" relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)\n", |
|
|
|
|
|
" choice = np.random.choice(list(relabeled_graph.nodes()))\n", |
|
|
|
|
|
" remaining_graph = nx.to_numpy_matrix(relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))))\n", |
|
|
|
|
|
" removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]\n", |
|
|
|
|
|
" graph_length = len(remaining_graph)\n", |
|
|
|
|
|
" source_graph = np.pad(remaining_graph, [(0, max_size - graph_length), (0, max_size - graph_length)])\n", |
|
|
|
|
|
" # target_graph = np.copy(source_graph)\n", |
|
|
|
|
|
" removed_node_row = np.asarray(removed_node)[0]\n", |
|
|
|
|
|
" # target_graph[graph_length] = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])\n", |
|
|
|
|
|
" return remaining_graph, removed_node_row" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"Prepare graphs for the model\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2] " |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 13, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": {}, |
|
|
|
|
|
"colab_type": "code", |
|
|
|
|
|
"id": "FXFQJIvE2Ync" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def prepare_graph_data(graph, max_size=40, min_size=10):\n", |
|
|
|
|
|
" if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:\n", |
|
|
|
|
|
" return None\n", |
|
|
|
|
|
" relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)\n", |
|
|
|
|
|
" choice = np.random.choice(list(relabeled_graph.nodes()))\n", |
|
|
|
|
|
" remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))\n", |
|
|
|
|
|
" remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)\n", |
|
|
|
|
|
" graph_length = len(remaining_graph)\n", |
|
|
|
|
|
" remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)])\n", |
|
|
|
|
|
" removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]\n", |
|
|
|
|
|
" removed_node_row = np.asarray(removed_node)[0]\n", |
|
|
|
|
|
" removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])\n", |
|
|
|
|
|
" return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 10, |
|
|
|
|
|
|
|
|
"execution_count": 30, |
|
|
"metadata": { |
|
|
"metadata": { |
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/", |
|
|
|
|
|
"height": 34 |
|
|
|
|
|
}, |
|
|
|
|
|
|
|
|
"colab": {}, |
|
|
"colab_type": "code", |
|
|
"colab_type": "code", |
|
|
"id": "KdO6zZruosHe", |
|
|
|
|
|
"outputId": "e226f7eb-b285-4e02-b576-2b2998c9240f" |
|
|
|
|
|
|
|
|
"id": "tQdkjvY22_Kf" |
|
|
}, |
|
|
}, |
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"469" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": 10, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"tags": [] |
|
|
|
|
|
}, |
|
|
|
|
|
"output_type": "execute_result" |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
|
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"# len(train)" |
|
|
|
|
|
|
|
|
"# coop = sum([list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs])) for i in range(10)], [])\n", |
|
|
|
|
|
"coop = list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in train]))\n", |
|
|
|
|
|
"dale = list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in test]))" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 12, |
|
|
|
|
|
|
|
|
"execution_count": 47, |
|
|
"metadata": { |
|
|
"metadata": { |
|
|
"colab": {}, |
|
|
"colab": {}, |
|
|
"colab_type": "code", |
|
|
"colab_type": "code", |
|
|
"id": "PkeAYLDQ2W5D" |
|
|
|
|
|
|
|
|
"id": "qChhpZuCpHWv" |
|
|
}, |
|
|
}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"def remove_random_node(graph, max_size=40, min_size=10):\n", |
|
|
|
|
|
" if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:\n", |
|
|
|
|
|
" return None\n", |
|
|
|
|
|
" relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)\n", |
|
|
|
|
|
" choice = np.random.choice(list(relabeled_graph.nodes()))\n", |
|
|
|
|
|
" remaining_graph = nx.to_numpy_matrix(relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes()))))\n", |
|
|
|
|
|
" removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]\n", |
|
|
|
|
|
" graph_length = len(remaining_graph)\n", |
|
|
|
|
|
" source_graph = np.pad(remaining_graph, [(0, max_size - graph_length), (0, max_size - graph_length)])\n", |
|
|
|
|
|
" # target_graph = np.copy(source_graph)\n", |
|
|
|
|
|
" removed_node_row = np.asarray(removed_node)[0]\n", |
|
|
|
|
|
" # target_graph[graph_length] = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])\n", |
|
|
|
|
|
" return remaining_graph, removed_node_row" |
|
|
|
|
|
|
|
|
"trainloader = torch.utils.data.DataLoader(coop, collate_fn=lambda x: x[0], batch_size=1)" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 13, |
|
|
|
|
|
|
|
|
"execution_count": 32, |
|
|
"metadata": { |
|
|
"metadata": { |
|
|
"colab": {}, |
|
|
"colab": {}, |
|
|
"colab_type": "code", |
|
|
"colab_type": "code", |
|
|
"id": "FXFQJIvE2Ync" |
|
|
|
|
|
|
|
|
"id": "8FHvrVFlqNh2" |
|
|
}, |
|
|
}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"def prepare_graph_data(graph, max_size=40, min_size=10):\n", |
|
|
|
|
|
" '''\n", |
|
|
|
|
|
" gets a graph as an input\n", |
|
|
|
|
|
" returns a graph with a randomly removed node adj matrix [0], its feature matrix [1], the removed node true links [2] \n", |
|
|
|
|
|
" '''\n", |
|
|
|
|
|
" if len(graph.nodes()) >= max_size or len(graph.nodes()) < min_size:\n", |
|
|
|
|
|
" return None\n", |
|
|
|
|
|
" relabeled_graph = nx.relabel.convert_node_labels_to_integers(graph)\n", |
|
|
|
|
|
" choice = np.random.choice(list(relabeled_graph.nodes()))\n", |
|
|
|
|
|
" remaining_graph = relabeled_graph.subgraph(filter(lambda x: x != choice, list(relabeled_graph.nodes())))\n", |
|
|
|
|
|
" remaining_graph_adj = nx.to_numpy_matrix(remaining_graph)\n", |
|
|
|
|
|
" graph_length = len(remaining_graph)\n", |
|
|
|
|
|
" remaining_graph_adj = np.pad(remaining_graph_adj, [(0, max_size - graph_length), (0, max_size - graph_length)])\n", |
|
|
|
|
|
" removed_node = nx.to_numpy_matrix(relabeled_graph)[choice]\n", |
|
|
|
|
|
" removed_node_row = np.asarray(removed_node)[0]\n", |
|
|
|
|
|
" removed_node_row = np.pad(removed_node_row, [(0, max_size - len(removed_node_row))])\n", |
|
|
|
|
|
" return remaining_graph_adj, feature_matrix(remaining_graph), removed_node_row" |
|
|
|
|
|
|
|
|
"testloader = torch.utils.data.DataLoader(dale, collate_fn=lambda x: x[0], batch_size=1)" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
"id": "JEvec2nosVn7" |
|
|
"id": "JEvec2nosVn7" |
|
|
}, |
|
|
}, |
|
|
"source": [ |
|
|
"source": [ |
|
|
"# Model" |
|
|
|
|
|
|
|
|
"# Building Model" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"Graph convolutional layer for extracting initial features" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
" return y" |
|
|
" return y" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"Attention calculator using given key, query and value" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 15, |
|
|
"execution_count": 15, |
|
|
|
|
|
|
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"def attention(query, key, value, key_dim):\n", |
|
|
"def attention(query, key, value, key_dim):\n", |
|
|
" # print('key:', key.transpose(-2,-1))\n", |
|
|
|
|
|
" scores = torch.matmul(query, key.transpose(-2,-1)) / math.sqrt(key_dim)\n", |
|
|
|
|
|
" scores = torch.matmul(scores, value)\n", |
|
|
|
|
|
" # print('scores:', scores)\n", |
|
|
|
|
|
" scores = F.softmax(scores)\n", |
|
|
|
|
|
" # scores = torch.sigmoid(scores)\n", |
|
|
|
|
|
" return scores" |
|
|
|
|
|
|
|
|
" # print('key:', key.transpose(-2,-1))\n", |
|
|
|
|
|
" scores = torch.matmul(query, key.transpose(-2,-1)) / math.sqrt(key_dim)\n", |
|
|
|
|
|
" scores = torch.matmul(scores, value)\n", |
|
|
|
|
|
" # print('scores:', scores)\n", |
|
|
|
|
|
" scores = F.softmax(scores)\n", |
|
|
|
|
|
" # scores = torch.sigmoid(scores)\n", |
|
|
|
|
|
" return scores" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"Graph attention layer for more features" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"class GraphAttn(nn.Module):\n", |
|
|
"class GraphAttn(nn.Module):\n", |
|
|
" def __init__(self, heads, model_dim, dropout=0.1):\n", |
|
|
|
|
|
" super().__init__()\n", |
|
|
|
|
|
" self.model_dim = model_dim\n", |
|
|
|
|
|
" self.key_dim = model_dim // heads\n", |
|
|
|
|
|
" self.heads = heads\n", |
|
|
|
|
|
|
|
|
" def __init__(self, heads, model_dim, dropout=0.1):\n", |
|
|
|
|
|
" super().__init__()\n", |
|
|
|
|
|
" self.model_dim = model_dim\n", |
|
|
|
|
|
" self.key_dim = model_dim // heads\n", |
|
|
|
|
|
" self.heads = heads\n", |
|
|
"\n", |
|
|
"\n", |
|
|
" self.q_linear = nn.Linear(model_dim, model_dim).cuda()\n", |
|
|
|
|
|
" self.v_linear = nn.Linear(model_dim, model_dim).cuda()\n", |
|
|
|
|
|
" self.k_linear = nn.Linear(model_dim, model_dim).cuda()\n", |
|
|
|
|
|
|
|
|
" self.q_linear = nn.Linear(model_dim, model_dim).cuda()\n", |
|
|
|
|
|
" self.v_linear = nn.Linear(model_dim, model_dim).cuda()\n", |
|
|
|
|
|
" self.k_linear = nn.Linear(model_dim, model_dim).cuda()\n", |
|
|
"\n", |
|
|
"\n", |
|
|
" self.dropout = nn.Dropout(dropout)\n", |
|
|
|
|
|
" self.out = nn.Linear(model_dim, model_dim).cuda()\n", |
|
|
|
|
|
|
|
|
" self.dropout = nn.Dropout(dropout)\n", |
|
|
|
|
|
" self.out = nn.Linear(model_dim, model_dim).cuda()\n", |
|
|
"\n", |
|
|
"\n", |
|
|
" def forward(self, query, key, value):\n", |
|
|
|
|
|
" # print(q, k, v)\n", |
|
|
|
|
|
" bs = query.size(0)\n", |
|
|
|
|
|
|
|
|
" def forward(self, query, key, value):\n", |
|
|
|
|
|
" # print(q, k, v)\n", |
|
|
|
|
|
" bs = query.size(0)\n", |
|
|
"\n", |
|
|
"\n", |
|
|
" key = self.k_linear(key.float()).view(bs, -1, self.heads, self.key_dim)\n", |
|
|
|
|
|
" query = self.q_linear(query.float()).view(bs, -1, self.heads, self.key_dim)\n", |
|
|
|
|
|
" value = self.v_linear(value.float()).view(bs, -1, self.heads, self.key_dim)\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" key = key.transpose(1,2)\n", |
|
|
|
|
|
" query = query.transpose(1,2)\n", |
|
|
|
|
|
" value = value.transpose(1,2)\n", |
|
|
|
|
|
|
|
|
" key = self.k_linear(key.float()).view(bs, -1, self.heads, self.key_dim)\n", |
|
|
|
|
|
" query = self.q_linear(query.float()).view(bs, -1, self.heads, self.key_dim)\n", |
|
|
|
|
|
" value = self.v_linear(value.float()).view(bs, -1, self.heads, self.key_dim)\n", |
|
|
"\n", |
|
|
"\n", |
|
|
" scores = attention(query, key, value, self.key_dim)\n", |
|
|
|
|
|
" concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)\n", |
|
|
|
|
|
" output = self.out(concat)\n", |
|
|
|
|
|
" output = output.view(bs, self.model_dim)\n", |
|
|
|
|
|
|
|
|
" key = key.transpose(1,2)\n", |
|
|
|
|
|
" query = query.transpose(1,2)\n", |
|
|
|
|
|
" value = value.transpose(1,2)\n", |
|
|
"\n", |
|
|
"\n", |
|
|
" return output" |
|
|
|
|
|
|
|
|
" scores = attention(query, key, value, self.key_dim)\n", |
|
|
|
|
|
" concat = scores.transpose(1,2).contiguous().view(bs, -1, self.model_dim)\n", |
|
|
|
|
|
" output = self.out(concat)\n", |
|
|
|
|
|
" output = output.view(bs, self.model_dim)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" return output" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"MLP layer to map features to links" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
" return output" |
|
|
" return output" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"Assembeled model" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 18, |
|
|
"execution_count": 18, |
|
|
|
|
|
|
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"class Hydra(nn.Module):\n", |
|
|
"class Hydra(nn.Module):\n", |
|
|
" def __init__(self, gcn_input, model_dim, head):\n", |
|
|
|
|
|
" super().__init__()\n", |
|
|
|
|
|
|
|
|
" def __init__(self, gcn_input, model_dim, head):\n", |
|
|
|
|
|
" super().__init__()\n", |
|
|
"\n", |
|
|
"\n", |
|
|
" self.GCN = GraphConv(input_dim=gcn_input, output_dim=model_dim).cuda()\n", |
|
|
|
|
|
" self.GAT = GraphAttn(heads=head, model_dim=model_dim).cuda()\n", |
|
|
|
|
|
" self.MLP = FeedForward(input_size=model_dim, hidden_size=gcn_input).cuda()\n", |
|
|
|
|
|
|
|
|
" self.GCN = GraphConv(input_dim=gcn_input, output_dim=model_dim).cuda()\n", |
|
|
|
|
|
" self.GAT = GraphAttn(heads=head, model_dim=model_dim).cuda()\n", |
|
|
|
|
|
" self.MLP = FeedForward(input_size=model_dim, hidden_size=gcn_input).cuda()\n", |
|
|
"\n", |
|
|
"\n", |
|
|
" def forward(self, x, adj):\n", |
|
|
|
|
|
" gcn_outputs = self.GCN(x, adj)\n", |
|
|
|
|
|
" gat_output = self.GAT(gcn_outputs, gcn_outputs, gcn_outputs)\n", |
|
|
|
|
|
" mlp_output = self.MLP(gat_output).reshape(1,-1)\n", |
|
|
|
|
|
|
|
|
" def forward(self, x, adj):\n", |
|
|
|
|
|
" gcn_outputs = self.GCN(x, adj)\n", |
|
|
|
|
|
" gat_output = self.GAT(gcn_outputs, gcn_outputs, gcn_outputs)\n", |
|
|
|
|
|
" mlp_output = self.MLP(gat_output).reshape(1,-1)\n", |
|
|
"\n", |
|
|
"\n", |
|
|
" return mlp_output" |
|
|
|
|
|
|
|
|
" return mlp_output" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"Building model with given inputs" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": {}, |
|
|
|
|
|
"colab_type": "code", |
|
|
|
|
|
"id": "hKeg1-8P2kqK" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# adj, features, true_links = prepare_graph_data(graphs[0])" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": {}, |
|
|
|
|
|
"colab_type": "code", |
|
|
|
|
|
"id": "8QgGRjyt2sFN" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda()" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": {}, |
|
|
|
|
|
"colab_type": "code", |
|
|
|
|
|
"id": "tcSaiDch3CUG" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# kyle = build_model(3, 243, 9)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/", |
|
|
|
|
|
"height": 173 |
|
|
|
|
|
}, |
|
|
|
|
|
"colab_type": "code", |
|
|
|
|
|
"id": "tjioYsvs2tUR", |
|
|
|
|
|
"outputId": "06033c55-7e06-4d6c-8b98-a4bbff7bf9f2" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:6: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", |
|
|
|
|
|
" \n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"tensor([[0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425,\n", |
|
|
|
|
|
" 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425,\n", |
|
|
|
|
|
" 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425, 0.5425,\n", |
|
|
|
|
|
" 0.5425, 0.5425, 0.5425, 0.5548, 0.5548, 0.5425, 0.5548, 0.5425, 0.5548,\n", |
|
|
|
|
|
" 0.5548, 0.5548, 0.5548, 0.5548]], device='cuda:0',\n", |
|
|
|
|
|
" grad_fn=<ViewBackward>)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": 20, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"tags": [] |
|
|
|
|
|
}, |
|
|
|
|
|
"output_type": "execute_result" |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"source": [ |
|
|
"# kyle(features, adj)" |
|
|
|
|
|
|
|
|
"# Evaluating Functions" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
"id": "FxNdr2zR3Zgo" |
|
|
"id": "FxNdr2zR3Zgo" |
|
|
}, |
|
|
}, |
|
|
"source": [ |
|
|
"source": [ |
|
|
"# Hala Train" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 23, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": {}, |
|
|
|
|
|
"colab_type": "code", |
|
|
|
|
|
"id": "t_k4B_3ko8ps" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def fn(batch):\n", |
|
|
|
|
|
" return batch[0]" |
|
|
|
|
|
|
|
|
"# Training Model" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 30, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": {}, |
|
|
|
|
|
"colab_type": "code", |
|
|
|
|
|
"id": "tQdkjvY22_Kf" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"source": [ |
|
|
"# coop = sum([list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in graphs])) for i in range(10)], [])\n", |
|
|
|
|
|
"coop = list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in train]))\n", |
|
|
|
|
|
"dale = list(filter(lambda x: x is not None, [prepare_graph_data(g) for g in test]))" |
|
|
|
|
|
|
|
|
"Training the model with given data and number of epochs" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 47, |
|
|
|
|
|
|
|
|
"execution_count": 57, |
|
|
"metadata": { |
|
|
"metadata": { |
|
|
"colab": {}, |
|
|
"colab": {}, |
|
|
"colab_type": "code", |
|
|
"colab_type": "code", |
|
|
"id": "qChhpZuCpHWv" |
|
|
|
|
|
|
|
|
"id": "PEXlurWtpII2" |
|
|
}, |
|
|
}, |
|
|
"outputs": [], |
|
|
"outputs": [], |
|
|
"source": [ |
|
|
"source": [ |
|
|
"trainloader_train = torch.utils.data.DataLoader(coop, collate_fn=fn, batch_size=1)" |
|
|
|
|
|
|
|
|
"def train_model(model, trainloader, epoch, print_every=100):\n", |
|
|
|
|
|
" optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" model.train()\n", |
|
|
|
|
|
" start = time.time()\n", |
|
|
|
|
|
" temp = start\n", |
|
|
|
|
|
" total_loss = 0\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" for i in range(epoch):\n", |
|
|
|
|
|
" for batch, data in enumerate(trainloader, 0):\n", |
|
|
|
|
|
" adj, features, true_links = data\n", |
|
|
|
|
|
" adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda()\n", |
|
|
|
|
|
" # print(adj.shape)\n", |
|
|
|
|
|
" # print(features.shape)\n", |
|
|
|
|
|
" # print(true_links.shape)\n", |
|
|
|
|
|
" preds = model(features, adj)\n", |
|
|
|
|
|
" optim.zero_grad()\n", |
|
|
|
|
|
" loss = F.binary_cross_entropy(preds.double(), true_links.double())\n", |
|
|
|
|
|
" writer.add_scalar('Loss/train', float(loss), i)\n", |
|
|
|
|
|
" loss.backward()\n", |
|
|
|
|
|
" optim.step()\n", |
|
|
|
|
|
" total_loss += loss.item()\n", |
|
|
|
|
|
" if (i + 1) % print_every == 0:\n", |
|
|
|
|
|
" loss_avg = total_loss / print_every\n", |
|
|
|
|
|
" print(\"time = %dm, epoch %d, iter = %d, loss = %.3f,\\\n", |
|
|
|
|
|
" %ds per %d iters\" % ((time.time() - start) // 60,\\\n", |
|
|
|
|
|
" epoch + 1, i + 1, loss_avg, time.time() - temp,\\\n", |
|
|
|
|
|
" print_every))\n", |
|
|
|
|
|
" total_loss = 0\n", |
|
|
|
|
|
" temp = time.time()" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
"kyle = build_model(3, 243, 9)" |
|
|
"kyle = build_model(3, 243, 9)" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 57, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": {}, |
|
|
|
|
|
"colab_type": "code", |
|
|
|
|
|
"id": "PEXlurWtpII2" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def train_model(model, trainloader, epoch, print_every=100):\n", |
|
|
|
|
|
" optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" model.train()\n", |
|
|
|
|
|
" start = time.time()\n", |
|
|
|
|
|
" temp = start\n", |
|
|
|
|
|
" total_loss = 0\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" for i in range(epoch):\n", |
|
|
|
|
|
" for batch, data in enumerate(trainloader, 0):\n", |
|
|
|
|
|
" adj, features, true_links = data\n", |
|
|
|
|
|
" adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda()\n", |
|
|
|
|
|
" # print(adj.shape)\n", |
|
|
|
|
|
" # print(features.shape)\n", |
|
|
|
|
|
" # print(true_links.shape)\n", |
|
|
|
|
|
" preds = model(features, adj)\n", |
|
|
|
|
|
" optim.zero_grad()\n", |
|
|
|
|
|
" loss = F.binary_cross_entropy(preds.double(), true_links.double())\n", |
|
|
|
|
|
" writer.add_scalar('Loss/train', float(loss), i)\n", |
|
|
|
|
|
" loss.backward()\n", |
|
|
|
|
|
" optim.step()\n", |
|
|
|
|
|
" total_loss += loss.item()\n", |
|
|
|
|
|
" if (i + 1) % print_every == 0:\n", |
|
|
|
|
|
" loss_avg = total_loss / print_every\n", |
|
|
|
|
|
" print(\"time = %dm, epoch %d, iter = %d, loss = %.3f,\\\n", |
|
|
|
|
|
" %ds per %d iters\" % ((time.time() - start) // 60,\\\n", |
|
|
|
|
|
" epoch + 1, i + 1, loss_avg, time.time() - temp,\\\n", |
|
|
|
|
|
" print_every))\n", |
|
|
|
|
|
" total_loss = 0\n", |
|
|
|
|
|
" temp = time.time()" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
"cell_type": "code", |
|
|
"execution_count": 58, |
|
|
"execution_count": 58, |
|
|
|
|
|
|
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": 32, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": {}, |
|
|
|
|
|
"colab_type": "code", |
|
|
|
|
|
"id": "8FHvrVFlqNh2" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
"source": [ |
|
|
"source": [ |
|
|
"trainloader_test = torch.utils.data.DataLoader(dale, collate_fn=fn, batch_size=1)" |
|
|
|
|
|
|
|
|
"# Testing Model" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"Testing model and printing the loss" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
" # total_loss = 0\n", |
|
|
" # total_loss = 0\n", |
|
|
"\n", |
|
|
"\n", |
|
|
" for batch, data in enumerate(trainloader, 0):\n", |
|
|
" for batch, data in enumerate(trainloader, 0):\n", |
|
|
" adj, features, true_links = data\n", |
|
|
|
|
|
" adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda()\n", |
|
|
|
|
|
" # print(adj.shape)\n", |
|
|
|
|
|
" # print(features.shape)\n", |
|
|
|
|
|
" # print(true_links.shape)\n", |
|
|
|
|
|
" preds = model(features, adj)\n", |
|
|
|
|
|
" loss = F.binary_cross_entropy(preds.double(), true_links.double())\n", |
|
|
|
|
|
" # total_loss += loss.item()\n", |
|
|
|
|
|
" # loss_avg = total_loss / print_every\n", |
|
|
|
|
|
" if (batch + 1) % print_every == 0:\n", |
|
|
|
|
|
" print(\"loss = \", float(loss))\n", |
|
|
|
|
|
" temp = time.time()" |
|
|
|
|
|
|
|
|
" adj, features, true_links = data\n", |
|
|
|
|
|
" adj, features, true_links = torch.tensor(adj).cuda(), torch.tensor(features).cuda(), torch.tensor(true_links).cuda()\n", |
|
|
|
|
|
" # print(adj.shape)\n", |
|
|
|
|
|
" # print(features.shape)\n", |
|
|
|
|
|
" # print(true_links.shape)\n", |
|
|
|
|
|
" preds = model(features, adj)\n", |
|
|
|
|
|
" loss = F.binary_cross_entropy(preds.double(), true_links.double())\n", |
|
|
|
|
|
" # total_loss += loss.item()\n", |
|
|
|
|
|
" # loss_avg = total_loss / print_every\n", |
|
|
|
|
|
" if (batch + 1) % print_every == 0:\n", |
|
|
|
|
|
" print(\"loss = \", float(loss))\n", |
|
|
|
|
|
" temp = time.time()" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
{ |
|
|
|
|
|
|
|
|
"test_model(kyle, trainloader_test)" |
|
|
"test_model(kyle, trainloader_test)" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# Evaluating Model" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
{ |
|
|
{ |
|
|
"cell_type": "markdown", |
|
|
"cell_type": "markdown", |
|
|
"metadata": { |
|
|
"metadata": { |