123456789101112131415161718192021222324252627 |
- import models
- import torch
- import os
-
- #Adopted from the ACSNet
-
- def generate_model(opt):
- model = getattr(models, opt.model)(opt.nclasses)
- if opt.use_gpu:
- model.cuda()
- torch.backends.cudnn.benchmark = True
-
- if opt.load_ckpt is not None:
- model_dict = model.state_dict()
- #load_ckpt_path = os.path.join('./checkpoints/exp'+str(opt.expID)+'/', opt.load_ckpt + '.pth')
- load_ckpt_path = os.path.join('./checkpoints/exp-cvcframe/', 'ck_'+ str(opt.load_ckpt) + '.pth')
- print(load_ckpt_path)
- assert os.path.isfile(load_ckpt_path), 'No checkpoint found.'
- print('Loading checkpoint......')
- checkpoint = torch.load(load_ckpt_path)
- new_dict = {k : v for k, v in checkpoint.items() if k in model_dict.keys()}
- model_dict.update(new_dict)
- model.load_state_dict(model_dict)
-
- print('Done')
-
- return model
|