You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

comm.py 905B

123456789101112131415161718192021222324252627
  1. import models
  2. import torch
  3. import os
  4. #Adopted from the ACSNet
  5. def generate_model(opt):
  6. model = getattr(models, opt.model)(opt.nclasses)
  7. if opt.use_gpu:
  8. model.cuda()
  9. torch.backends.cudnn.benchmark = True
  10. if opt.load_ckpt is not None:
  11. model_dict = model.state_dict()
  12. #load_ckpt_path = os.path.join('./checkpoints/exp'+str(opt.expID)+'/', opt.load_ckpt + '.pth')
  13. load_ckpt_path = os.path.join('./checkpoints/exp-cvcframe/', 'ck_'+ str(opt.load_ckpt) + '.pth')
  14. print(load_ckpt_path)
  15. assert os.path.isfile(load_ckpt_path), 'No checkpoint found.'
  16. print('Loading checkpoint......')
  17. checkpoint = torch.load(load_ckpt_path)
  18. new_dict = {k : v for k, v in checkpoint.items() if k in model_dict.keys()}
  19. model_dict.update(new_dict)
  20. model.load_state_dict(model_dict)
  21. print('Done')
  22. return model