You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 2.2KB

4 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import os
  2. import torch
  3. from config import Config
  4. from src.model import SaliencyNet
  5. from torch.utils.data import DataLoader
  6. from src.train import SalientFrameSamplerTrainer
  7. from dataloaders.t2v_dataloader import VideoDescriptionDataloader
  8. def load_weights(path, which_epoch):
  9. weights, last_epoch = None, None
  10. available_models = [name for name in os.listdir(path) if name.endswith(".pt")]
  11. if available_models:
  12. last_epoch = max([int(name[6:-3]) for name in available_models])
  13. last_epoch = min(last_epoch, which_epoch)
  14. weights = torch.load(os.path.join(path, f'epoch_{last_epoch}.pt'))
  15. return weights, last_epoch
  16. def set_seeds(seed: int):
  17. os.environ['PYTHONHASHSEED'] = str(seed)
  18. torch.manual_seed(seed)
  19. torch.cuda.manual_seed(seed)
  20. torch.backends.cudnn.benchmark = False
  21. torch.backends.cudnn.deterministic = True
  22. def main(cfg):
  23. set_seeds(cfg.seed)
  24. model = SaliencyNet(cfg)
  25. model.to(cfg.device)
  26. if cfg.load_complete_model:
  27. weights, last_epoch = load_weights(cfg.trained_model_path, cfg.trained_model_epoch)
  28. if weights:
  29. model.load_state_dict(weights)
  30. print(f'Complete model -{cfg.backbone} trained on {cfg.similarity_matrix_loss} and {cfg.saliency_matching_loss} lossese from epoch #{last_epoch}')
  31. elif cfg.load_backbone:
  32. weights, last_epoch = load_weights(cfg.trained_backbone_path, cfg.trained_backbone_epoch)
  33. if weights:
  34. model.pretrained.load_state_dict(weights)
  35. print(f'{cfg.backbone} backbone trained on {cfg.similarity_matrix_loss} loss loaded from epoch #{last_epoch}')
  36. else:
  37. print(f'{cfg.backbone} backbone loaded from scratch.')
  38. train_dataset = VideoDescriptionDataloader(cfg, data_type='train')
  39. val_dataset = VideoDescriptionDataloader(cfg, data_type='valid')
  40. train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=cfg.workers)
  41. val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=cfg.workers)
  42. trainer = SalientFrameSamplerTrainer(model, train_dataloader, val_dataloader, cfg)
  43. trainer.train()
  44. if __name__ == "__main__":
  45. torch.multiprocessing.set_start_method("spawn")
  46. cfg = Config().args
  47. main(cfg)