| import pickle | import pickle | ||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import gc | |||||
| from datetime import datetime | from datetime import datetime | ||||
| def cv(args, out_dir): | def cv(args, out_dir): | ||||
| torch.cuda.set_per_process_memory_fraction(0.6, 0) | torch.cuda.set_per_process_memory_fraction(0.6, 0) | ||||
| # Clear any cached memory | # Clear any cached memory | ||||
| gc.collect() | |||||
| torch.cuda.empty_cache() | torch.cuda.empty_cache() | ||||
| save_args(args, os.path.join(out_dir, 'args.json')) | save_args(args, os.path.join(out_dir, 'args.json')) | ||||
| test_loss_file = os.path.join(out_dir, 'test_loss.pkl') | test_loss_file = os.path.join(out_dir, 'test_loss.pkl') | ||||
| logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss)) | logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss)) | ||||
| logging.info("-" * n_delimiter) | logging.info("-" * n_delimiter) | ||||
| losses.append(inner_loss) | losses.append(inner_loss) | ||||
| gc.collect() | |||||
| torch.cuda.empty_cache() | torch.cuda.empty_cache() | ||||
| time.sleep(10) | time.sleep(10) | ||||
| min_ls, min_idx = arg_min(losses) | min_ls, min_idx = arg_min(losses) |