Browse Source

fix bug

master
Pooya Moini 2 years ago
parent
commit
05a9bfbb51
1 changed files with 30 additions and 24 deletions
  1. 30
    24
      train.ipynb

+ 30
- 24
train.ipynb View File

@@ -35,6 +35,7 @@
"import ray\n",
"from ray import tune\n",
"from ray.tune.schedulers import ASHAScheduler\n",
"from ray.tune import CLIReporter\n",
"\n",
"import torch\n",
"from torch import nn\n",
@@ -96,7 +97,6 @@
" train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)\n",
" val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)\n",
" \n",
" \n",
" conf = {\n",
" 'capture_input_dim' : 112,\n",
" 'score_input_dim' : 50,\n",
@@ -116,47 +116,43 @@
" logger = TensorBoardLogger(save_dir='logs/', name=name)\n",
" checkpoint = ModelCheckpoint(\n",
" dirpath=save_dir, \n",
" filename='{epoch}-{val_loss:.2f}', \n",
" monitor='val_loss',\n",
" mode='min',\n",
" save_top_k=10, \n",
" filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}', \n",
" monitor='val_acc',\n",
" mode='max',\n",
" save_top_k=5, \n",
" every_n_epochs = 5\n",
" )\n",
"# reporter = TuneReportCallback(\n",
"# {\n",
"# \"loss\": \"ptl/val_loss\",\n",
"# \"mean_accuracy\": \"ptl/val_acc\"\n",
"# },\n",
"# on=\"validation_end\"\n",
"# )\n",
" os.makedirs(save_dir, exist_ok=True)\n",
" json.dump(conf, open(save_dir + 'config.json', 'w'))\n",
"\n",
" trainer = Trainer(\n",
" benchmark=True, \n",
" gpus=[1], \n",
" gpus=[0], \n",
" accumulate_grad_batches=64,\n",
" logger=logger, \n",
" enable_progress_bar=False,\n",
" max_epochs=10,\n",
" callbacks=[checkpoint]\n",
" max_epochs=20,\n",
" callbacks=[\n",
" checkpoint,\n",
" ]\n",
" )\n",
" trainer.fit(model, train_loader, val_loader)\n",
" res = trainer.validate(val_loader)[0]\n",
" tune.report(**res)"
" res = trainer.validate(model, val_loader)[0]\n",
" tune.report(**res)\n",
" return"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"== Status ==<br>Current time: 2022-03-09 01:53:36 (running for 00:01:52.55)<br>Memory usage on this node: 14.1/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 1.0/18 CPUs, 2.0/2 GPUs, 0.0/71.06 GiB heap, 0.0/34.44 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/experiment_2022-03-09_01-51-43<br>Number of trials: 16/16 (15 PENDING, 1 RUNNING)<br><br>"
"== Status ==<br>Current time: 2022-03-11 21:39:55 (running for 00:09:12.46)<br>Memory usage on this node: 15.5/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 2.0/18 CPUs, 2.0/2 GPUs, 0.0/70.74 GiB heap, 0.0/34.31 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/CSI-weibo<br>Number of trials: 20/20 (19 PENDING, 1 RUNNING)<br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
@@ -169,14 +165,15 @@
"source": [
"analysis = tune.run(\n",
" experiment,\n",
" num_samples=4,\n",
" resources_per_trial={\"cpu\": 1, \"gpu\": 2},\n",
" num_samples=5,\n",
" resources_per_trial={\"cpu\": 2, \"gpu\": 2},\n",
" verbose=1,\n",
" config={\n",
" \"weight_decay\": tune.grid_search([0., 0.1, 0.01, 0.001]),\n",
" \"lr\": tune.loguniform(1e-5, 1e-1),\n",
" \"lr\": tune.loguniform(1e-4, 1e-1),\n",
" \"dropout\": tune.uniform(0., 0.3)\n",
" }\n",
" },\n",
" name='CSI-weibo'\n",
")\n"
]
},
@@ -186,7 +183,16 @@
"metadata": {},
"outputs": [],
"source": [
"df = analysis.results_df"
"analysis.results_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis.results_df.to_csv('results.csv', index=False)"
]
}
],

Loading…
Cancel
Save