import argparse import os import torch media_path_root = "path/to/media" class Config: def __init__(self): self.parser = argparse.ArgumentParser(description='Train a student network.') self.add_arguments() self.args = self.parse() self.update_paths() def add_arguments(self): # General settings self.parser.add_argument('--seed', type=int, default=12345, help='Init Seed') self.parser.add_argument('--device', type=int, default=0, help='Device number to use for training') self.parser.add_argument("--do_train", type=int, default=1, help="Whether to run training.") self.parser.add_argument("--do_eval", type=int, default=1, help="Whether to run evaluation.") self.parser.add_argument('--dataset', type=str, default='activity-net', choices=['activity-net', 'msrvtt'], help='Dataset to use') self.parser.add_argument('--teacher_model', type=str, default='CLIP', choices=['CLIP', 'DiffusionRet', 'EMCL', 'HBI', 'DiCoSA'], help='Teacher model to use') self.parser.add_argument('--workers', type=int, default=12, help='Number of workers') # Model configurations self.parser.add_argument('--dropout_ratio', type=float, default=0.3, help='Dropout ratio for the model') self.parser.add_argument('--fc1_size', type=int, default=512, help='First FC layer Hidden size') self.parser.add_argument('--fc2_size', type=int, default=256, help='Second FC layer Hidden size') self.parser.add_argument('--fc3_size', type=int, default=128, help='Third FC layer Hidden size') self.parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'], help='Backbone model') self.parser.add_argument('--n_trainable_backbone_layers', type=int, default=1, help='Number of trainable backbone layers') self.parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'leakyrelu', 'gelu'], help='Activation function') self.parser.add_argument('--normalization_layer', type=str, default='sigmoid', choices=['sigmoid', 'softmax', 'gumbel', 'k_top'], help='Normalization layer type (for text-to-video retireval pipeline)') self.parser.add_argument('--init_tau', default=5.0, type=float, help="annealing init temperature") self.parser.add_argument('--min_tau', default=0.5, type=float, help="min temperature to anneal to") self.parser.add_argument('--decay_factor', default=0.045, type=float, help="exp decay factor per epoch") self.parser.add_argument('--store_complete_model', type=int, default=0, help='Store the compelte model') self.parser.add_argument('--load_complete_model', type=int, default=0, help='1: Load') self.parser.add_argument('--trained_model_epoch', type=int, default=100, help='Load the trained model from which epoch') self.parser.add_argument('--store_backbone', type=int, default=0, help='Store just backbone') self.parser.add_argument('--load_backbone', type=int, default=1, help='1: Load') self.parser.add_argument('--trained_backbone_epoch', type=int, default=100, help='Load a trained backbone model from which epoch') # Dataloader configurations self.parser.add_argument('--normalization', type=str, default='binary', choices=['raw', 'min_max', 'binary'], help='Normalization type for dataloader') self.parser.add_argument('--binary_percentage', type=float, default=0.1, help='Binary percentage for dataloader') # Training configurations self.parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training') self.parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for training') self.parser.add_argument('--similarity_matrix_loss', type=str, default='kl', choices=['mse','kl'], help='frame2frame loss') self.parser.add_argument('--use_weighted_loss', type=int, default=0, help='Whether use weighted BCE, to handle unbalanced classes') self.parser.add_argument('--alpha', type=float, default=1, help='Weight for f2f loss') self.parser.add_argument('--beta', type=float, default=1, help='Weight for f2t loss') self.parser.add_argument('--gamma', type=float, default=1, help='Weight for v2t loss') self.parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') self.parser.add_argument('--use_scheduler', type=int, default=1, help='Use a learning rate scheduler') self.parser.add_argument('--scheduler_step_size', type=float, default=5, help='After how many epochs the learning rate will be adjusted') self.parser.add_argument('--scheduler_gamma', type=float, default=0.4, help='The factor by which the learning rate will be reduced at each step.') self.parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay for the optimizer. Set 0 to disable it') self.parser.add_argument('--do_cluster', type=int, default=1, help='Calculate results with cluster as well or not.') def parse(self): return self.parser.parse_args() def update_paths(self): self.args.device = torch.device(f'cuda:{self.args.device}' if torch.cuda.is_available() else "cpu") self.args.raw_backbone_path = f"{media_path_root}/saved_models/student_pretrained_backbones/{self.args.backbone}_pretrained.pth" self.args.st_scores_path = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/st_scores/st_scores.json" self.args.trained_backbone_path = f"{media_path_root}/experiments/saved_models/{self.args.backbone}/{self.args.similarity_matrix_loss}" self.args.trained_model_path = f"{self.args.trained_backbone_path}/complete_model/" self.args.log_dir = f'logs/{self.args.project_name}/' os.makedirs(self.args.st_scores_path, exist_ok=True) os.makedirs(self.args.trained_backbone_path, exist_ok=True) os.makedirs(self.args.trained_model_path, exist_ok=True) os.makedirs(self.args.log_dir, exist_ok=True) self.args.paths = {'train':{}, 'valid':{}} for data_split in ['train', 'valid']: split = 'val_1' if data_split=='valid' else data_split self.args.paths[data_split]['frames'] = f"{media_path_root}/{self.args.dataset}/sampled_frames/{split}" self.args.paths[data_split]['frame_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/frames/{split}" self.args.paths[data_split]['text_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/description/{split}" self.args.paths[data_split]['gt_scores'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/gt_scores/gt_{split}.json"