Browse Source

add gc for clearing gpu cache

main
MahsaYazdani 1 year ago
parent
commit
59f597e94a
1 changed files with 3 additions and 0 deletions
  1. 3
    0
      predictor/cross_validation.py

+ 3
- 0
predictor/cross_validation.py View File

@@ -5,6 +5,7 @@ import time
import pickle
import torch
import torch.nn as nn
import gc

from datetime import datetime

@@ -106,6 +107,7 @@ def create_model(data, hidden_size, gpu_id=None):
def cv(args, out_dir):
torch.cuda.set_per_process_memory_fraction(0.6, 0)
# Clear any cached memory
gc.collect()
torch.cuda.empty_cache()
save_args(args, os.path.join(out_dir, 'args.json'))
test_loss_file = os.path.join(out_dir, 'test_loss.pkl')
@@ -153,6 +155,7 @@ def cv(args, out_dir):
logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss))
logging.info("-" * n_delimiter)
losses.append(inner_loss)
gc.collect()
torch.cuda.empty_cache()
time.sleep(10)
min_ls, min_idx = arg_min(losses)

Loading…
Cancel
Save