|
|
@@ -35,6 +35,7 @@ |
|
|
|
"import ray\n", |
|
|
|
"from ray import tune\n", |
|
|
|
"from ray.tune.schedulers import ASHAScheduler\n", |
|
|
|
"from ray.tune import CLIReporter\n", |
|
|
|
"\n", |
|
|
|
"import torch\n", |
|
|
|
"from torch import nn\n", |
|
|
@@ -96,7 +97,6 @@ |
|
|
|
" train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)\n", |
|
|
|
" val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)\n", |
|
|
|
" \n", |
|
|
|
" \n", |
|
|
|
" conf = {\n", |
|
|
|
" 'capture_input_dim' : 112,\n", |
|
|
|
" 'score_input_dim' : 50,\n", |
|
|
@@ -116,47 +116,43 @@ |
|
|
|
" logger = TensorBoardLogger(save_dir='logs/', name=name)\n", |
|
|
|
" checkpoint = ModelCheckpoint(\n", |
|
|
|
" dirpath=save_dir, \n", |
|
|
|
" filename='{epoch}-{val_loss:.2f}', \n", |
|
|
|
" monitor='val_loss',\n", |
|
|
|
" mode='min',\n", |
|
|
|
" save_top_k=10, \n", |
|
|
|
" filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}', \n", |
|
|
|
" monitor='val_acc',\n", |
|
|
|
" mode='max',\n", |
|
|
|
" save_top_k=5, \n", |
|
|
|
" every_n_epochs = 5\n", |
|
|
|
" )\n", |
|
|
|
"# reporter = TuneReportCallback(\n", |
|
|
|
"# {\n", |
|
|
|
"# \"loss\": \"ptl/val_loss\",\n", |
|
|
|
"# \"mean_accuracy\": \"ptl/val_acc\"\n", |
|
|
|
"# },\n", |
|
|
|
"# on=\"validation_end\"\n", |
|
|
|
"# )\n", |
|
|
|
" os.makedirs(save_dir, exist_ok=True)\n", |
|
|
|
" json.dump(conf, open(save_dir + 'config.json', 'w'))\n", |
|
|
|
"\n", |
|
|
|
" trainer = Trainer(\n", |
|
|
|
" benchmark=True, \n", |
|
|
|
" gpus=[1], \n", |
|
|
|
" gpus=[0], \n", |
|
|
|
" accumulate_grad_batches=64,\n", |
|
|
|
" logger=logger, \n", |
|
|
|
" enable_progress_bar=False,\n", |
|
|
|
" max_epochs=10,\n", |
|
|
|
" callbacks=[checkpoint]\n", |
|
|
|
" max_epochs=20,\n", |
|
|
|
" callbacks=[\n", |
|
|
|
" checkpoint,\n", |
|
|
|
" ]\n", |
|
|
|
" )\n", |
|
|
|
" trainer.fit(model, train_loader, val_loader)\n", |
|
|
|
" res = trainer.validate(val_loader)[0]\n", |
|
|
|
" tune.report(**res)" |
|
|
|
" res = trainer.validate(model, val_loader)[0]\n", |
|
|
|
" tune.report(**res)\n", |
|
|
|
" return" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": { |
|
|
|
"scrolled": true |
|
|
|
"scrolled": false |
|
|
|
}, |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"data": { |
|
|
|
"text/html": [ |
|
|
|
"== Status ==<br>Current time: 2022-03-09 01:53:36 (running for 00:01:52.55)<br>Memory usage on this node: 14.1/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 1.0/18 CPUs, 2.0/2 GPUs, 0.0/71.06 GiB heap, 0.0/34.44 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/experiment_2022-03-09_01-51-43<br>Number of trials: 16/16 (15 PENDING, 1 RUNNING)<br><br>" |
|
|
|
"== Status ==<br>Current time: 2022-03-11 21:39:55 (running for 00:09:12.46)<br>Memory usage on this node: 15.5/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 2.0/18 CPUs, 2.0/2 GPUs, 0.0/70.74 GiB heap, 0.0/34.31 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/CSI-weibo<br>Number of trials: 20/20 (19 PENDING, 1 RUNNING)<br><br>" |
|
|
|
], |
|
|
|
"text/plain": [ |
|
|
|
"<IPython.core.display.HTML object>" |
|
|
@@ -169,14 +165,15 @@ |
|
|
|
"source": [ |
|
|
|
"analysis = tune.run(\n", |
|
|
|
" experiment,\n", |
|
|
|
" num_samples=4,\n", |
|
|
|
" resources_per_trial={\"cpu\": 1, \"gpu\": 2},\n", |
|
|
|
" num_samples=5,\n", |
|
|
|
" resources_per_trial={\"cpu\": 2, \"gpu\": 2},\n", |
|
|
|
" verbose=1,\n", |
|
|
|
" config={\n", |
|
|
|
" \"weight_decay\": tune.grid_search([0., 0.1, 0.01, 0.001]),\n", |
|
|
|
" \"lr\": tune.loguniform(1e-5, 1e-1),\n", |
|
|
|
" \"lr\": tune.loguniform(1e-4, 1e-1),\n", |
|
|
|
" \"dropout\": tune.uniform(0., 0.3)\n", |
|
|
|
" }\n", |
|
|
|
" },\n", |
|
|
|
" name='CSI-weibo'\n", |
|
|
|
")\n" |
|
|
|
] |
|
|
|
}, |
|
|
@@ -186,7 +183,16 @@ |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"df = analysis.results_df" |
|
|
|
"analysis.results_df" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"execution_count": null, |
|
|
|
"metadata": {}, |
|
|
|
"outputs": [], |
|
|
|
"source": [ |
|
|
|
"analysis.results_df.to_csv('results.csv', index=False)" |
|
|
|
] |
|
|
|
} |
|
|
|
], |