Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

checkpoint.py 1.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. from loguru import logger
  5. import torch
  6. import os
  7. import shutil
  8. def load_ckpt(model, ckpt):
  9. model_state_dict = model.state_dict()
  10. load_dict = {}
  11. for key_model, v in model_state_dict.items():
  12. if key_model not in ckpt:
  13. logger.warning(
  14. "{} is not in the ckpt. Please double check and see if this is desired.".format(
  15. key_model
  16. )
  17. )
  18. continue
  19. v_ckpt = ckpt[key_model]
  20. if v.shape != v_ckpt.shape:
  21. logger.warning(
  22. "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
  23. key_model, v_ckpt.shape, key_model, v.shape
  24. )
  25. )
  26. continue
  27. load_dict[key_model] = v_ckpt
  28. model.load_state_dict(load_dict, strict=False)
  29. return model
  30. def save_checkpoint(state, is_best, save_dir, model_name=""):
  31. if not os.path.exists(save_dir):
  32. os.makedirs(save_dir)
  33. filename = os.path.join(save_dir, model_name + "_ckpt.pth.tar")
  34. torch.save(state, filename)
  35. if is_best:
  36. best_filename = os.path.join(save_dir, "best_ckpt.pth.tar")
  37. shutil.copyfile(filename, best_filename)