Browse Source

fix bug

master
Pooya Moini 2 years ago
parent
commit
05a9bfbb51
1 changed files with 30 additions and 24 deletions
  1. 30
    24
      train.ipynb

+ 30
- 24
train.ipynb View File

"import ray\n", "import ray\n",
"from ray import tune\n", "from ray import tune\n",
"from ray.tune.schedulers import ASHAScheduler\n", "from ray.tune.schedulers import ASHAScheduler\n",
"from ray.tune import CLIReporter\n",
"\n", "\n",
"import torch\n", "import torch\n",
"from torch import nn\n", "from torch import nn\n",
" train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)\n", " train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)\n",
" val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)\n", " val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)\n",
" \n", " \n",
" \n",
" conf = {\n", " conf = {\n",
" 'capture_input_dim' : 112,\n", " 'capture_input_dim' : 112,\n",
" 'score_input_dim' : 50,\n", " 'score_input_dim' : 50,\n",
" logger = TensorBoardLogger(save_dir='logs/', name=name)\n", " logger = TensorBoardLogger(save_dir='logs/', name=name)\n",
" checkpoint = ModelCheckpoint(\n", " checkpoint = ModelCheckpoint(\n",
" dirpath=save_dir, \n", " dirpath=save_dir, \n",
" filename='{epoch}-{val_loss:.2f}', \n",
" monitor='val_loss',\n",
" mode='min',\n",
" save_top_k=10, \n",
" filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}', \n",
" monitor='val_acc',\n",
" mode='max',\n",
" save_top_k=5, \n",
" every_n_epochs = 5\n", " every_n_epochs = 5\n",
" )\n", " )\n",
"# reporter = TuneReportCallback(\n",
"# {\n",
"# \"loss\": \"ptl/val_loss\",\n",
"# \"mean_accuracy\": \"ptl/val_acc\"\n",
"# },\n",
"# on=\"validation_end\"\n",
"# )\n",
" os.makedirs(save_dir, exist_ok=True)\n", " os.makedirs(save_dir, exist_ok=True)\n",
" json.dump(conf, open(save_dir + 'config.json', 'w'))\n", " json.dump(conf, open(save_dir + 'config.json', 'w'))\n",
"\n", "\n",
" trainer = Trainer(\n", " trainer = Trainer(\n",
" benchmark=True, \n", " benchmark=True, \n",
" gpus=[1], \n",
" gpus=[0], \n",
" accumulate_grad_batches=64,\n", " accumulate_grad_batches=64,\n",
" logger=logger, \n", " logger=logger, \n",
" enable_progress_bar=False,\n", " enable_progress_bar=False,\n",
" max_epochs=10,\n",
" callbacks=[checkpoint]\n",
" max_epochs=20,\n",
" callbacks=[\n",
" checkpoint,\n",
" ]\n",
" )\n", " )\n",
" trainer.fit(model, train_loader, val_loader)\n", " trainer.fit(model, train_loader, val_loader)\n",
" res = trainer.validate(val_loader)[0]\n",
" tune.report(**res)"
" res = trainer.validate(model, val_loader)[0]\n",
" tune.report(**res)\n",
" return"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"scrolled": true
"scrolled": false
}, },
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/html": [ "text/html": [
"== Status ==<br>Current time: 2022-03-09 01:53:36 (running for 00:01:52.55)<br>Memory usage on this node: 14.1/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 1.0/18 CPUs, 2.0/2 GPUs, 0.0/71.06 GiB heap, 0.0/34.44 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/experiment_2022-03-09_01-51-43<br>Number of trials: 16/16 (15 PENDING, 1 RUNNING)<br><br>"
"== Status ==<br>Current time: 2022-03-11 21:39:55 (running for 00:09:12.46)<br>Memory usage on this node: 15.5/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 2.0/18 CPUs, 2.0/2 GPUs, 0.0/70.74 GiB heap, 0.0/34.31 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/CSI-weibo<br>Number of trials: 20/20 (19 PENDING, 1 RUNNING)<br><br>"
], ],
"text/plain": [ "text/plain": [
"<IPython.core.display.HTML object>" "<IPython.core.display.HTML object>"
"source": [ "source": [
"analysis = tune.run(\n", "analysis = tune.run(\n",
" experiment,\n", " experiment,\n",
" num_samples=4,\n",
" resources_per_trial={\"cpu\": 1, \"gpu\": 2},\n",
" num_samples=5,\n",
" resources_per_trial={\"cpu\": 2, \"gpu\": 2},\n",
" verbose=1,\n", " verbose=1,\n",
" config={\n", " config={\n",
" \"weight_decay\": tune.grid_search([0., 0.1, 0.01, 0.001]),\n", " \"weight_decay\": tune.grid_search([0., 0.1, 0.01, 0.001]),\n",
" \"lr\": tune.loguniform(1e-5, 1e-1),\n",
" \"lr\": tune.loguniform(1e-4, 1e-1),\n",
" \"dropout\": tune.uniform(0., 0.3)\n", " \"dropout\": tune.uniform(0., 0.3)\n",
" }\n",
" },\n",
" name='CSI-weibo'\n",
")\n" ")\n"
] ]
}, },
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"df = analysis.results_df"
"analysis.results_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis.results_df.to_csv('results.csv', index=False)"
] ]
} }
], ],

Loading…
Cancel
Save