123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- import os
- import torch
- from config import Config
- from src.model import SaliencyNet
- from torch.utils.data import DataLoader
- from src.train import SalientFrameSamplerTrainer
- from dataloaders.t2v_dataloader import VideoDescriptionDataloader
-
- def load_weights(path, which_epoch):
- weights, last_epoch = None, None
- available_models = [name for name in os.listdir(path) if name.endswith(".pt")]
-
- if available_models:
- last_epoch = max([int(name[6:-3]) for name in available_models])
- last_epoch = min(last_epoch, which_epoch)
- weights = torch.load(os.path.join(path, f'epoch_{last_epoch}.pt'))
-
- return weights, last_epoch
-
- def set_seeds(seed: int):
- os.environ['PYTHONHASHSEED'] = str(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.backends.cudnn.benchmark = False
- torch.backends.cudnn.deterministic = True
-
- def main(cfg):
-
- set_seeds(cfg.seed)
-
- model = SaliencyNet(cfg)
-
- model.to(cfg.device)
-
- if cfg.load_complete_model:
- weights, last_epoch = load_weights(cfg.trained_model_path, cfg.trained_model_epoch)
- if weights:
- model.load_state_dict(weights)
- print(f'Complete model -{cfg.backbone} trained on {cfg.similarity_matrix_loss} and {cfg.saliency_matching_loss} lossese from epoch #{last_epoch}')
-
- elif cfg.load_backbone:
- weights, last_epoch = load_weights(cfg.trained_backbone_path, cfg.trained_backbone_epoch)
- if weights:
- model.pretrained.load_state_dict(weights)
- print(f'{cfg.backbone} backbone trained on {cfg.similarity_matrix_loss} loss loaded from epoch #{last_epoch}')
-
- else:
- print(f'{cfg.backbone} backbone loaded from scratch.')
-
- train_dataset = VideoDescriptionDataloader(cfg, data_type='train')
- val_dataset = VideoDescriptionDataloader(cfg, data_type='valid')
-
- train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=cfg.workers)
- val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=cfg.workers)
-
- trainer = SalientFrameSamplerTrainer(model, train_dataloader, val_dataloader, cfg)
- trainer.train()
-
- if __name__ == "__main__":
- torch.multiprocessing.set_start_method("spawn")
- cfg = Config().args
-
- main(cfg)
|