import os import torch from config import Config from src.model import SaliencyNet from torch.utils.data import DataLoader from src.train import SalientFrameSamplerTrainer from dataloaders.t2v_dataloader import VideoDescriptionDataloader def load_weights(path, which_epoch): weights, last_epoch = None, None available_models = [name for name in os.listdir(path) if name.endswith(".pt")] if available_models: last_epoch = max([int(name[6:-3]) for name in available_models]) last_epoch = min(last_epoch, which_epoch) weights = torch.load(os.path.join(path, f'epoch_{last_epoch}.pt')) return weights, last_epoch def set_seeds(seed: int): os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def main(cfg): set_seeds(cfg.seed) model = SaliencyNet(cfg) model.to(cfg.device) if cfg.load_complete_model: weights, last_epoch = load_weights(cfg.trained_model_path, cfg.trained_model_epoch) if weights: model.load_state_dict(weights) print(f'Complete model -{cfg.backbone} trained on {cfg.similarity_matrix_loss} and {cfg.saliency_matching_loss} lossese from epoch #{last_epoch}') elif cfg.load_backbone: weights, last_epoch = load_weights(cfg.trained_backbone_path, cfg.trained_backbone_epoch) if weights: model.pretrained.load_state_dict(weights) print(f'{cfg.backbone} backbone trained on {cfg.similarity_matrix_loss} loss loaded from epoch #{last_epoch}') else: print(f'{cfg.backbone} backbone loaded from scratch.') train_dataset = VideoDescriptionDataloader(cfg, data_type='train') val_dataset = VideoDescriptionDataloader(cfg, data_type='valid') train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=cfg.workers) val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=cfg.workers) trainer = SalientFrameSamplerTrainer(model, train_dataloader, val_dataloader, cfg) trainer.train() if __name__ == "__main__": torch.multiprocessing.set_start_method("spawn") cfg = Config().args main(cfg)