You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import os
  2. import pickle
  3. import matplotlib.pyplot as plt
  4. import torch
  5. import json
  6. import random
  7. def arg_min(lst):
  8. m = float('inf')
  9. idx = 0
  10. for i, v in enumerate(lst):
  11. if v < m:
  12. m = v
  13. idx = i
  14. return m, idx
  15. def save_best_model(state_dict, model_dir: str, best_epoch: int, keep: int):
  16. save_to = os.path.join(model_dir, '{}.pkl'.format(best_epoch))
  17. torch.save(state_dict, save_to)
  18. model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
  19. epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
  20. outdated = sorted(epochs, reverse=True)[keep:]
  21. for n in outdated:
  22. os.remove(os.path.join(model_dir, '{}.pkl'.format(n)))
  23. def find_best_model(model_dir: str):
  24. model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
  25. epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
  26. best_epoch = max(epochs)
  27. return os.path.join(model_dir, '{}.pkl'.format(best_epoch))
  28. def save_args(args, save_to: str):
  29. args_dict = args.__dict__
  30. with open(save_to, 'w') as f:
  31. json.dump(args_dict, f, indent=2)
  32. def read_map(map_file):
  33. d = {}
  34. with open(map_file, 'r') as f:
  35. f.readline()
  36. for line in f:
  37. k, v = line.rstrip().split('\t')
  38. d[k] = int(v)
  39. return d
  40. def random_split_indices(n_samples, train_rate: float = None, test_rate: float = None):
  41. if train_rate is not None and (train_rate < 0 or train_rate > 1):
  42. raise ValueError("train rate should be in [0, 1], found {}".format(train_rate))
  43. elif test_rate is not None:
  44. if test_rate < 0 or test_rate > 1:
  45. raise ValueError("test rate should be in [0, 1], found {}".format(test_rate))
  46. train_rate = 1 - test_rate
  47. elif train_rate is None and test_rate is None:
  48. raise ValueError("Either train_rate or test_rate should be given.")
  49. evidence = list(range(n_samples))
  50. train_size = int(len(evidence) * train_rate)
  51. random.shuffle(evidence)
  52. train_indices = evidence[:train_size]
  53. test_indices = evidence[train_size:]
  54. return train_indices, test_indices