Browse Source

add gc for clearing gpu cache

main
MahsaYazdani 2 years ago
parent
commit
59f597e94a
1 changed files with 3 additions and 0 deletions
  1. 3
    0
      predictor/cross_validation.py

+ 3
- 0
predictor/cross_validation.py View File

import pickle import pickle
import torch import torch
import torch.nn as nn import torch.nn as nn
import gc


from datetime import datetime from datetime import datetime


def cv(args, out_dir): def cv(args, out_dir):
torch.cuda.set_per_process_memory_fraction(0.6, 0) torch.cuda.set_per_process_memory_fraction(0.6, 0)
# Clear any cached memory # Clear any cached memory
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
save_args(args, os.path.join(out_dir, 'args.json')) save_args(args, os.path.join(out_dir, 'args.json'))
test_loss_file = os.path.join(out_dir, 'test_loss.pkl') test_loss_file = os.path.join(out_dir, 'test_loss.pkl')
logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss)) logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss))
logging.info("-" * n_delimiter) logging.info("-" * n_delimiter)
losses.append(inner_loss) losses.append(inner_loss)
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
time.sleep(10) time.sleep(10)
min_ls, min_idx = arg_min(losses) min_ls, min_idx = arg_min(losses)

Loading…
Cancel
Save