|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- from os import path
- from typing import TYPE_CHECKING
- import torch
- from ..models.model import Model
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
-
-
- def load_model(the_model: Model, conf: 'BaseConfig') -> Model:
-
- dev_name = conf.dev_name
-
- if conf.final_model_dir is not None:
- load_dir = conf.final_model_dir
- if not path.exists(load_dir):
- raise Exception('No path %s' % load_dir)
- else:
- load_dir = conf.save_dir
-
- print('>>> loading from: ' + load_dir, flush=True)
- if not path.exists(load_dir):
- raise Exception('Problem in loading the model. %s does not exist!' % load_dir)
-
- map_location = None
- if dev_name == 'cpu':
- map_location = 'cpu'
- elif dev_name is not None and ':' in dev_name:
- map_location = dev_name
-
- the_model.to(conf.device)
-
- if path.isfile(load_dir):
-
- if not path.exists(load_dir):
- raise Exception('Problem in loading the model. %s does not exist!' % load_dir)
-
- the_model.load_state_dict(torch.load(load_dir, map_location=map_location))
- print('Loaded the model at %s' % load_dir, flush=True)
-
- else:
- if conf.epoch is not None:
- epoch = int(conf.epoch)
- elif path.exists(load_dir + '/GeneralInfo'):
- checkpoint = torch.load(load_dir + '/GeneralInfo', map_location=map_location)
- epoch = checkpoint['best_val_epoch']
- print(f'Epoch has not been specified in the config, '
- f'using {epoch} which is best_val_epoch instead')
- else:
- raise Exception('Either epoch or pt dir (final_model) must be given as GeneralInfo was not found')
-
- if not path.exists('%s/%d' % (load_dir, epoch)):
- raise Exception('Problem in loading the model. %s/%d does not exist!' %
- (load_dir, epoch))
-
- checkpoint = torch.load('%s/%d' % (load_dir, epoch),
- map_location=map_location)
-
- # backward compatibility
- if 'model_state_dict' in checkpoint:
- checkpoint = checkpoint['model_state_dict']
-
- the_model.load_state_dict(checkpoint)
-
- print(f'Loaded the model from epoch {epoch} of {load_dir}', flush=True)
-
- the_model.to(conf.device)
-
- return the_model
|