123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Nx8IcGZHAKn_"
- },
- "source": [
- "# Imports"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Dhv1DYPL3Vm1"
- },
- "outputs": [],
- "source": [
- "from torch_geometric.nn import HGTConv, Linear\n",
- "from torch_geometric.loader import HGTLoader\n",
- "from torch_geometric.data import HeteroData\n",
- "import torch.nn.functional as F\n",
- "import pickle5 as pickle\n",
- "import torch.nn as nn\n",
- "import pandas as pd\n",
- "from utils import *\n",
- "import random\n",
- "import torch\n",
- "import copy"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
- "node_type1 = 'drug'\n",
- "node_type2 = 'disease'\n",
- "rel = 'indication'"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "oOEkKAITR8tc"
- },
- "outputs": [],
- "source": [
- "config = {\n",
- " \"num_samples\": 512,\n",
- " \"batch_size\": 164,\n",
- " \"dropout\": 0.5,\n",
- " \"epochs\": 300\n",
- "}"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-cRmQ9cEAO_K"
- },
- "source": [
- "# Load data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "uf_2DpGQCCFJ"
- },
- "outputs": [],
- "source": [
- "primekg_file = '../data/kg.csv'\n",
- "df = pd.read_csv(primekg_file, sep =\",\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Get drugs and diseases which are used in indication relation."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "drug_disease_pairs = df[df['relation']==rel]\n",
- "drugs, diseases = [], []\n",
- "\n",
- "for i, row in drug_disease_pairs.iterrows():\n",
- " if row['x_type'] == node_type1:\n",
- " drugs.append(row['x_index'])\n",
- " if row['x_type'] == node_type2:\n",
- " diseases.append(row['x_index'])\n",
- " \n",
- " if row['y_type'] == node_type1:\n",
- " drugs.append(row['y_index'])\n",
- " if row['y_type'] == node_type2:\n",
- " diseases.append(row['y_index'])\n",
- " \n",
- "drugs, diseases = list(set(drugs)), list(set(diseases))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Remove drug and disease nodes that do not contribute to at least one indication edge. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "to_remove = df[df['x_type']==node_type1]\n",
- "to_remove = to_remove[~to_remove['x_index'].isin(drugs)]\n",
- "df.drop(to_remove.index, inplace = True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "to_remove = df[df['y_type']==node_type1]\n",
- "to_remove = to_remove[~to_remove['y_index'].isin(drugs)]\n",
- "df.drop(to_remove.index, inplace = True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "to_remove = df[df['x_type']==node_type2]\n",
- "to_remove = to_remove[~to_remove['x_index'].isin(diseases)]\n",
- "df.drop(to_remove.index, inplace = True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "to_remove = df[df['y_type']==node_type2]\n",
- "to_remove = to_remove[~to_remove['y_index'].isin(diseases)]\n",
- "df.drop(to_remove.index, inplace = True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Make HeteroData object for the graph."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "new_df = pd.DataFrame()\n",
- "new_df[0] = df['x_type'] + '::' + df['x_index'].astype(str)\n",
- "new_df[1] = df['relation']\n",
- "new_df[2] = df['y_type'] + '::' +df['y_index'].astype(str)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "df = new_df\n",
- "df = df.drop_duplicates()\n",
- "triplets = df.values.tolist()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "jUcSDffvCtKY"
- },
- "outputs": [],
- "source": [
- "entity_dictionary = {}\n",
- "def insert_entry(entry, ent_type, dic):\n",
- " if ent_type not in dic:\n",
- " dic[ent_type] = {}\n",
- " ent_n_id = len(dic[ent_type])\n",
- " if entry not in dic[ent_type]:\n",
- " dic[ent_type][entry] = ent_n_id\n",
- " return dic\n",
- "\n",
- "for triple in triplets:\n",
- " src = triple[0]\n",
- " split_src = src.split('::')\n",
- " src_type = split_src[0]\n",
- " dest = triple[2]\n",
- " split_dest = dest.split('::')\n",
- " dest_type = split_dest[0]\n",
- " insert_entry(src,src_type,entity_dictionary)\n",
- " insert_entry(dest,dest_type,entity_dictionary)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "vTybNyrqFLrl"
- },
- "outputs": [],
- "source": [
- "edge_dictionary={}\n",
- "for triple in triplets:\n",
- " src = triple[0]\n",
- " split_src = src.split('::')\n",
- " src_type = split_src[0]\n",
- " dest = triple[2]\n",
- " split_dest = dest.split('::')\n",
- " dest_type = split_dest[0]\n",
- " \n",
- " src_int_id = entity_dictionary[src_type][src]\n",
- " dest_int_id = entity_dictionary[dest_type][dest]\n",
- " \n",
- " pair = (src_int_id,dest_int_id)\n",
- " etype = (src_type, triple[1],dest_type)\n",
- " if etype in edge_dictionary:\n",
- " edge_dictionary[etype] += [pair]\n",
- " else:\n",
- " edge_dictionary[etype] = [pair]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "data = HeteroData()\n",
- "\n",
- "for i, key in enumerate(entity_dictionary.keys()):\n",
- " if key != 'drug':\n",
- " data[key].x = (torch.ones((len(entity_dictionary[key]), 768)) * i)\n",
- " elif key == 'drug':\n",
- " data[key].x = (torch.rand((len(entity_dictionary[key]), 767)))\n",
- " \n",
- " data[key].id = torch.arange(len(entity_dictionary[key]))\n",
- "\n",
- "for key in edge_dictionary:\n",
- " data[key].edge_index = torch.transpose(torch.IntTensor(edge_dictionary[key]), 0, 1).long().contiguous()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Add initial embeddings."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "embeddings = pd.read_pickle('../data/entities_embeddings.pkl')\n",
- "smiles_embeddings = pd.read_pickle('../data/smiles_embeddings.pkl')\n",
- "\n",
- "for i, row in smiles_embeddings.iterrows():\n",
- " if row['id'] in entity_dictionary['drug'].keys():\n",
- " drug_id = entity_dictionary['drug'][row['id']]\n",
- " data['drug'].x[drug_id] = torch.Tensor(row['embedding'])\n",
- "\n",
- "for i, row in embeddings.iterrows():\n",
- " x_type = row['id'].split('::')[0]\n",
- " if x_type in data.node_types and row['id'] in entity_dictionary[x_type] and x_type != 'drug':\n",
- " id_ = entity_dictionary[x_type][row['id']]\n",
- " data[x_type].x[id_][:768] = torch.Tensor(row['embedding'])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Load train and validation data of one fold."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "file = open('../data/CV data/train1.pkl', 'rb')\n",
- "train_data = pickle.load(file)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "file = open('../data/CV data/val1.pkl', 'rb')\n",
- "val_data = pickle.load(file)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Creating mask."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "drug_disease_num = train_data[(node_type1, rel, node_type2)]['edge_index'].shape[1]\n",
- "mask = random.sample(range(drug_disease_num), int(drug_disease_num*0.8))\n",
- "train_data[(node_type1, rel, node_type2)]['mask'] = torch.zeros(drug_disease_num, dtype=torch.bool)\n",
- "train_data[(node_type1, rel, node_type2)]['mask'][mask] = True\n",
- "\n",
- "train_data[(node_type2, rel, node_type1)]['mask'] = torch.zeros(drug_disease_num, dtype=torch.bool)\n",
- "train_data[(node_type2, rel, node_type1)]['mask'][mask] = True"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Define model."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "ql-F7A42fMWm"
- },
- "outputs": [],
- "source": [
- "class HGT(nn.Module):\n",
- " def __init__(self, hidden_channels, out_channels, num_heads, num_layers, dropout):\n",
- " super().__init__()\n",
- "\n",
- " self.lin_dict = nn.ModuleDict()\n",
- " for node_type in train_data.node_types:\n",
- " self.lin_dict[node_type] = Linear(-1, hidden_channels[0])\n",
- " \n",
- " self.convs = nn.ModuleList()\n",
- " for i in range(num_layers):\n",
- " conv = HGTConv(hidden_channels[i], hidden_channels[i+1], train_data.metadata(),\n",
- " num_heads[i], group='mean')\n",
- " self.convs.append(conv)\n",
- " \n",
- " self.lin = Linear(sum(hidden_channels[1:]), out_channels)\n",
- " \n",
- " self.dropout = nn.Dropout(dropout)\n",
- "\n",
- " def forward(self, x_dict, edge_index_dict):\n",
- " x_dict = {\n",
- " node_type: self.dropout(self.lin_dict[node_type](x).relu_())\n",
- " for node_type, x in x_dict.items()\n",
- " }\n",
- " out = {}\n",
- " for i, conv in enumerate(self.convs):\n",
- " x_dict = conv(x_dict, edge_index_dict)\n",
- "\n",
- " if out=={}:\n",
- " out = copy.copy(x_dict)\n",
- " else:\n",
- " out = {\n",
- " node_type: torch.cat((out[node_type], x_dict[node_type]), dim=1)\n",
- " for node_type, x in x_dict.items()\n",
- " }\n",
- "\n",
- " return F.relu(self.lin(out[node_type1])), F.relu(self.lin(out[node_type2]))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "atdEjxJqvLaN"
- },
- "outputs": [],
- "source": [
- "class MLPPredictor(nn.Module):\n",
- " def __init__(self, channel_num, dropout):\n",
- " super().__init__()\n",
- " self.L1 = nn.Linear(channel_num * 2, channel_num)\n",
- " self.L2 = nn.Linear(channel_num, 1)\n",
- " self.bn = nn.BatchNorm1d(num_features=channel_num)\n",
- " self.dropout = nn.Dropout(0.2)\n",
- "\n",
- " def forward(self, drug_embeddings, disease_embeddings):\n",
- " x = torch.cat((drug_embeddings, disease_embeddings), dim=1)\n",
- " x = F.relu(self.bn(self.L1(x)))\n",
- " x = self.dropout(x)\n",
- " x = self.L2(x)\n",
- " return x"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "6fHs5rX76ldq"
- },
- "outputs": [],
- "source": [
- "def compute_loss(scores, labels):\n",
- " pos_weights = torch.clone(labels)\n",
- " pos_weights[pos_weights == 1] = ((labels==0).sum() / labels.shape[0])\n",
- " pos_weights[pos_weights == 0] = ((labels==1).sum() / labels.shape[0])\n",
- " \n",
- " return F.binary_cross_entropy_with_logits(scores, labels, pos_weight=pos_weights)\n",
- "# return F.binary_cross_entropy_with_logits(scores, labels)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "nmLsh9VigPpI"
- },
- "outputs": [],
- "source": [
- "def define_model(dropout):\n",
- " GNN = HGT(hidden_channels=[64, 64, 64, 64],\n",
- " out_channels=64,\n",
- " num_heads=[8, 8, 8],\n",
- " num_layers=3,\n",
- " dropout=dropout)\n",
- "\n",
- " pred = MLPPredictor(64, dropout)\n",
- " model = nn.Sequential(GNN, pred)\n",
- " model.to(device)\n",
- " \n",
- " return GNN, pred, model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "kk5vWUiQV7oi"
- },
- "outputs": [],
- "source": [
- "def define_loaders(config):\n",
- " kwargs = {'batch_size': config['batch_size'], 'num_workers': 8, 'persistent_workers': True}\n",
- " \n",
- " train_loader = HGTLoader(train_data, num_samples=[config['num_samples']] * 3, shuffle=True, input_nodes=(node_type1, None), **kwargs)\n",
- " val_loader = HGTLoader(val_data, num_samples=[config['num_samples']] * 3, shuffle=True, input_nodes=(node_type1, None), **kwargs)\n",
- " return train_loader, val_loader"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def edge_exists(edges, edge):\n",
- " edges = edges.to(device)\n",
- " edge = edge.to(device)\n",
- " return (edges == edge).all(dim=0).sum() > 0"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Make batches."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def make_batch(batch):\n",
- " \n",
- " batch_size = batch[node_type1].batch_size\n",
- " edge_index = batch[(node_type1, rel, node_type2)]['edge_index']\n",
- " mask = batch[(node_type1, rel, node_type2)]['mask'] \n",
- " \n",
- " batch_index = (edge_index[0] < batch_size)\n",
- " edge_index = edge_index[:, batch_index]\n",
- " mask = mask[batch_index]\n",
- " edge_label_index = edge_index[:, mask]\n",
- " pos_num = edge_label_index.shape[1]\n",
- " edge_label = torch.ones(pos_num)\n",
- " \n",
- " neg_edges_source = []\n",
- " neg_edges_dest = []\n",
- " while len(neg_edges_source) < pos_num:\n",
- " source = random.randint(0, batch_size-1)\n",
- " dest = random.randint(0, batch[node_type2].x.shape[0]-1)\n",
- " neg_edge = torch.Tensor([[source], [dest]])\n",
- " if edge_exists(edge_index, neg_edge):\n",
- " continue\n",
- " else:\n",
- " neg_edges_source.append(source)\n",
- " neg_edges_dest.append(dest)\n",
- " \n",
- " neg_edges = torch.tensor([neg_edges_source, neg_edges_dest])\n",
- " edge_label_index = torch.cat((edge_label_index, neg_edges), dim=1)\n",
- " edge_label = torch.cat((edge_label, torch.zeros(neg_edges.shape[1])), dim=0)\n",
- " edge_index = edge_index[:, ~mask]\n",
- "\n",
- " batch[(node_type1, rel, node_type2)]['edge_index'] = edge_index\n",
- " batch[(node_type1, rel, node_type2)]['edge_label_index'] = edge_label_index\n",
- " batch[(node_type1, rel, node_type2)]['edge_label'] = edge_label\n",
- " \n",
- " batch[(node_type2, rel, node_type1)]['edge_index'] = edge_index\n",
- " temp = copy.copy(batch[(node_type2, rel, node_type1)]['edge_index'][0])\n",
- " batch[(node_type2, rel, node_type1)]['edge_index'][0] = batch[(node_type2, rel, node_type1)]['edge_index'][1]\n",
- " batch[(node_type2, rel, node_type1)]['edge_index'][1] = temp\n",
- " \n",
- " return batch"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def make_test_batch(batch):\n",
- " \n",
- " batch_size = batch[node_type1].batch_size\n",
- " edge_index = batch[(node_type1, rel, node_type2)]['edge_index']\n",
- " edge_label_index = batch[(node_type1, rel, node_type2)]['edge_label_index']\n",
- " edge_label = batch[(node_type1, rel, node_type2)]['edge_label']\n",
- " \n",
- " source = []\n",
- " dest = []\n",
- " labels = []\n",
- " for i in range(edge_label_index.shape[1]):\n",
- " if edge_label_index[0, i] in batch[node_type1]['id'] and edge_label_index[1, i] in batch[node_type2]['id'] \\\n",
- " and ((batch[node_type1]['id'] == edge_label_index[0, i]).nonzero(as_tuple=True)[0]) < batch_size:\n",
- " if edge_label[i] == 1:\n",
- " source.append((batch[node_type1]['id'] == edge_label_index[0, i]).nonzero(as_tuple=True)[0])\n",
- " dest.append((batch[node_type2]['id'] == edge_label_index[1, i]).nonzero(as_tuple=True)[0])\n",
- "\n",
- " edge_label_index = torch.zeros(2, len(source)).long()\n",
- " edge_label_index[0] = torch.tensor(source)\n",
- " edge_label_index[1] = torch.tensor(dest)\n",
- " pos_num = edge_label_index.shape[1]\n",
- " edge_label = torch.ones(pos_num)\n",
- " \n",
- " neg_edges_source = []\n",
- " neg_edges_dest = []\n",
- " while len(neg_edges_source) < pos_num:\n",
- " source_node = random.randint(0, batch_size-1)\n",
- " dest_node = random.randint(0, batch[node_type2].x.shape[0]-1)\n",
- " neg_edge = torch.Tensor([[source_node], [dest_node]])\n",
- " neg_edge_in_orig_graph = torch.Tensor([[batch[node_type1]['id'][source_node]], [batch[node_type2]['id'][dest_node]]])\n",
- " if edge_exists(data[(node_type1, rel, node_type2)]['edge_index'], neg_edge_in_orig_graph):\n",
- " continue\n",
- " else:\n",
- " neg_edges_source.append(source_node)\n",
- " neg_edges_dest.append(dest_node)\n",
- "\n",
- " neg_edges = torch.tensor([neg_edges_source, neg_edges_dest])\n",
- " edge_label_index = torch.cat((edge_label_index, neg_edges), dim=1)\n",
- " edge_label = torch.cat((edge_label, torch.zeros(neg_edges.shape[1])), dim=0)\n",
- "\n",
- " batch[(node_type1, rel, node_type2)]['edge_label_index'] = edge_label_index\n",
- " batch[(node_type1, rel, node_type2)]['edge_label'] = edge_label\n",
- "\n",
- " return batch"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Train"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "51k5xvYymLvw"
- },
- "outputs": [],
- "source": [
- "def train(GNN, pred, model, loader, optimizer):\n",
- " model.train()\n",
- " total_examples = total_loss = 0\n",
- " for i, batch in enumerate(iter(loader)):\n",
- " optimizer.zero_grad()\n",
- " batch = make_batch(batch)\n",
- " batch = batch.to(device)\n",
- " edge_label_index = batch[(node_type1, rel, node_type2)]['edge_label_index']\n",
- " edge_label = batch[(node_type1, rel, node_type2)]['edge_label']\n",
- " if edge_label.shape[0] == 0:\n",
- " continue\n",
- " \n",
- " drug_embeddings, disease_embeddings = GNN(batch.x_dict, batch.edge_index_dict)\n",
- " \n",
- " c = drug_embeddings[edge_label_index[0]]\n",
- " d = disease_embeddings[edge_label_index[1]]\n",
- " out = pred(c, d)[:, 0]\n",
- " loss = compute_loss(out, edge_label)\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- "\n",
- " total_examples += edge_label_index.shape[1]\n",
- " total_loss += float(loss) * edge_label_index.shape[1]\n",
- "\n",
- " return total_loss / total_examples"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Test"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Vyvi80_Wo4GE"
- },
- "outputs": [],
- "source": [
- "@torch.no_grad()\n",
- "def test(GNN, pred, model, loader):\n",
- " model.eval()\n",
- "\n",
- " total_examples = total_correct = 0\n",
- " out, labels = torch.tensor([]).to(device), torch.tensor([]).to(device)\n",
- " source, dest = torch.tensor([]).to(device), torch.tensor([]).to(device)\n",
- " for batch in iter(loader):\n",
- " batch = make_test_batch(batch)\n",
- " batch = batch.to(device)\n",
- " drug_embeddings, disease_embeddings = GNN(batch.x_dict, batch.edge_index_dict)\n",
- " \n",
- " edge_label_index = batch[(node_type1, rel, node_type2)]['edge_label_index']\n",
- " edge_label = batch[(node_type1, rel, node_type2)]['edge_label']\n",
- " \n",
- " if edge_label.shape[0] == 0:\n",
- " continue\n",
- " \n",
- " c = drug_embeddings[edge_label_index[0]]\n",
- " d = disease_embeddings[edge_label_index[1]]\n",
- " batch_out = pred(c, d)[:, 0]\n",
- " labels = torch.cat((labels, edge_label))\n",
- " out = torch.cat((out, batch_out))\n",
- " \n",
- " drugs = batch[node_type1]['id'][edge_label_index[0]]\n",
- " diseases = batch[node_type2]['id'][edge_label_index[1]]\n",
- " source = torch.cat((source, drugs))\n",
- " dest = torch.cat((dest, diseases))\n",
- "\n",
- " loss = compute_loss(out, labels) \n",
- " return out, labels, source, dest, loss.cpu().numpy()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Run"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def run(config):\n",
- " losses, val_losses = [], []\n",
- " best_val_loss = float('inf')\n",
- " best_epoch = 0\n",
- " \n",
- " train_loader, val_loader = define_loaders(config)\n",
- " GNN, pred, model = define_model(config['dropout'])\n",
- " \n",
- " optimizer = torch.optim.AdamW(model.parameters())\n",
- " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, \n",
- " T_max=config['epochs'], \n",
- " eta_min=0, \n",
- " last_epoch=-1, \n",
- " verbose=False)\n",
- " \n",
- " for epoch in range(config['epochs']):\n",
- " loss = train(GNN, pred, model, train_loader, optimizer)\n",
- " out, labels, source, dest, val_loss = test(GNN, pred, model, val_loader)\n",
- " write_to_out(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, ValLoss: {val_loss:.4f} \\n')\n",
- " losses.append(loss)\n",
- " val_losses.append(val_loss)\n",
- " plot_losses(losses, val_losses)\n",
- "\n",
- " scheduler.step()\n",
- " \n",
- " torch.save(model.state_dict(), '../out/saved_model.h5')\n",
- " \n",
- " out, labels, source, dest, val_loss = test(GNN, pred, model, val_loader)\n",
- " AUPR(out, labels)\n",
- " AUROC(out, labels)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "run(config)"
- ]
- }
- ],
- "metadata": {
- "colab": {
- "collapsed_sections": [],
- "name": "geo-pyHGT.ipynb",
- "provenance": []
- },
- "gpuClass": "standard",
- "kernelspec": {
- "display_name": "basee",
- "language": "python",
- "name": "basee"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 1
- }
|