A PyTorch implementation of the paper "CSI: a hybrid deep neural network for fake news detection"
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.ipynb 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "name": "stderr",
  10. "output_type": "stream",
  11. "text": [
  12. "Global seed set to 42\n"
  13. ]
  14. },
  15. {
  16. "data": {
  17. "text/plain": [
  18. "42"
  19. ]
  20. },
  21. "execution_count": 1,
  22. "metadata": {},
  23. "output_type": "execute_result"
  24. }
  25. ],
  26. "source": [
  27. "import numpy as np\n",
  28. "import os\n",
  29. "from importlib import reload\n",
  30. "from copy import deepcopy\n",
  31. "import json\n",
  32. "import pandas as pd\n",
  33. "from tqdm.notebook import tqdm\n",
  34. "\n",
  35. "import ray\n",
  36. "from ray import tune\n",
  37. "from ray.tune.schedulers import ASHAScheduler\n",
  38. "\n",
  39. "import torch\n",
  40. "from torch import nn\n",
  41. "import pytorch_lightning as pl\n",
  42. "\n",
  43. "from pytorch_lightning import Trainer, seed_everything\n",
  44. "from pytorch_lightning.loggers import TensorBoardLogger\n",
  45. "from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor\n",
  46. "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
  47. "from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback\n",
  48. "\n",
  49. "seed_everything(42)"
  50. ]
  51. },
  52. {
  53. "cell_type": "markdown",
  54. "metadata": {},
  55. "source": [
  56. "## data"
  57. ]
  58. },
  59. {
  60. "cell_type": "code",
  61. "execution_count": 2,
  62. "metadata": {},
  63. "outputs": [],
  64. "source": [
  65. "import src.data\n",
  66. "reload(src.data)\n",
  67. "from src.data import CSIDataset\n",
  68. "\n",
  69. "from torch.utils.data import DataLoader"
  70. ]
  71. },
  72. {
  73. "cell_type": "markdown",
  74. "metadata": {},
  75. "source": [
  76. "## Model"
  77. ]
  78. },
  79. {
  80. "cell_type": "code",
  81. "execution_count": 3,
  82. "metadata": {},
  83. "outputs": [],
  84. "source": [
  85. "import src.model\n",
  86. "reload(src.model)\n",
  87. "from src.model import CSIModel\n",
  88. "\n",
  89. "\n",
  90. "def experiment(args):\n",
  91. " dataset = 'weibo'\n",
  92. " path = f'/media/external_10TB/10TB/ramezani/Omranpour/assets/{dataset}/'\n",
  93. " \n",
  94. " train_set = CSIDataset(pkl_dir=path + 'train/pkls/')\n",
  95. " val_set = CSIDataset(pkl_dir=path + 'validation/pkls/')\n",
  96. " train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)\n",
  97. " val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)\n",
  98. " \n",
  99. " \n",
  100. " conf = {\n",
  101. " 'capture_input_dim' : 112,\n",
  102. " 'score_input_dim' : 50,\n",
  103. " 'd_Wa': 100,\n",
  104. " 'd_lstm' : 50,\n",
  105. " 'd_Wr' : 100,\n",
  106. " 'd_Wu' : 100,\n",
  107. " 'd_Ws' : 1,\n",
  108. " 'lr': args['lr'],\n",
  109. " 'dropout' : args['dropout'],\n",
  110. " 'weight_decay' : args['weight_decay']\n",
  111. " }\n",
  112. " model = CSIModel(conf)\n",
  113. "\n",
  114. " name = f\"dataset={dataset}-do={args['dropout']}-lr={args['lr']}-wd={args['weight_decay']}\"\n",
  115. " save_dir = f'/media/external_10TB/10TB/ramezani/Omranpour/CSI/weights/{name}/'\n",
  116. " logger = TensorBoardLogger(save_dir='logs/', name=name)\n",
  117. " checkpoint = ModelCheckpoint(\n",
  118. " dirpath=save_dir, \n",
  119. " filename='{epoch}-{val_loss:.2f}', \n",
  120. " monitor='val_loss',\n",
  121. " mode='min',\n",
  122. " save_top_k=10, \n",
  123. " every_n_epochs = 5\n",
  124. " )\n",
  125. "# reporter = TuneReportCallback(\n",
  126. "# {\n",
  127. "# \"loss\": \"ptl/val_loss\",\n",
  128. "# \"mean_accuracy\": \"ptl/val_acc\"\n",
  129. "# },\n",
  130. "# on=\"validation_end\"\n",
  131. "# )\n",
  132. " os.makedirs(save_dir, exist_ok=True)\n",
  133. " json.dump(conf, open(save_dir + 'config.json', 'w'))\n",
  134. "\n",
  135. " trainer = Trainer(\n",
  136. " benchmark=True, \n",
  137. " gpus=[1], \n",
  138. " accumulate_grad_batches=64,\n",
  139. " logger=logger, \n",
  140. " enable_progress_bar=False,\n",
  141. " max_epochs=10,\n",
  142. " callbacks=[checkpoint]\n",
  143. " )\n",
  144. " trainer.fit(model, train_loader, val_loader)\n",
  145. " res = trainer.validate(val_loader)[0]\n",
  146. " tune.report(**res)"
  147. ]
  148. },
  149. {
  150. "cell_type": "code",
  151. "execution_count": null,
  152. "metadata": {
  153. "scrolled": true
  154. },
  155. "outputs": [
  156. {
  157. "data": {
  158. "text/html": [
  159. "== Status ==<br>Current time: 2022-03-09 01:53:36 (running for 00:01:52.55)<br>Memory usage on this node: 14.1/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 1.0/18 CPUs, 2.0/2 GPUs, 0.0/71.06 GiB heap, 0.0/34.44 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/experiment_2022-03-09_01-51-43<br>Number of trials: 16/16 (15 PENDING, 1 RUNNING)<br><br>"
  160. ],
  161. "text/plain": [
  162. "<IPython.core.display.HTML object>"
  163. ]
  164. },
  165. "metadata": {},
  166. "output_type": "display_data"
  167. }
  168. ],
  169. "source": [
  170. "analysis = tune.run(\n",
  171. " experiment,\n",
  172. " num_samples=4,\n",
  173. " resources_per_trial={\"cpu\": 1, \"gpu\": 2},\n",
  174. " verbose=1,\n",
  175. " config={\n",
  176. " \"weight_decay\": tune.grid_search([0., 0.1, 0.01, 0.001]),\n",
  177. " \"lr\": tune.loguniform(1e-5, 1e-1),\n",
  178. " \"dropout\": tune.uniform(0., 0.3)\n",
  179. " }\n",
  180. ")\n"
  181. ]
  182. },
  183. {
  184. "cell_type": "code",
  185. "execution_count": null,
  186. "metadata": {},
  187. "outputs": [],
  188. "source": [
  189. "df = analysis.results_df"
  190. ]
  191. }
  192. ],
  193. "metadata": {
  194. "kernelspec": {
  195. "display_name": "Python 3",
  196. "language": "python",
  197. "name": "python3"
  198. },
  199. "language_info": {
  200. "codemirror_mode": {
  201. "name": "ipython",
  202. "version": 3
  203. },
  204. "file_extension": ".py",
  205. "mimetype": "text/x-python",
  206. "name": "python",
  207. "nbconvert_exporter": "python",
  208. "pygments_lexer": "ipython3",
  209. "version": "3.6.8"
  210. }
  211. },
  212. "nbformat": 4,
  213. "nbformat_minor": 2
  214. }