|
|
@@ -0,0 +1,214 @@ |
|
|
|
{ |
|
|
|
"cells": [ |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 1, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"name": "stderr", |
|
|
|
"output_type": "stream", |
|
|
|
"text": [ |
|
|
|
"Global seed set to 42\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/plain": [ |
|
|
|
"42" |
|
|
|
] |
|
|
|
}, |
|
|
|
"execution_count": 1, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "execute_result" |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"import numpy as np\n", |
|
|
|
"import os\n", |
|
|
|
"from importlib import reload\n", |
|
|
|
"from copy import deepcopy\n", |
|
|
|
"import json\n", |
|
|
|
"import pandas as pd\n", |
|
|
|
"from tqdm.notebook import tqdm\n", |
|
|
|
"\n", |
|
|
|
"import ray\n", |
|
|
|
"from ray import tune\n", |
|
|
|
"from ray.tune.schedulers import ASHAScheduler\n", |
|
|
|
"\n", |
|
|
|
"import torch\n", |
|
|
|
"from torch import nn\n", |
|
|
|
"import pytorch_lightning as pl\n", |
|
|
|
"\n", |
|
|
|
"from pytorch_lightning import Trainer, seed_everything\n", |
|
|
|
"from pytorch_lightning.loggers import TensorBoardLogger\n", |
|
|
|
"from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor\n", |
|
|
|
"from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", |
|
|
|
"from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback\n", |
|
|
|
"\n", |
|
|
|
"seed_everything(42)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"## data" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 2, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import src.data\n", |
|
|
|
"reload(src.data)\n", |
|
|
|
"from src.data import CSIDataset\n", |
|
|
|
"\n", |
|
|
|
"from torch.utils.data import DataLoader" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": {}, |
|
|
|
"source": [ |
|
|
|
"## Model" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": 3, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"import src.model\n", |
|
|
|
"reload(src.model)\n", |
|
|
|
"from src.model import CSIModel\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"def experiment(args):\n", |
|
|
|
" dataset = 'weibo'\n", |
|
|
|
" path = f'/media/external_10TB/10TB/ramezani/Omranpour/assets/{dataset}/'\n", |
|
|
|
" \n", |
|
|
|
" train_set = CSIDataset(pkl_dir=path + 'train/pkls/')\n", |
|
|
|
" val_set = CSIDataset(pkl_dir=path + 'validation/pkls/')\n", |
|
|
|
" train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)\n", |
|
|
|
" val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)\n", |
|
|
|
" \n", |
|
|
|
" \n", |
|
|
|
" conf = {\n", |
|
|
|
" 'capture_input_dim' : 112,\n", |
|
|
|
" 'score_input_dim' : 50,\n", |
|
|
|
" 'd_Wa': 100,\n", |
|
|
|
" 'd_lstm' : 50,\n", |
|
|
|
" 'd_Wr' : 100,\n", |
|
|
|
" 'd_Wu' : 100,\n", |
|
|
|
" 'd_Ws' : 1,\n", |
|
|
|
" 'lr': args['lr'],\n", |
|
|
|
" 'dropout' : args['dropout'],\n", |
|
|
|
" 'weight_decay' : args['weight_decay']\n", |
|
|
|
" }\n", |
|
|
|
" model = CSIModel(conf)\n", |
|
|
|
"\n", |
|
|
|
" name = f\"dataset={dataset}-do={args['dropout']}-lr={args['lr']}-wd={args['weight_decay']}\"\n", |
|
|
|
" save_dir = f'/media/external_10TB/10TB/ramezani/Omranpour/CSI/weights/{name}/'\n", |
|
|
|
" logger = TensorBoardLogger(save_dir='logs/', name=name)\n", |
|
|
|
" checkpoint = ModelCheckpoint(\n", |
|
|
|
" dirpath=save_dir, \n", |
|
|
|
" filename='{epoch}-{val_loss:.2f}', \n", |
|
|
|
" monitor='val_loss',\n", |
|
|
|
" mode='min',\n", |
|
|
|
" save_top_k=10, \n", |
|
|
|
" every_n_epochs = 5\n", |
|
|
|
" )\n", |
|
|
|
"# reporter = TuneReportCallback(\n", |
|
|
|
"# {\n", |
|
|
|
"# \"loss\": \"ptl/val_loss\",\n", |
|
|
|
"# \"mean_accuracy\": \"ptl/val_acc\"\n", |
|
|
|
"# },\n", |
|
|
|
"# on=\"validation_end\"\n", |
|
|
|
"# )\n", |
|
|
|
" os.makedirs(save_dir, exist_ok=True)\n", |
|
|
|
" json.dump(conf, open(save_dir + 'config.json', 'w'))\n", |
|
|
|
"\n", |
|
|
|
" trainer = Trainer(\n", |
|
|
|
" benchmark=True, \n", |
|
|
|
" gpus=[1], \n", |
|
|
|
" accumulate_grad_batches=64,\n", |
|
|
|
" logger=logger, \n", |
|
|
|
" enable_progress_bar=False,\n", |
|
|
|
" max_epochs=10,\n", |
|
|
|
" callbacks=[checkpoint]\n", |
|
|
|
" )\n", |
|
|
|
" trainer.fit(model, train_loader, val_loader)\n", |
|
|
|
" res = trainer.validate(val_loader)[0]\n", |
|
|
|
" tune.report(**res)" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": { |
|
|
|
"scrolled": true |
|
|
|
}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/html": [ |
|
|
|
"== Status ==<br>Current time: 2022-03-09 01:53:36 (running for 00:01:52.55)<br>Memory usage on this node: 14.1/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 1.0/18 CPUs, 2.0/2 GPUs, 0.0/71.06 GiB heap, 0.0/34.44 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/experiment_2022-03-09_01-51-43<br>Number of trials: 16/16 (15 PENDING, 1 RUNNING)<br><br>" |
|
|
|
], |
|
|
|
"text/plain": [ |
|
|
|
"<IPython.core.display.HTML object>" |
|
|
|
] |
|
|
|
}, |
|
|
|
"metadata": {}, |
|
|
|
"output_type": "display_data" |
|
|
|
} |
|
|
|
], |
|
|
|
"source": [ |
|
|
|
"analysis = tune.run(\n", |
|
|
|
" experiment,\n", |
|
|
|
" num_samples=4,\n", |
|
|
|
" resources_per_trial={\"cpu\": 1, \"gpu\": 2},\n", |
|
|
|
" verbose=1,\n", |
|
|
|
" config={\n", |
|
|
|
" \"weight_decay\": tune.grid_search([0., 0.1, 0.01, 0.001]),\n", |
|
|
|
" \"lr\": tune.loguniform(1e-5, 1e-1),\n", |
|
|
|
" \"dropout\": tune.uniform(0., 0.3)\n", |
|
|
|
" }\n", |
|
|
|
")\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"df = analysis.results_df" |
|
|
|
] |
|
|
|
} |
|
|
|
], |
|
|
|
"metadata": { |
|
|
|
"kernelspec": { |
|
|
|
"display_name": "Python 3", |
|
|
|
"language": "python", |
|
|
|
"name": "python3" |
|
|
|
}, |
|
|
|
"language_info": { |
|
|
|
"codemirror_mode": { |
|
|
|
"name": "ipython", |
|
|
|
"version": 3 |
|
|
|
}, |
|
|
|
"file_extension": ".py", |
|
|
|
"mimetype": "text/x-python", |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.6.8" |
|
|
|
} |
|
|
|
}, |
|
|
|
"nbformat": 4, |
|
|
|
"nbformat_minor": 2 |
|
|
|
} |