You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

HGTDR.ipynb 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {
  6. "id": "Nx8IcGZHAKn_"
  7. },
  8. "source": [
  9. "# Imports"
  10. ]
  11. },
  12. {
  13. "cell_type": "code",
  14. "execution_count": null,
  15. "metadata": {
  16. "id": "Dhv1DYPL3Vm1"
  17. },
  18. "outputs": [],
  19. "source": [
  20. "from torch_geometric.nn import HGTConv, Linear\n",
  21. "from torch_geometric.loader import HGTLoader\n",
  22. "from torch_geometric.data import HeteroData\n",
  23. "import torch.nn.functional as F\n",
  24. "import pickle5 as pickle\n",
  25. "import torch.nn as nn\n",
  26. "import pandas as pd\n",
  27. "from utils import *\n",
  28. "import random\n",
  29. "import torch\n",
  30. "import copy"
  31. ]
  32. },
  33. {
  34. "cell_type": "code",
  35. "execution_count": null,
  36. "metadata": {},
  37. "outputs": [],
  38. "source": [
  39. "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
  40. "node_type1 = 'drug'\n",
  41. "node_type2 = 'disease'\n",
  42. "rel = 'indication'"
  43. ]
  44. },
  45. {
  46. "cell_type": "code",
  47. "execution_count": null,
  48. "metadata": {
  49. "id": "oOEkKAITR8tc"
  50. },
  51. "outputs": [],
  52. "source": [
  53. "config = {\n",
  54. " \"num_samples\": 512,\n",
  55. " \"batch_size\": 164,\n",
  56. " \"dropout\": 0.5,\n",
  57. " \"epochs\": 300\n",
  58. "}"
  59. ]
  60. },
  61. {
  62. "cell_type": "markdown",
  63. "metadata": {
  64. "id": "-cRmQ9cEAO_K"
  65. },
  66. "source": [
  67. "# Load data"
  68. ]
  69. },
  70. {
  71. "cell_type": "code",
  72. "execution_count": null,
  73. "metadata": {
  74. "id": "uf_2DpGQCCFJ"
  75. },
  76. "outputs": [],
  77. "source": [
  78. "primekg_file = '../data/kg.csv'\n",
  79. "df = pd.read_csv(primekg_file, sep =\",\")"
  80. ]
  81. },
  82. {
  83. "cell_type": "markdown",
  84. "metadata": {},
  85. "source": [
  86. "### Get drugs and diseases which are used in indication relation."
  87. ]
  88. },
  89. {
  90. "cell_type": "code",
  91. "execution_count": null,
  92. "metadata": {},
  93. "outputs": [],
  94. "source": [
  95. "drug_disease_pairs = df[df['relation']==rel]\n",
  96. "drugs, diseases = [], []\n",
  97. "\n",
  98. "for i, row in drug_disease_pairs.iterrows():\n",
  99. " if row['x_type'] == node_type1:\n",
  100. " drugs.append(row['x_index'])\n",
  101. " if row['x_type'] == node_type2:\n",
  102. " diseases.append(row['x_index'])\n",
  103. " \n",
  104. " if row['y_type'] == node_type1:\n",
  105. " drugs.append(row['y_index'])\n",
  106. " if row['y_type'] == node_type2:\n",
  107. " diseases.append(row['y_index'])\n",
  108. " \n",
  109. "drugs, diseases = list(set(drugs)), list(set(diseases))"
  110. ]
  111. },
  112. {
  113. "cell_type": "markdown",
  114. "metadata": {},
  115. "source": [
  116. "### Remove drug and disease nodes that do not contribute to at least one indication edge. "
  117. ]
  118. },
  119. {
  120. "cell_type": "code",
  121. "execution_count": null,
  122. "metadata": {},
  123. "outputs": [],
  124. "source": [
  125. "to_remove = df[df['x_type']==node_type1]\n",
  126. "to_remove = to_remove[~to_remove['x_index'].isin(drugs)]\n",
  127. "df.drop(to_remove.index, inplace = True)"
  128. ]
  129. },
  130. {
  131. "cell_type": "code",
  132. "execution_count": null,
  133. "metadata": {},
  134. "outputs": [],
  135. "source": [
  136. "to_remove = df[df['y_type']==node_type1]\n",
  137. "to_remove = to_remove[~to_remove['y_index'].isin(drugs)]\n",
  138. "df.drop(to_remove.index, inplace = True)"
  139. ]
  140. },
  141. {
  142. "cell_type": "code",
  143. "execution_count": null,
  144. "metadata": {},
  145. "outputs": [],
  146. "source": [
  147. "to_remove = df[df['x_type']==node_type2]\n",
  148. "to_remove = to_remove[~to_remove['x_index'].isin(diseases)]\n",
  149. "df.drop(to_remove.index, inplace = True)"
  150. ]
  151. },
  152. {
  153. "cell_type": "code",
  154. "execution_count": null,
  155. "metadata": {},
  156. "outputs": [],
  157. "source": [
  158. "to_remove = df[df['y_type']==node_type2]\n",
  159. "to_remove = to_remove[~to_remove['y_index'].isin(diseases)]\n",
  160. "df.drop(to_remove.index, inplace = True)"
  161. ]
  162. },
  163. {
  164. "cell_type": "markdown",
  165. "metadata": {},
  166. "source": [
  167. "### Make HeteroData object for the graph."
  168. ]
  169. },
  170. {
  171. "cell_type": "code",
  172. "execution_count": null,
  173. "metadata": {},
  174. "outputs": [],
  175. "source": [
  176. "new_df = pd.DataFrame()\n",
  177. "new_df[0] = df['x_type'] + '::' + df['x_index'].astype(str)\n",
  178. "new_df[1] = df['relation']\n",
  179. "new_df[2] = df['y_type'] + '::' +df['y_index'].astype(str)"
  180. ]
  181. },
  182. {
  183. "cell_type": "code",
  184. "execution_count": null,
  185. "metadata": {},
  186. "outputs": [],
  187. "source": [
  188. "df = new_df\n",
  189. "df = df.drop_duplicates()\n",
  190. "triplets = df.values.tolist()"
  191. ]
  192. },
  193. {
  194. "cell_type": "code",
  195. "execution_count": null,
  196. "metadata": {
  197. "id": "jUcSDffvCtKY"
  198. },
  199. "outputs": [],
  200. "source": [
  201. "entity_dictionary = {}\n",
  202. "def insert_entry(entry, ent_type, dic):\n",
  203. " if ent_type not in dic:\n",
  204. " dic[ent_type] = {}\n",
  205. " ent_n_id = len(dic[ent_type])\n",
  206. " if entry not in dic[ent_type]:\n",
  207. " dic[ent_type][entry] = ent_n_id\n",
  208. " return dic\n",
  209. "\n",
  210. "for triple in triplets:\n",
  211. " src = triple[0]\n",
  212. " split_src = src.split('::')\n",
  213. " src_type = split_src[0]\n",
  214. " dest = triple[2]\n",
  215. " split_dest = dest.split('::')\n",
  216. " dest_type = split_dest[0]\n",
  217. " insert_entry(src,src_type,entity_dictionary)\n",
  218. " insert_entry(dest,dest_type,entity_dictionary)"
  219. ]
  220. },
  221. {
  222. "cell_type": "code",
  223. "execution_count": null,
  224. "metadata": {
  225. "id": "vTybNyrqFLrl"
  226. },
  227. "outputs": [],
  228. "source": [
  229. "edge_dictionary={}\n",
  230. "for triple in triplets:\n",
  231. " src = triple[0]\n",
  232. " split_src = src.split('::')\n",
  233. " src_type = split_src[0]\n",
  234. " dest = triple[2]\n",
  235. " split_dest = dest.split('::')\n",
  236. " dest_type = split_dest[0]\n",
  237. " \n",
  238. " src_int_id = entity_dictionary[src_type][src]\n",
  239. " dest_int_id = entity_dictionary[dest_type][dest]\n",
  240. " \n",
  241. " pair = (src_int_id,dest_int_id)\n",
  242. " etype = (src_type, triple[1],dest_type)\n",
  243. " if etype in edge_dictionary:\n",
  244. " edge_dictionary[etype] += [pair]\n",
  245. " else:\n",
  246. " edge_dictionary[etype] = [pair]"
  247. ]
  248. },
  249. {
  250. "cell_type": "code",
  251. "execution_count": null,
  252. "metadata": {},
  253. "outputs": [],
  254. "source": [
  255. "data = HeteroData()\n",
  256. "\n",
  257. "for i, key in enumerate(entity_dictionary.keys()):\n",
  258. " if key != 'drug':\n",
  259. " data[key].x = (torch.ones((len(entity_dictionary[key]), 768)) * i)\n",
  260. " elif key == 'drug':\n",
  261. " data[key].x = (torch.rand((len(entity_dictionary[key]), 767)))\n",
  262. " \n",
  263. " data[key].id = torch.arange(len(entity_dictionary[key]))\n",
  264. "\n",
  265. "for key in edge_dictionary:\n",
  266. " data[key].edge_index = torch.transpose(torch.IntTensor(edge_dictionary[key]), 0, 1).long().contiguous()"
  267. ]
  268. },
  269. {
  270. "cell_type": "markdown",
  271. "metadata": {},
  272. "source": [
  273. "### Add initial embeddings."
  274. ]
  275. },
  276. {
  277. "cell_type": "code",
  278. "execution_count": null,
  279. "metadata": {},
  280. "outputs": [],
  281. "source": [
  282. "embeddings = pd.read_pickle('../data/entities_embeddings.pkl')\n",
  283. "smiles_embeddings = pd.read_pickle('../data/smiles_embeddings.pkl')\n",
  284. "\n",
  285. "for i, row in smiles_embeddings.iterrows():\n",
  286. " if row['id'] in entity_dictionary['drug'].keys():\n",
  287. " drug_id = entity_dictionary['drug'][row['id']]\n",
  288. " data['drug'].x[drug_id] = torch.Tensor(row['embedding'])\n",
  289. "\n",
  290. "for i, row in embeddings.iterrows():\n",
  291. " x_type = row['id'].split('::')[0]\n",
  292. " if x_type in data.node_types and row['id'] in entity_dictionary[x_type] and x_type != 'drug':\n",
  293. " id_ = entity_dictionary[x_type][row['id']]\n",
  294. " data[x_type].x[id_][:768] = torch.Tensor(row['embedding'])"
  295. ]
  296. },
  297. {
  298. "cell_type": "markdown",
  299. "metadata": {},
  300. "source": [
  301. "### Load train and validation data of one fold."
  302. ]
  303. },
  304. {
  305. "cell_type": "code",
  306. "execution_count": null,
  307. "metadata": {},
  308. "outputs": [],
  309. "source": [
  310. "file = open('../data/CV data/train1.pkl', 'rb')\n",
  311. "train_data = pickle.load(file)"
  312. ]
  313. },
  314. {
  315. "cell_type": "code",
  316. "execution_count": null,
  317. "metadata": {},
  318. "outputs": [],
  319. "source": [
  320. "file = open('../data/CV data/val1.pkl', 'rb')\n",
  321. "val_data = pickle.load(file)"
  322. ]
  323. },
  324. {
  325. "cell_type": "markdown",
  326. "metadata": {},
  327. "source": [
  328. "### Creating mask."
  329. ]
  330. },
  331. {
  332. "cell_type": "code",
  333. "execution_count": null,
  334. "metadata": {},
  335. "outputs": [],
  336. "source": [
  337. "drug_disease_num = train_data[(node_type1, rel, node_type2)]['edge_index'].shape[1]\n",
  338. "mask = random.sample(range(drug_disease_num), int(drug_disease_num*0.8))\n",
  339. "train_data[(node_type1, rel, node_type2)]['mask'] = torch.zeros(drug_disease_num, dtype=torch.bool)\n",
  340. "train_data[(node_type1, rel, node_type2)]['mask'][mask] = True\n",
  341. "\n",
  342. "train_data[(node_type2, rel, node_type1)]['mask'] = torch.zeros(drug_disease_num, dtype=torch.bool)\n",
  343. "train_data[(node_type2, rel, node_type1)]['mask'][mask] = True"
  344. ]
  345. },
  346. {
  347. "cell_type": "markdown",
  348. "metadata": {},
  349. "source": [
  350. "### Define model."
  351. ]
  352. },
  353. {
  354. "cell_type": "code",
  355. "execution_count": null,
  356. "metadata": {
  357. "id": "ql-F7A42fMWm"
  358. },
  359. "outputs": [],
  360. "source": [
  361. "class HGT(nn.Module):\n",
  362. " def __init__(self, hidden_channels, out_channels, num_heads, num_layers, dropout):\n",
  363. " super().__init__()\n",
  364. "\n",
  365. " self.lin_dict = nn.ModuleDict()\n",
  366. " for node_type in train_data.node_types:\n",
  367. " self.lin_dict[node_type] = Linear(-1, hidden_channels[0])\n",
  368. " \n",
  369. " self.convs = nn.ModuleList()\n",
  370. " for i in range(num_layers):\n",
  371. " conv = HGTConv(hidden_channels[i], hidden_channels[i+1], train_data.metadata(),\n",
  372. " num_heads[i], group='mean')\n",
  373. " self.convs.append(conv)\n",
  374. " \n",
  375. " self.lin = Linear(sum(hidden_channels[1:]), out_channels)\n",
  376. " \n",
  377. " self.dropout = nn.Dropout(dropout)\n",
  378. "\n",
  379. " def forward(self, x_dict, edge_index_dict):\n",
  380. " x_dict = {\n",
  381. " node_type: self.dropout(self.lin_dict[node_type](x).relu_())\n",
  382. " for node_type, x in x_dict.items()\n",
  383. " }\n",
  384. " out = {}\n",
  385. " for i, conv in enumerate(self.convs):\n",
  386. " x_dict = conv(x_dict, edge_index_dict)\n",
  387. "\n",
  388. " if out=={}:\n",
  389. " out = copy.copy(x_dict)\n",
  390. " else:\n",
  391. " out = {\n",
  392. " node_type: torch.cat((out[node_type], x_dict[node_type]), dim=1)\n",
  393. " for node_type, x in x_dict.items()\n",
  394. " }\n",
  395. "\n",
  396. " return F.relu(self.lin(out[node_type1])), F.relu(self.lin(out[node_type2]))"
  397. ]
  398. },
  399. {
  400. "cell_type": "code",
  401. "execution_count": null,
  402. "metadata": {
  403. "id": "atdEjxJqvLaN"
  404. },
  405. "outputs": [],
  406. "source": [
  407. "class MLPPredictor(nn.Module):\n",
  408. " def __init__(self, channel_num, dropout):\n",
  409. " super().__init__()\n",
  410. " self.L1 = nn.Linear(channel_num * 2, channel_num)\n",
  411. " self.L2 = nn.Linear(channel_num, 1)\n",
  412. " self.bn = nn.BatchNorm1d(num_features=channel_num)\n",
  413. " self.dropout = nn.Dropout(0.2)\n",
  414. "\n",
  415. " def forward(self, drug_embeddings, disease_embeddings):\n",
  416. " x = torch.cat((drug_embeddings, disease_embeddings), dim=1)\n",
  417. " x = F.relu(self.bn(self.L1(x)))\n",
  418. " x = self.dropout(x)\n",
  419. " x = self.L2(x)\n",
  420. " return x"
  421. ]
  422. },
  423. {
  424. "cell_type": "code",
  425. "execution_count": null,
  426. "metadata": {
  427. "id": "6fHs5rX76ldq"
  428. },
  429. "outputs": [],
  430. "source": [
  431. "def compute_loss(scores, labels):\n",
  432. " pos_weights = torch.clone(labels)\n",
  433. " pos_weights[pos_weights == 1] = ((labels==0).sum() / labels.shape[0])\n",
  434. " pos_weights[pos_weights == 0] = ((labels==1).sum() / labels.shape[0])\n",
  435. " \n",
  436. " return F.binary_cross_entropy_with_logits(scores, labels, pos_weight=pos_weights)\n",
  437. "# return F.binary_cross_entropy_with_logits(scores, labels)"
  438. ]
  439. },
  440. {
  441. "cell_type": "code",
  442. "execution_count": null,
  443. "metadata": {
  444. "id": "nmLsh9VigPpI"
  445. },
  446. "outputs": [],
  447. "source": [
  448. "def define_model(dropout):\n",
  449. " GNN = HGT(hidden_channels=[64, 64, 64, 64],\n",
  450. " out_channels=64,\n",
  451. " num_heads=[8, 8, 8],\n",
  452. " num_layers=3,\n",
  453. " dropout=dropout)\n",
  454. "\n",
  455. " pred = MLPPredictor(64, dropout)\n",
  456. " model = nn.Sequential(GNN, pred)\n",
  457. " model.to(device)\n",
  458. " \n",
  459. " return GNN, pred, model"
  460. ]
  461. },
  462. {
  463. "cell_type": "code",
  464. "execution_count": null,
  465. "metadata": {
  466. "id": "kk5vWUiQV7oi"
  467. },
  468. "outputs": [],
  469. "source": [
  470. "def define_loaders(config):\n",
  471. " kwargs = {'batch_size': config['batch_size'], 'num_workers': 8, 'persistent_workers': True}\n",
  472. " \n",
  473. " train_loader = HGTLoader(train_data, num_samples=[config['num_samples']] * 3, shuffle=True, input_nodes=(node_type1, None), **kwargs)\n",
  474. " val_loader = HGTLoader(val_data, num_samples=[config['num_samples']] * 3, shuffle=True, input_nodes=(node_type1, None), **kwargs)\n",
  475. " return train_loader, val_loader"
  476. ]
  477. },
  478. {
  479. "cell_type": "code",
  480. "execution_count": null,
  481. "metadata": {},
  482. "outputs": [],
  483. "source": [
  484. "def edge_exists(edges, edge):\n",
  485. " edges = edges.to(device)\n",
  486. " edge = edge.to(device)\n",
  487. " return (edges == edge).all(dim=0).sum() > 0"
  488. ]
  489. },
  490. {
  491. "cell_type": "markdown",
  492. "metadata": {},
  493. "source": [
  494. "### Make batches."
  495. ]
  496. },
  497. {
  498. "cell_type": "code",
  499. "execution_count": null,
  500. "metadata": {},
  501. "outputs": [],
  502. "source": [
  503. "def make_batch(batch):\n",
  504. " \n",
  505. " batch_size = batch[node_type1].batch_size\n",
  506. " edge_index = batch[(node_type1, rel, node_type2)]['edge_index']\n",
  507. " mask = batch[(node_type1, rel, node_type2)]['mask'] \n",
  508. " \n",
  509. " batch_index = (edge_index[0] < batch_size)\n",
  510. " edge_index = edge_index[:, batch_index]\n",
  511. " mask = mask[batch_index]\n",
  512. " edge_label_index = edge_index[:, mask]\n",
  513. " pos_num = edge_label_index.shape[1]\n",
  514. " edge_label = torch.ones(pos_num)\n",
  515. " \n",
  516. " neg_edges_source = []\n",
  517. " neg_edges_dest = []\n",
  518. " while len(neg_edges_source) < pos_num:\n",
  519. " source = random.randint(0, batch_size-1)\n",
  520. " dest = random.randint(0, batch[node_type2].x.shape[0]-1)\n",
  521. " neg_edge = torch.Tensor([[source], [dest]])\n",
  522. " if edge_exists(edge_index, neg_edge):\n",
  523. " continue\n",
  524. " else:\n",
  525. " neg_edges_source.append(source)\n",
  526. " neg_edges_dest.append(dest)\n",
  527. " \n",
  528. " neg_edges = torch.tensor([neg_edges_source, neg_edges_dest])\n",
  529. " edge_label_index = torch.cat((edge_label_index, neg_edges), dim=1)\n",
  530. " edge_label = torch.cat((edge_label, torch.zeros(neg_edges.shape[1])), dim=0)\n",
  531. " edge_index = edge_index[:, ~mask]\n",
  532. "\n",
  533. " batch[(node_type1, rel, node_type2)]['edge_index'] = edge_index\n",
  534. " batch[(node_type1, rel, node_type2)]['edge_label_index'] = edge_label_index\n",
  535. " batch[(node_type1, rel, node_type2)]['edge_label'] = edge_label\n",
  536. " \n",
  537. " batch[(node_type2, rel, node_type1)]['edge_index'] = edge_index\n",
  538. " temp = copy.copy(batch[(node_type2, rel, node_type1)]['edge_index'][0])\n",
  539. " batch[(node_type2, rel, node_type1)]['edge_index'][0] = batch[(node_type2, rel, node_type1)]['edge_index'][1]\n",
  540. " batch[(node_type2, rel, node_type1)]['edge_index'][1] = temp\n",
  541. " \n",
  542. " return batch"
  543. ]
  544. },
  545. {
  546. "cell_type": "code",
  547. "execution_count": null,
  548. "metadata": {},
  549. "outputs": [],
  550. "source": [
  551. "def make_test_batch(batch):\n",
  552. " \n",
  553. " batch_size = batch[node_type1].batch_size\n",
  554. " edge_index = batch[(node_type1, rel, node_type2)]['edge_index']\n",
  555. " edge_label_index = batch[(node_type1, rel, node_type2)]['edge_label_index']\n",
  556. " edge_label = batch[(node_type1, rel, node_type2)]['edge_label']\n",
  557. " \n",
  558. " source = []\n",
  559. " dest = []\n",
  560. " labels = []\n",
  561. " for i in range(edge_label_index.shape[1]):\n",
  562. " if edge_label_index[0, i] in batch[node_type1]['id'] and edge_label_index[1, i] in batch[node_type2]['id'] \\\n",
  563. " and ((batch[node_type1]['id'] == edge_label_index[0, i]).nonzero(as_tuple=True)[0]) < batch_size:\n",
  564. " if edge_label[i] == 1:\n",
  565. " source.append((batch[node_type1]['id'] == edge_label_index[0, i]).nonzero(as_tuple=True)[0])\n",
  566. " dest.append((batch[node_type2]['id'] == edge_label_index[1, i]).nonzero(as_tuple=True)[0])\n",
  567. "\n",
  568. " edge_label_index = torch.zeros(2, len(source)).long()\n",
  569. " edge_label_index[0] = torch.tensor(source)\n",
  570. " edge_label_index[1] = torch.tensor(dest)\n",
  571. " pos_num = edge_label_index.shape[1]\n",
  572. " edge_label = torch.ones(pos_num)\n",
  573. " \n",
  574. " neg_edges_source = []\n",
  575. " neg_edges_dest = []\n",
  576. " while len(neg_edges_source) < pos_num:\n",
  577. " source_node = random.randint(0, batch_size-1)\n",
  578. " dest_node = random.randint(0, batch[node_type2].x.shape[0]-1)\n",
  579. " neg_edge = torch.Tensor([[source_node], [dest_node]])\n",
  580. " neg_edge_in_orig_graph = torch.Tensor([[batch[node_type1]['id'][source_node]], [batch[node_type2]['id'][dest_node]]])\n",
  581. " if edge_exists(data[(node_type1, rel, node_type2)]['edge_index'], neg_edge_in_orig_graph):\n",
  582. " continue\n",
  583. " else:\n",
  584. " neg_edges_source.append(source_node)\n",
  585. " neg_edges_dest.append(dest_node)\n",
  586. "\n",
  587. " neg_edges = torch.tensor([neg_edges_source, neg_edges_dest])\n",
  588. " edge_label_index = torch.cat((edge_label_index, neg_edges), dim=1)\n",
  589. " edge_label = torch.cat((edge_label, torch.zeros(neg_edges.shape[1])), dim=0)\n",
  590. "\n",
  591. " batch[(node_type1, rel, node_type2)]['edge_label_index'] = edge_label_index\n",
  592. " batch[(node_type1, rel, node_type2)]['edge_label'] = edge_label\n",
  593. "\n",
  594. " return batch"
  595. ]
  596. },
  597. {
  598. "cell_type": "markdown",
  599. "metadata": {},
  600. "source": [
  601. "### Train"
  602. ]
  603. },
  604. {
  605. "cell_type": "code",
  606. "execution_count": null,
  607. "metadata": {
  608. "id": "51k5xvYymLvw"
  609. },
  610. "outputs": [],
  611. "source": [
  612. "def train(GNN, pred, model, loader, optimizer):\n",
  613. " model.train()\n",
  614. " total_examples = total_loss = 0\n",
  615. " for i, batch in enumerate(iter(loader)):\n",
  616. " optimizer.zero_grad()\n",
  617. " batch = make_batch(batch)\n",
  618. " batch = batch.to(device)\n",
  619. " edge_label_index = batch[(node_type1, rel, node_type2)]['edge_label_index']\n",
  620. " edge_label = batch[(node_type1, rel, node_type2)]['edge_label']\n",
  621. " if edge_label.shape[0] == 0:\n",
  622. " continue\n",
  623. " \n",
  624. " drug_embeddings, disease_embeddings = GNN(batch.x_dict, batch.edge_index_dict)\n",
  625. " \n",
  626. " c = drug_embeddings[edge_label_index[0]]\n",
  627. " d = disease_embeddings[edge_label_index[1]]\n",
  628. " out = pred(c, d)[:, 0]\n",
  629. " loss = compute_loss(out, edge_label)\n",
  630. " loss.backward()\n",
  631. " optimizer.step()\n",
  632. "\n",
  633. " total_examples += edge_label_index.shape[1]\n",
  634. " total_loss += float(loss) * edge_label_index.shape[1]\n",
  635. "\n",
  636. " return total_loss / total_examples"
  637. ]
  638. },
  639. {
  640. "cell_type": "markdown",
  641. "metadata": {},
  642. "source": [
  643. "### Test"
  644. ]
  645. },
  646. {
  647. "cell_type": "code",
  648. "execution_count": null,
  649. "metadata": {
  650. "id": "Vyvi80_Wo4GE"
  651. },
  652. "outputs": [],
  653. "source": [
  654. "@torch.no_grad()\n",
  655. "def test(GNN, pred, model, loader):\n",
  656. " model.eval()\n",
  657. "\n",
  658. " total_examples = total_correct = 0\n",
  659. " out, labels = torch.tensor([]).to(device), torch.tensor([]).to(device)\n",
  660. " source, dest = torch.tensor([]).to(device), torch.tensor([]).to(device)\n",
  661. " for batch in iter(loader):\n",
  662. " batch = make_test_batch(batch)\n",
  663. " batch = batch.to(device)\n",
  664. " drug_embeddings, disease_embeddings = GNN(batch.x_dict, batch.edge_index_dict)\n",
  665. " \n",
  666. " edge_label_index = batch[(node_type1, rel, node_type2)]['edge_label_index']\n",
  667. " edge_label = batch[(node_type1, rel, node_type2)]['edge_label']\n",
  668. " \n",
  669. " if edge_label.shape[0] == 0:\n",
  670. " continue\n",
  671. " \n",
  672. " c = drug_embeddings[edge_label_index[0]]\n",
  673. " d = disease_embeddings[edge_label_index[1]]\n",
  674. " batch_out = pred(c, d)[:, 0]\n",
  675. " labels = torch.cat((labels, edge_label))\n",
  676. " out = torch.cat((out, batch_out))\n",
  677. " \n",
  678. " drugs = batch[node_type1]['id'][edge_label_index[0]]\n",
  679. " diseases = batch[node_type2]['id'][edge_label_index[1]]\n",
  680. " source = torch.cat((source, drugs))\n",
  681. " dest = torch.cat((dest, diseases))\n",
  682. "\n",
  683. " loss = compute_loss(out, labels) \n",
  684. " return out, labels, source, dest, loss.cpu().numpy()"
  685. ]
  686. },
  687. {
  688. "cell_type": "markdown",
  689. "metadata": {},
  690. "source": [
  691. "### Run"
  692. ]
  693. },
  694. {
  695. "cell_type": "code",
  696. "execution_count": null,
  697. "metadata": {},
  698. "outputs": [],
  699. "source": [
  700. "def run(config):\n",
  701. " losses, val_losses = [], []\n",
  702. " best_val_loss = float('inf')\n",
  703. " best_epoch = 0\n",
  704. " \n",
  705. " train_loader, val_loader = define_loaders(config)\n",
  706. " GNN, pred, model = define_model(config['dropout'])\n",
  707. " \n",
  708. " optimizer = torch.optim.AdamW(model.parameters())\n",
  709. " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, \n",
  710. " T_max=config['epochs'], \n",
  711. " eta_min=0, \n",
  712. " last_epoch=-1, \n",
  713. " verbose=False)\n",
  714. " \n",
  715. " for epoch in range(config['epochs']):\n",
  716. " loss = train(GNN, pred, model, train_loader, optimizer)\n",
  717. " out, labels, source, dest, val_loss = test(GNN, pred, model, val_loader)\n",
  718. " write_to_out(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, ValLoss: {val_loss:.4f} \\n')\n",
  719. " losses.append(loss)\n",
  720. " val_losses.append(val_loss)\n",
  721. " plot_losses(losses, val_losses)\n",
  722. "\n",
  723. " scheduler.step()\n",
  724. " \n",
  725. " torch.save(model.state_dict(), '../out/saved_model.h5')\n",
  726. " \n",
  727. " out, labels, source, dest, val_loss = test(GNN, pred, model, val_loader)\n",
  728. " AUPR(out, labels)\n",
  729. " AUROC(out, labels)"
  730. ]
  731. },
  732. {
  733. "cell_type": "code",
  734. "execution_count": null,
  735. "metadata": {},
  736. "outputs": [],
  737. "source": [
  738. "run(config)"
  739. ]
  740. }
  741. ],
  742. "metadata": {
  743. "colab": {
  744. "collapsed_sections": [],
  745. "name": "geo-pyHGT.ipynb",
  746. "provenance": []
  747. },
  748. "gpuClass": "standard",
  749. "kernelspec": {
  750. "display_name": "basee",
  751. "language": "python",
  752. "name": "basee"
  753. },
  754. "language_info": {
  755. "codemirror_mode": {
  756. "name": "ipython",
  757. "version": 3
  758. },
  759. "file_extension": ".py",
  760. "mimetype": "text/x-python",
  761. "name": "python",
  762. "nbconvert_exporter": "python",
  763. "pygments_lexer": "ipython3",
  764. "version": "3.9.12"
  765. }
  766. },
  767. "nbformat": 4,
  768. "nbformat_minor": 1
  769. }