You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import os
  2. import pickle
  3. import matplotlib.pyplot as plt
  4. import torch
  5. import json
  6. import random
  7. def calc_stat(numbers):
  8. mu = sum(numbers) / len(numbers)
  9. sigma = (sum([(x - mu) ** 2 for x in numbers]) / len(numbers)) ** 0.5
  10. return mu, sigma
  11. def conf_inv(mu, sigma, n):
  12. delta = 2.776 * sigma / (n ** 0.5) # 95%
  13. return mu - delta, mu + delta
  14. def arg_min(lst):
  15. m = float('inf')
  16. idx = 0
  17. for i, v in enumerate(lst):
  18. if v < m:
  19. m = v
  20. idx = i
  21. return m, idx
  22. def save_best_model(state_dict, model_dir: str, best_epoch: int, keep: int):
  23. save_to = os.path.join(model_dir, '{}.pkl'.format(best_epoch))
  24. torch.save(state_dict, save_to)
  25. model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
  26. epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
  27. outdated = sorted(epochs, reverse=True)[keep:]
  28. for n in outdated:
  29. os.remove(os.path.join(model_dir, '{}.pkl'.format(n)))
  30. def find_best_model(model_dir: str):
  31. model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl']
  32. epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])]
  33. best_epoch = max(epochs)
  34. return os.path.join(model_dir, '{}.pkl'.format(best_epoch))
  35. def save_args(args, save_to: str):
  36. args_dict = args.__dict__
  37. with open(save_to, 'w') as f:
  38. json.dump(args_dict, f, indent=2)
  39. def read_map(map_file):
  40. d = {}
  41. with open(map_file, 'r') as f:
  42. f.readline()
  43. for line in f:
  44. k, v = line.rstrip().split('\t')
  45. d[k] = int(v)
  46. return d
  47. def random_split_indices(n_samples, train_rate: float = None, test_rate: float = None):
  48. if train_rate is not None and (train_rate < 0 or train_rate > 1):
  49. raise ValueError("train rate should be in [0, 1], found {}".format(train_rate))
  50. elif test_rate is not None:
  51. if test_rate < 0 or test_rate > 1:
  52. raise ValueError("test rate should be in [0, 1], found {}".format(test_rate))
  53. train_rate = 1 - test_rate
  54. elif train_rate is None and test_rate is None:
  55. raise ValueError("Either train_rate or test_rate should be given.")
  56. evidence = list(range(n_samples))
  57. train_size = int(len(evidence) * train_rate)
  58. random.shuffle(evidence)
  59. train_indices = evidence[:train_size]
  60. test_indices = evidence[train_size:]
  61. return train_indices, test_indices