Browse Source

final

master
Arash Zolghadr 8 months ago
commit
bfeee4f944
3 changed files with 835 additions and 0 deletions
  1. 18
    0
      README.md
  2. 769
    0
      src/HGTDR.ipynb
  3. 48
    0
      src/utils.py

+ 18
- 0
README.md View File

@@ -0,0 +1,18 @@
# Source code for "HGTDR: Heterogeneous Knowledge Graph-Based Drug Repurposing with Heterogeneous Graph Transformer"
Requirements:

torch==1.13.1

torch-geometric==2.3.1

torch-sparse==0.6.17

torch-scatter==2.1.1

pandas==1.4.2

matplotlib==3.5.1

scikit-learn==1.0.2

pickle5

+ 769
- 0
src/HGTDR.ipynb View File

@@ -0,0 +1,769 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Nx8IcGZHAKn_"
},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Dhv1DYPL3Vm1"
},
"outputs": [],
"source": [
"from torch_geometric.nn import HGTConv, Linear\n",
"from torch_geometric.loader import HGTLoader\n",
"from torch_geometric.data import HeteroData\n",
"import torch.nn.functional as F\n",
"import pickle5 as pickle\n",
"import torch.nn as nn\n",
"import pandas as pd\n",
"from utils import *\n",
"import random\n",
"import torch\n",
"import copy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
"node_type1 = 'drug'\n",
"node_type2 = 'disease'\n",
"rel = 'indication'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oOEkKAITR8tc"
},
"outputs": [],
"source": [
"config = {\n",
" \"num_samples\": 512,\n",
" \"batch_size\": 164,\n",
" \"dropout\": 0.5,\n",
" \"epochs\": 300\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-cRmQ9cEAO_K"
},
"source": [
"# Load data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uf_2DpGQCCFJ"
},
"outputs": [],
"source": [
"primekg_file = '../data/kg.csv'\n",
"df = pd.read_csv(primekg_file, sep =\",\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get drugs and diseases which are used in indication relation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"drug_disease_pairs = df[df['relation']==rel]\n",
"drugs, diseases = [], []\n",
"\n",
"for i, row in drug_disease_pairs.iterrows():\n",
" if row['x_type'] == node_type1:\n",
" drugs.append(row['x_index'])\n",
" if row['x_type'] == node_type2:\n",
" diseases.append(row['x_index'])\n",
" \n",
" if row['y_type'] == node_type1:\n",
" drugs.append(row['y_index'])\n",
" if row['y_type'] == node_type2:\n",
" diseases.append(row['y_index'])\n",
" \n",
"drugs, diseases = list(set(drugs)), list(set(diseases))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Remove drug and disease nodes that do not contribute to at least one indication edge. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"to_remove = df[df['x_type']==node_type1]\n",
"to_remove = to_remove[~to_remove['x_index'].isin(drugs)]\n",
"df.drop(to_remove.index, inplace = True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"to_remove = df[df['y_type']==node_type1]\n",
"to_remove = to_remove[~to_remove['y_index'].isin(drugs)]\n",
"df.drop(to_remove.index, inplace = True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"to_remove = df[df['x_type']==node_type2]\n",
"to_remove = to_remove[~to_remove['x_index'].isin(diseases)]\n",
"df.drop(to_remove.index, inplace = True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"to_remove = df[df['y_type']==node_type2]\n",
"to_remove = to_remove[~to_remove['y_index'].isin(diseases)]\n",
"df.drop(to_remove.index, inplace = True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Make HeteroData object for the graph."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"new_df = pd.DataFrame()\n",
"new_df[0] = df['x_type'] + '::' + df['x_index'].astype(str)\n",
"new_df[1] = df['relation']\n",
"new_df[2] = df['y_type'] + '::' +df['y_index'].astype(str)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = new_df\n",
"df = df.drop_duplicates()\n",
"triplets = df.values.tolist()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jUcSDffvCtKY"
},
"outputs": [],
"source": [
"entity_dictionary = {}\n",
"def insert_entry(entry, ent_type, dic):\n",
" if ent_type not in dic:\n",
" dic[ent_type] = {}\n",
" ent_n_id = len(dic[ent_type])\n",
" if entry not in dic[ent_type]:\n",
" dic[ent_type][entry] = ent_n_id\n",
" return dic\n",
"\n",
"for triple in triplets:\n",
" src = triple[0]\n",
" split_src = src.split('::')\n",
" src_type = split_src[0]\n",
" dest = triple[2]\n",
" split_dest = dest.split('::')\n",
" dest_type = split_dest[0]\n",
" insert_entry(src,src_type,entity_dictionary)\n",
" insert_entry(dest,dest_type,entity_dictionary)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vTybNyrqFLrl"
},
"outputs": [],
"source": [
"edge_dictionary={}\n",
"for triple in triplets:\n",
" src = triple[0]\n",
" split_src = src.split('::')\n",
" src_type = split_src[0]\n",
" dest = triple[2]\n",
" split_dest = dest.split('::')\n",
" dest_type = split_dest[0]\n",
" \n",
" src_int_id = entity_dictionary[src_type][src]\n",
" dest_int_id = entity_dictionary[dest_type][dest]\n",
" \n",
" pair = (src_int_id,dest_int_id)\n",
" etype = (src_type, triple[1],dest_type)\n",
" if etype in edge_dictionary:\n",
" edge_dictionary[etype] += [pair]\n",
" else:\n",
" edge_dictionary[etype] = [pair]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = HeteroData()\n",
"\n",
"for i, key in enumerate(entity_dictionary.keys()):\n",
" if key != 'drug':\n",
" data[key].x = (torch.ones((len(entity_dictionary[key]), 768)) * i)\n",
" elif key == 'drug':\n",
" data[key].x = (torch.rand((len(entity_dictionary[key]), 767)))\n",
" \n",
" data[key].id = torch.arange(len(entity_dictionary[key]))\n",
"\n",
"for key in edge_dictionary:\n",
" data[key].edge_index = torch.transpose(torch.IntTensor(edge_dictionary[key]), 0, 1).long().contiguous()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Add initial embeddings."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embeddings = pd.read_pickle('../data/entities_embeddings.pkl')\n",
"smiles_embeddings = pd.read_pickle('../data/smiles_embeddings.pkl')\n",
"\n",
"for i, row in smiles_embeddings.iterrows():\n",
" if row['id'] in entity_dictionary['drug'].keys():\n",
" drug_id = entity_dictionary['drug'][row['id']]\n",
" data['drug'].x[drug_id] = torch.Tensor(row['embedding'])\n",
"\n",
"for i, row in embeddings.iterrows():\n",
" x_type = row['id'].split('::')[0]\n",
" if x_type in data.node_types and row['id'] in entity_dictionary[x_type] and x_type != 'drug':\n",
" id_ = entity_dictionary[x_type][row['id']]\n",
" data[x_type].x[id_][:768] = torch.Tensor(row['embedding'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load train and validation data of one fold."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"file = open('../data/CV data/train1.pkl', 'rb')\n",
"train_data = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"file = open('../data/CV data/val1.pkl', 'rb')\n",
"val_data = pickle.load(file)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating mask."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"drug_disease_num = train_data[(node_type1, rel, node_type2)]['edge_index'].shape[1]\n",
"mask = random.sample(range(drug_disease_num), int(drug_disease_num*0.8))\n",
"train_data[(node_type1, rel, node_type2)]['mask'] = torch.zeros(drug_disease_num, dtype=torch.bool)\n",
"train_data[(node_type1, rel, node_type2)]['mask'][mask] = True\n",
"\n",
"train_data[(node_type2, rel, node_type1)]['mask'] = torch.zeros(drug_disease_num, dtype=torch.bool)\n",
"train_data[(node_type2, rel, node_type1)]['mask'][mask] = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ql-F7A42fMWm"
},
"outputs": [],
"source": [
"class HGT(nn.Module):\n",
" def __init__(self, hidden_channels, out_channels, num_heads, num_layers, dropout):\n",
" super().__init__()\n",
"\n",
" self.lin_dict = nn.ModuleDict()\n",
" for node_type in train_data.node_types:\n",
" self.lin_dict[node_type] = Linear(-1, hidden_channels[0])\n",
" \n",
" self.convs = nn.ModuleList()\n",
" for i in range(num_layers):\n",
" conv = HGTConv(hidden_channels[i], hidden_channels[i+1], train_data.metadata(),\n",
" num_heads[i], group='mean')\n",
" self.convs.append(conv)\n",
" \n",
" self.lin = Linear(sum(hidden_channels[1:]), out_channels)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x_dict, edge_index_dict):\n",
" x_dict = {\n",
" node_type: self.dropout(self.lin_dict[node_type](x).relu_())\n",
" for node_type, x in x_dict.items()\n",
" }\n",
" out = {}\n",
" for i, conv in enumerate(self.convs):\n",
" x_dict = conv(x_dict, edge_index_dict)\n",
"\n",
" if out=={}:\n",
" out = copy.copy(x_dict)\n",
" else:\n",
" out = {\n",
" node_type: torch.cat((out[node_type], x_dict[node_type]), dim=1)\n",
" for node_type, x in x_dict.items()\n",
" }\n",
"\n",
" return F.relu(self.lin(out[node_type1])), F.relu(self.lin(out[node_type2]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "atdEjxJqvLaN"
},
"outputs": [],
"source": [
"class MLPPredictor(nn.Module):\n",
" def __init__(self, channel_num, dropout):\n",
" super().__init__()\n",
" self.L1 = nn.Linear(channel_num * 2, channel_num)\n",
" self.L2 = nn.Linear(channel_num, 1)\n",
" self.bn = nn.BatchNorm1d(num_features=channel_num)\n",
" self.dropout = nn.Dropout(0.2)\n",
"\n",
" def forward(self, drug_embeddings, disease_embeddings):\n",
" x = torch.cat((drug_embeddings, disease_embeddings), dim=1)\n",
" x = F.relu(self.bn(self.L1(x)))\n",
" x = self.dropout(x)\n",
" x = self.L2(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6fHs5rX76ldq"
},
"outputs": [],
"source": [
"def compute_loss(scores, labels):\n",
" pos_weights = torch.clone(labels)\n",
" pos_weights[pos_weights == 1] = ((labels==0).sum() / labels.shape[0])\n",
" pos_weights[pos_weights == 0] = ((labels==1).sum() / labels.shape[0])\n",
" \n",
" return F.binary_cross_entropy_with_logits(scores, labels, pos_weight=pos_weights)\n",
"# return F.binary_cross_entropy_with_logits(scores, labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nmLsh9VigPpI"
},
"outputs": [],
"source": [
"def define_model(dropout):\n",
" GNN = HGT(hidden_channels=[64, 64, 64, 64],\n",
" out_channels=64,\n",
" num_heads=[8, 8, 8],\n",
" num_layers=3,\n",
" dropout=dropout)\n",
"\n",
" pred = MLPPredictor(64, dropout)\n",
" model = nn.Sequential(GNN, pred)\n",
" model.to(device)\n",
" \n",
" return GNN, pred, model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kk5vWUiQV7oi"
},
"outputs": [],
"source": [
"def define_loaders(config):\n",
" kwargs = {'batch_size': config['batch_size'], 'num_workers': 8, 'persistent_workers': True}\n",
" \n",
" train_loader = HGTLoader(train_data, num_samples=[config['num_samples']] * 3, shuffle=True, input_nodes=(node_type1, None), **kwargs)\n",
" val_loader = HGTLoader(val_data, num_samples=[config['num_samples']] * 3, shuffle=True, input_nodes=(node_type1, None), **kwargs)\n",
" return train_loader, val_loader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def edge_exists(edges, edge):\n",
" edges = edges.to(device)\n",
" edge = edge.to(device)\n",
" return (edges == edge).all(dim=0).sum() > 0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Make batches."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def make_batch(batch):\n",
" \n",
" batch_size = batch[node_type1].batch_size\n",
" edge_index = batch[(node_type1, rel, node_type2)]['edge_index']\n",
" mask = batch[(node_type1, rel, node_type2)]['mask'] \n",
" \n",
" batch_index = (edge_index[0] < batch_size)\n",
" edge_index = edge_index[:, batch_index]\n",
" mask = mask[batch_index]\n",
" edge_label_index = edge_index[:, mask]\n",
" pos_num = edge_label_index.shape[1]\n",
" edge_label = torch.ones(pos_num)\n",
" \n",
" neg_edges_source = []\n",
" neg_edges_dest = []\n",
" while len(neg_edges_source) < pos_num:\n",
" source = random.randint(0, batch_size-1)\n",
" dest = random.randint(0, batch[node_type2].x.shape[0]-1)\n",
" neg_edge = torch.Tensor([[source], [dest]])\n",
" if edge_exists(edge_index, neg_edge):\n",
" continue\n",
" else:\n",
" neg_edges_source.append(source)\n",
" neg_edges_dest.append(dest)\n",
" \n",
" neg_edges = torch.tensor([neg_edges_source, neg_edges_dest])\n",
" edge_label_index = torch.cat((edge_label_index, neg_edges), dim=1)\n",
" edge_label = torch.cat((edge_label, torch.zeros(neg_edges.shape[1])), dim=0)\n",
" edge_index = edge_index[:, ~mask]\n",
"\n",
" batch[(node_type1, rel, node_type2)]['edge_index'] = edge_index\n",
" batch[(node_type1, rel, node_type2)]['edge_label_index'] = edge_label_index\n",
" batch[(node_type1, rel, node_type2)]['edge_label'] = edge_label\n",
" \n",
" batch[(node_type2, rel, node_type1)]['edge_index'] = edge_index\n",
" temp = copy.copy(batch[(node_type2, rel, node_type1)]['edge_index'][0])\n",
" batch[(node_type2, rel, node_type1)]['edge_index'][0] = batch[(node_type2, rel, node_type1)]['edge_index'][1]\n",
" batch[(node_type2, rel, node_type1)]['edge_index'][1] = temp\n",
" \n",
" return batch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def make_test_batch(batch):\n",
" \n",
" batch_size = batch[node_type1].batch_size\n",
" edge_index = batch[(node_type1, rel, node_type2)]['edge_index']\n",
" edge_label_index = batch[(node_type1, rel, node_type2)]['edge_label_index']\n",
" edge_label = batch[(node_type1, rel, node_type2)]['edge_label']\n",
" \n",
" source = []\n",
" dest = []\n",
" labels = []\n",
" for i in range(edge_label_index.shape[1]):\n",
" if edge_label_index[0, i] in batch[node_type1]['id'] and edge_label_index[1, i] in batch[node_type2]['id'] \\\n",
" and ((batch[node_type1]['id'] == edge_label_index[0, i]).nonzero(as_tuple=True)[0]) < batch_size:\n",
" if edge_label[i] == 1:\n",
" source.append((batch[node_type1]['id'] == edge_label_index[0, i]).nonzero(as_tuple=True)[0])\n",
" dest.append((batch[node_type2]['id'] == edge_label_index[1, i]).nonzero(as_tuple=True)[0])\n",
"\n",
" edge_label_index = torch.zeros(2, len(source)).long()\n",
" edge_label_index[0] = torch.tensor(source)\n",
" edge_label_index[1] = torch.tensor(dest)\n",
" pos_num = edge_label_index.shape[1]\n",
" edge_label = torch.ones(pos_num)\n",
" \n",
" neg_edges_source = []\n",
" neg_edges_dest = []\n",
" while len(neg_edges_source) < pos_num:\n",
" source_node = random.randint(0, batch_size-1)\n",
" dest_node = random.randint(0, batch[node_type2].x.shape[0]-1)\n",
" neg_edge = torch.Tensor([[source_node], [dest_node]])\n",
" neg_edge_in_orig_graph = torch.Tensor([[batch[node_type1]['id'][source_node]], [batch[node_type2]['id'][dest_node]]])\n",
" if edge_exists(data[(node_type1, rel, node_type2)]['edge_index'], neg_edge_in_orig_graph):\n",
" continue\n",
" else:\n",
" neg_edges_source.append(source_node)\n",
" neg_edges_dest.append(dest_node)\n",
"\n",
" neg_edges = torch.tensor([neg_edges_source, neg_edges_dest])\n",
" edge_label_index = torch.cat((edge_label_index, neg_edges), dim=1)\n",
" edge_label = torch.cat((edge_label, torch.zeros(neg_edges.shape[1])), dim=0)\n",
"\n",
" batch[(node_type1, rel, node_type2)]['edge_label_index'] = edge_label_index\n",
" batch[(node_type1, rel, node_type2)]['edge_label'] = edge_label\n",
"\n",
" return batch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "51k5xvYymLvw"
},
"outputs": [],
"source": [
"def train(GNN, pred, model, loader, optimizer):\n",
" model.train()\n",
" total_examples = total_loss = 0\n",
" for i, batch in enumerate(iter(loader)):\n",
" optimizer.zero_grad()\n",
" batch = make_batch(batch)\n",
" batch = batch.to(device)\n",
" edge_label_index = batch[(node_type1, rel, node_type2)]['edge_label_index']\n",
" edge_label = batch[(node_type1, rel, node_type2)]['edge_label']\n",
" if edge_label.shape[0] == 0:\n",
" continue\n",
" \n",
" drug_embeddings, disease_embeddings = GNN(batch.x_dict, batch.edge_index_dict)\n",
" \n",
" c = drug_embeddings[edge_label_index[0]]\n",
" d = disease_embeddings[edge_label_index[1]]\n",
" out = pred(c, d)[:, 0]\n",
" loss = compute_loss(out, edge_label)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_examples += edge_label_index.shape[1]\n",
" total_loss += float(loss) * edge_label_index.shape[1]\n",
"\n",
" return total_loss / total_examples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vyvi80_Wo4GE"
},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def test(GNN, pred, model, loader):\n",
" model.eval()\n",
"\n",
" total_examples = total_correct = 0\n",
" out, labels = torch.tensor([]).to(device), torch.tensor([]).to(device)\n",
" source, dest = torch.tensor([]).to(device), torch.tensor([]).to(device)\n",
" for batch in iter(loader):\n",
" batch = make_test_batch(batch)\n",
" batch = batch.to(device)\n",
" drug_embeddings, disease_embeddings = GNN(batch.x_dict, batch.edge_index_dict)\n",
" \n",
" edge_label_index = batch[(node_type1, rel, node_type2)]['edge_label_index']\n",
" edge_label = batch[(node_type1, rel, node_type2)]['edge_label']\n",
" \n",
" if edge_label.shape[0] == 0:\n",
" continue\n",
" \n",
" c = drug_embeddings[edge_label_index[0]]\n",
" d = disease_embeddings[edge_label_index[1]]\n",
" batch_out = pred(c, d)[:, 0]\n",
" labels = torch.cat((labels, edge_label))\n",
" out = torch.cat((out, batch_out))\n",
" \n",
" drugs = batch[node_type1]['id'][edge_label_index[0]]\n",
" diseases = batch[node_type2]['id'][edge_label_index[1]]\n",
" source = torch.cat((source, drugs))\n",
" dest = torch.cat((dest, diseases))\n",
"\n",
" loss = compute_loss(out, labels) \n",
" return out, labels, source, dest, loss.cpu().numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def run(config):\n",
" losses, val_losses = [], []\n",
" best_val_loss = float('inf')\n",
" best_epoch = 0\n",
" \n",
" train_loader, val_loader = define_loaders(config)\n",
" GNN, pred, model = define_model(config['dropout'])\n",
" \n",
" optimizer = torch.optim.AdamW(model.parameters())\n",
" scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, \n",
" T_max=config['epochs'], \n",
" eta_min=0, \n",
" last_epoch=-1, \n",
" verbose=False)\n",
" \n",
" for epoch in range(config['epochs']):\n",
" loss = train(GNN, pred, model, train_loader, optimizer)\n",
" out, labels, source, dest, val_loss = test(GNN, pred, model, val_loader)\n",
" write_to_out(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, ValLoss: {val_loss:.4f} \\n')\n",
" losses.append(loss)\n",
" val_losses.append(val_loss)\n",
" plot_losses(losses, val_losses)\n",
"\n",
" scheduler.step()\n",
" \n",
" torch.save(model.state_dict(), '../out/saved_model.h5')\n",
" \n",
" out, labels, source, dest, val_loss = test(GNN, pred, model, val_loader)\n",
" AUPR(out, labels)\n",
" AUROC(out, labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run(config)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "geo-pyHGT.ipynb",
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "basee",
"language": "python",
"name": "basee"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

+ 48
- 0
src/utils.py View File

@@ -0,0 +1,48 @@
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc
from matplotlib import pyplot

def AUROC(scores, labels):
scores = scores.cpu().numpy()
labels = labels.cpu().numpy()

ns_probs = [0 for _ in range(len(labels))]
lr_auc = roc_auc_score(labels, scores)
write_to_out('AUROC: %.3f \n' % (lr_auc))
ns_fpr, ns_tpr, _ = roc_curve(labels, ns_probs)
lr_fpr, lr_tpr, _ = roc_curve(labels, scores)
pyplot.plot(ns_fpr, ns_tpr, linestyle='--', label='No Skill')
pyplot.plot(lr_fpr, lr_tpr, label='Logistic')
pyplot.xlabel('False Positive Rate')
pyplot.ylabel('True Positive Rate')
pyplot.legend()
pyplot.savefig('../out/AUROC', dpi=180)
pyplot.show()

def AUPR(scores, labels):
scores = scores.cpu().numpy()
labels = labels.cpu().numpy()

lr_precision, lr_recall, _ = precision_recall_curve(labels, scores)
lr_auc = auc(lr_recall, lr_precision)
write_to_out('AUPR: %.3f \n' % (lr_auc))
no_skill = len(labels[labels==1]) / len(labels)
pyplot.plot([0, 1], [no_skill, no_skill], linestyle='--', label='No Skill')
pyplot.plot(lr_recall, lr_precision, label='HGT')
pyplot.xlabel('Recall')
pyplot.ylabel('Precision')
pyplot.legend()
pyplot.savefig('../out/AUPR', dpi=180)
pyplot.show()
def plot_losses(losses, val_losses):
pyplot.plot(range(len(losses)), losses, label="loss")
pyplot.plot(range(len(losses)), val_losses, label="val_loss")
pyplot.legend()
pyplot.savefig('../out/losses', dpi=200)
pyplot.clf()
def write_to_out(text):
print(text)
out_file = open('../out/out.txt', 'a')
out_file.write(text)
out_file.close()

Loading…
Cancel
Save