You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

_model_loading.py 2.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from os import path
  2. from typing import TYPE_CHECKING
  3. import torch
  4. from ..models.model import Model
  5. if TYPE_CHECKING:
  6. from ..configs.base_config import BaseConfig
  7. def load_model(the_model: Model, conf: 'BaseConfig') -> Model:
  8. dev_name = conf.dev_name
  9. if conf.final_model_dir is not None:
  10. load_dir = conf.final_model_dir
  11. if not path.exists(load_dir):
  12. raise Exception('No path %s' % load_dir)
  13. else:
  14. load_dir = conf.save_dir
  15. print('>>> loading from: ' + load_dir, flush=True)
  16. if not path.exists(load_dir):
  17. raise Exception('Problem in loading the model. %s does not exist!' % load_dir)
  18. map_location = None
  19. if dev_name == 'cpu':
  20. map_location = 'cpu'
  21. elif dev_name is not None and ':' in dev_name:
  22. map_location = dev_name
  23. the_model.to(conf.device)
  24. if path.isfile(load_dir):
  25. if not path.exists(load_dir):
  26. raise Exception('Problem in loading the model. %s does not exist!' % load_dir)
  27. the_model.load_state_dict(torch.load(load_dir, map_location=map_location))
  28. print('Loaded the model at %s' % load_dir, flush=True)
  29. else:
  30. if conf.epoch is not None:
  31. epoch = int(conf.epoch)
  32. elif path.exists(load_dir + '/GeneralInfo'):
  33. checkpoint = torch.load(load_dir + '/GeneralInfo', map_location=map_location)
  34. epoch = checkpoint['best_val_epoch']
  35. print(f'Epoch has not been specified in the config, '
  36. f'using {epoch} which is best_val_epoch instead')
  37. else:
  38. raise Exception('Either epoch or pt dir (final_model) must be given as GeneralInfo was not found')
  39. if not path.exists('%s/%d' % (load_dir, epoch)):
  40. raise Exception('Problem in loading the model. %s/%d does not exist!' %
  41. (load_dir, epoch))
  42. checkpoint = torch.load('%s/%d' % (load_dir, epoch),
  43. map_location=map_location)
  44. # backward compatibility
  45. if 'model_state_dict' in checkpoint:
  46. checkpoint = checkpoint['model_state_dict']
  47. the_model.load_state_dict(checkpoint)
  48. print(f'Loaded the model from epoch {epoch} of {load_dir}', flush=True)
  49. the_model.to(conf.device)
  50. return the_model