| @@ -35,6 +35,7 @@ | |||
| "import ray\n", | |||
| "from ray import tune\n", | |||
| "from ray.tune.schedulers import ASHAScheduler\n", | |||
| "from ray.tune import CLIReporter\n", | |||
| "\n", | |||
| "import torch\n", | |||
| "from torch import nn\n", | |||
| @@ -96,7 +97,6 @@ | |||
| " train_loader = DataLoader(train_set, batch_size=1, shuffle=False, num_workers=1)\n", | |||
| " val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1)\n", | |||
| " \n", | |||
| " \n", | |||
| " conf = {\n", | |||
| " 'capture_input_dim' : 112,\n", | |||
| " 'score_input_dim' : 50,\n", | |||
| @@ -116,47 +116,43 @@ | |||
| " logger = TensorBoardLogger(save_dir='logs/', name=name)\n", | |||
| " checkpoint = ModelCheckpoint(\n", | |||
| " dirpath=save_dir, \n", | |||
| " filename='{epoch}-{val_loss:.2f}', \n", | |||
| " monitor='val_loss',\n", | |||
| " mode='min',\n", | |||
| " save_top_k=10, \n", | |||
| " filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}', \n", | |||
| " monitor='val_acc',\n", | |||
| " mode='max',\n", | |||
| " save_top_k=5, \n", | |||
| " every_n_epochs = 5\n", | |||
| " )\n", | |||
| "# reporter = TuneReportCallback(\n", | |||
| "# {\n", | |||
| "# \"loss\": \"ptl/val_loss\",\n", | |||
| "# \"mean_accuracy\": \"ptl/val_acc\"\n", | |||
| "# },\n", | |||
| "# on=\"validation_end\"\n", | |||
| "# )\n", | |||
| " os.makedirs(save_dir, exist_ok=True)\n", | |||
| " json.dump(conf, open(save_dir + 'config.json', 'w'))\n", | |||
| "\n", | |||
| " trainer = Trainer(\n", | |||
| " benchmark=True, \n", | |||
| " gpus=[1], \n", | |||
| " gpus=[0], \n", | |||
| " accumulate_grad_batches=64,\n", | |||
| " logger=logger, \n", | |||
| " enable_progress_bar=False,\n", | |||
| " max_epochs=10,\n", | |||
| " callbacks=[checkpoint]\n", | |||
| " max_epochs=20,\n", | |||
| " callbacks=[\n", | |||
| " checkpoint,\n", | |||
| " ]\n", | |||
| " )\n", | |||
| " trainer.fit(model, train_loader, val_loader)\n", | |||
| " res = trainer.validate(val_loader)[0]\n", | |||
| " tune.report(**res)" | |||
| " res = trainer.validate(model, val_loader)[0]\n", | |||
| " tune.report(**res)\n", | |||
| " return" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": { | |||
| "scrolled": true | |||
| "scrolled": false | |||
| }, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "text/html": [ | |||
| "== Status ==<br>Current time: 2022-03-09 01:53:36 (running for 00:01:52.55)<br>Memory usage on this node: 14.1/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 1.0/18 CPUs, 2.0/2 GPUs, 0.0/71.06 GiB heap, 0.0/34.44 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/experiment_2022-03-09_01-51-43<br>Number of trials: 16/16 (15 PENDING, 1 RUNNING)<br><br>" | |||
| "== Status ==<br>Current time: 2022-03-11 21:39:55 (running for 00:09:12.46)<br>Memory usage on this node: 15.5/125.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 2.0/18 CPUs, 2.0/2 GPUs, 0.0/70.74 GiB heap, 0.0/34.31 GiB objects (0.0/1.0 accelerator_type:GTX)<br>Result logdir: /home/ramezani/ray_results/CSI-weibo<br>Number of trials: 20/20 (19 PENDING, 1 RUNNING)<br><br>" | |||
| ], | |||
| "text/plain": [ | |||
| "<IPython.core.display.HTML object>" | |||
| @@ -169,14 +165,15 @@ | |||
| "source": [ | |||
| "analysis = tune.run(\n", | |||
| " experiment,\n", | |||
| " num_samples=4,\n", | |||
| " resources_per_trial={\"cpu\": 1, \"gpu\": 2},\n", | |||
| " num_samples=5,\n", | |||
| " resources_per_trial={\"cpu\": 2, \"gpu\": 2},\n", | |||
| " verbose=1,\n", | |||
| " config={\n", | |||
| " \"weight_decay\": tune.grid_search([0., 0.1, 0.01, 0.001]),\n", | |||
| " \"lr\": tune.loguniform(1e-5, 1e-1),\n", | |||
| " \"lr\": tune.loguniform(1e-4, 1e-1),\n", | |||
| " \"dropout\": tune.uniform(0., 0.3)\n", | |||
| " }\n", | |||
| " },\n", | |||
| " name='CSI-weibo'\n", | |||
| ")\n" | |||
| ] | |||
| }, | |||
| @@ -186,7 +183,16 @@ | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "df = analysis.results_df" | |||
| "analysis.results_df" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "analysis.results_df.to_csv('results.csv', index=False)" | |||
| ] | |||
| } | |||
| ], | |||