|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Global seed set to 42\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "42"
- ]
- },
- "execution_count": 1,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import numpy as np\n",
- "import os\n",
- "from importlib import reload\n",
- "from copy import deepcopy\n",
- "import json\n",
- "import pandas as pd\n",
- "from tqdm.notebook import tqdm\n",
- "\n",
- "import ray\n",
- "from ray import tune\n",
- "from ray.tune.schedulers import ASHAScheduler\n",
- "from ray.tune import CLIReporter\n",
- "\n",
- "import torch\n",
- "from torch import nn\n",
- "import pytorch_lightning as pl\n",
- "\n",
- "from pytorch_lightning import Trainer, seed_everything\n",
- "from pytorch_lightning.loggers import TensorBoardLogger\n",
- "from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor\n",
- "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
- "from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback\n",
- "\n",
- "seed_everything(42)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "import src.data\n",
- "reload(src.data)\n",
- "from src.data import CSIDataset\n",
- "\n",
- "from torch.utils.data import DataLoader"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "import src.model\n",
- "reload(src.model)\n",
- "from src.model import CSIModel\n",
- "\n",
- "\n",
- "def experiment(args):\n",
- " dataset = 'weibo'\n",
- " path = f'/media/external_10TB/10TB/ramezani/Omranpour/assets/{dataset}/'\n",
- " \n",
- " train_set = CSIDataset(pkl_dir=path + 'train/pkls/')\n",
- " val_set = CSIDataset(pkl_dir=path + 'validation/pkls/')\n",
- " train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)\n",
- " val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)\n",
- " \n",
- " conf = {\n",
- " 'capture_input_dim' : 112,\n",
- " 'score_input_dim' : 50,\n",
- " 'd_Wa': 100,\n",
- " 'd_lstm' : 50,\n",
- " 'd_Wr' : 100,\n",
- " 'd_Wu' : 100,\n",
- " 'd_Ws' : 1,\n",
- " 'lr': args['lr'],\n",
- " 'dropout' : args['dropout'],\n",
- " 'weight_decay' : args['weight_decay']\n",
- " }\n",
- " model = CSIModel(conf)\n",
- "\n",
- " name = f\"dataset={dataset}-do={args['dropout']}-lr={args['lr']}-wd={args['weight_decay']}\"\n",
- " save_dir = f'/media/external_10TB/10TB/ramezani/Omranpour/CSI/weights/{name}/'\n",
- " logger = TensorBoardLogger(save_dir='logs/', name=name)\n",
- " checkpoint = ModelCheckpoint(\n",
- " dirpath=save_dir, \n",
- " filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}', \n",
- " monitor='val_acc',\n",
- " mode='max',\n",
- " save_top_k=5, \n",
- " every_n_epochs = 5\n",
- " )\n",
- " os.makedirs(save_dir, exist_ok=True)\n",
- " json.dump(conf, open(save_dir + 'config.json', 'w'))\n",
- "\n",
- " trainer = Trainer(\n",
- " benchmark=True, \n",
- " gpus=[0], \n",
- " accumulate_grad_batches=64,\n",
- " logger=logger, \n",
- " enable_progress_bar=False,\n",
- " max_epochs=20,\n",
- " callbacks=[\n",
- " checkpoint,\n",
- " ]\n",
- " )\n",
- " trainer.fit(model, train_loader, val_loader)\n",
- " res = trainer.validate(model, val_loader)[0]\n",
- " tune.report(**res)\n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "== Status ==<br>Current time: 2022-03-11 21:39:55 (running for 00:09:12.46)<br>Memory usage on this node: 15.5/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 2.0/18 CPUs, 2.0/2 GPUs, 0.0/70.74 GiB heap, 0.0/34.31 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/CSI-weibo<br>Number of trials: 20/20 (19 PENDING, 1 RUNNING)<br><br>"
- ],
- "text/plain": [
- "<IPython.core.display.HTML object>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "analysis = tune.run(\n",
- " experiment,\n",
- " num_samples=5,\n",
- " resources_per_trial={\"cpu\": 2, \"gpu\": 2},\n",
- " verbose=1,\n",
- " config={\n",
- " \"weight_decay\": tune.grid_search([0., 0.1, 0.01, 0.001]),\n",
- " \"lr\": tune.loguniform(1e-4, 1e-1),\n",
- " \"dropout\": tune.uniform(0., 0.3)\n",
- " },\n",
- " name='CSI-weibo'\n",
- ")\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "analysis.results_df"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "analysis.results_df.to_csv('results.csv', index=False)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.8"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|