from os import path from typing import TYPE_CHECKING import torch from ..models.model import Model if TYPE_CHECKING: from ..configs.base_config import BaseConfig def load_model(the_model: Model, conf: 'BaseConfig') -> Model: dev_name = conf.dev_name if conf.final_model_dir is not None: load_dir = conf.final_model_dir if not path.exists(load_dir): raise Exception('No path %s' % load_dir) else: load_dir = conf.save_dir print('>>> loading from: ' + load_dir, flush=True) if not path.exists(load_dir): raise Exception('Problem in loading the model. %s does not exist!' % load_dir) map_location = None if dev_name == 'cpu': map_location = 'cpu' elif dev_name is not None and ':' in dev_name: map_location = dev_name the_model.to(conf.device) if path.isfile(load_dir): if not path.exists(load_dir): raise Exception('Problem in loading the model. %s does not exist!' % load_dir) the_model.load_state_dict(torch.load(load_dir, map_location=map_location)) print('Loaded the model at %s' % load_dir, flush=True) else: if conf.epoch is not None: epoch = int(conf.epoch) elif path.exists(load_dir + '/GeneralInfo'): checkpoint = torch.load(load_dir + '/GeneralInfo', map_location=map_location) epoch = checkpoint['best_val_epoch'] print(f'Epoch has not been specified in the config, ' f'using {epoch} which is best_val_epoch instead') else: raise Exception('Either epoch or pt dir (final_model) must be given as GeneralInfo was not found') if not path.exists('%s/%d' % (load_dir, epoch)): raise Exception('Problem in loading the model. %s/%d does not exist!' % (load_dir, epoch)) checkpoint = torch.load('%s/%d' % (load_dir, epoch), map_location=map_location) # backward compatibility if 'model_state_dict' in checkpoint: checkpoint = checkpoint['model_state_dict'] the_model.load_state_dict(checkpoint) print(f'Loaded the model from epoch {epoch} of {load_dir}', flush=True) the_model.to(conf.device) return the_model