1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- import argparse
- import os
- import torch
-
- media_path_root = "path/to/media"
-
- class Config:
- def __init__(self):
- self.parser = argparse.ArgumentParser(description='Train a student network.')
- self.add_arguments()
- self.args = self.parse()
- self.update_paths()
-
- def add_arguments(self):
- # General settings
- self.parser.add_argument('--seed', type=int, default=12345, help='Init Seed')
- self.parser.add_argument('--device', type=int, default=0, help='Device number to use for training')
- self.parser.add_argument("--do_train", type=int, default=1, help="Whether to run training.")
- self.parser.add_argument("--do_eval", type=int, default=1, help="Whether to run evaluation.")
- self.parser.add_argument('--dataset', type=str, default='activity-net', choices=['activity-net', 'msrvtt'], help='Dataset to use')
- self.parser.add_argument('--teacher_model', type=str, default='CLIP', choices=['CLIP', 'DiffusionRet', 'EMCL', 'HBI', 'DiCoSA'], help='Teacher model to use')
- self.parser.add_argument('--workers', type=int, default=12, help='Number of workers')
-
- # Model configurations
- self.parser.add_argument('--dropout_ratio', type=float, default=0.3, help='Dropout ratio for the model')
- self.parser.add_argument('--fc1_size', type=int, default=512, help='First FC layer Hidden size')
- self.parser.add_argument('--fc2_size', type=int, default=256, help='Second FC layer Hidden size')
- self.parser.add_argument('--fc3_size', type=int, default=128, help='Third FC layer Hidden size')
- self.parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'], help='Backbone model')
- self.parser.add_argument('--n_trainable_backbone_layers', type=int, default=1, help='Number of trainable backbone layers')
- self.parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'leakyrelu', 'gelu'], help='Activation function')
- self.parser.add_argument('--normalization_layer', type=str, default='sigmoid', choices=['sigmoid', 'softmax', 'gumbel', 'k_top'], help='Normalization layer type (for text-to-video retireval pipeline)')
- self.parser.add_argument('--init_tau', default=5.0, type=float, help="annealing init temperature")
- self.parser.add_argument('--min_tau', default=0.5, type=float, help="min temperature to anneal to")
- self.parser.add_argument('--decay_factor', default=0.045, type=float, help="exp decay factor per epoch")
-
- self.parser.add_argument('--store_complete_model', type=int, default=0, help='Store the compelte model')
- self.parser.add_argument('--load_complete_model', type=int, default=0, help='1: Load')
- self.parser.add_argument('--trained_model_epoch', type=int, default=100, help='Load the trained model from which epoch')
-
- self.parser.add_argument('--store_backbone', type=int, default=0, help='Store just backbone')
- self.parser.add_argument('--load_backbone', type=int, default=1, help='1: Load')
- self.parser.add_argument('--trained_backbone_epoch', type=int, default=100, help='Load a trained backbone model from which epoch')
-
- # Dataloader configurations
- self.parser.add_argument('--normalization', type=str, default='binary', choices=['raw', 'min_max', 'binary'], help='Normalization type for dataloader')
- self.parser.add_argument('--binary_percentage', type=float, default=0.1, help='Binary percentage for dataloader')
-
- # Training configurations
- self.parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training')
- self.parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for training')
- self.parser.add_argument('--similarity_matrix_loss', type=str, default='kl', choices=['mse','kl'], help='frame2frame loss')
- self.parser.add_argument('--use_weighted_loss', type=int, default=0, help='Whether use weighted BCE, to handle unbalanced classes')
- self.parser.add_argument('--alpha', type=float, default=1, help='Weight for f2f loss')
- self.parser.add_argument('--beta', type=float, default=1, help='Weight for f2t loss')
- self.parser.add_argument('--gamma', type=float, default=1, help='Weight for v2t loss')
- self.parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
- self.parser.add_argument('--use_scheduler', type=int, default=1, help='Use a learning rate scheduler')
- self.parser.add_argument('--scheduler_step_size', type=float, default=5, help='After how many epochs the learning rate will be adjusted')
- self.parser.add_argument('--scheduler_gamma', type=float, default=0.4, help='The factor by which the learning rate will be reduced at each step.')
- self.parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay for the optimizer. Set 0 to disable it')
- self.parser.add_argument('--do_cluster', type=int, default=1, help='Calculate results with cluster as well or not.')
-
-
- def parse(self):
- return self.parser.parse_args()
-
- def update_paths(self):
- self.args.device = torch.device(f'cuda:{self.args.device}' if torch.cuda.is_available() else "cpu")
-
- self.args.raw_backbone_path = f"{media_path_root}/saved_models/student_pretrained_backbones/{self.args.backbone}_pretrained.pth"
- self.args.st_scores_path = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/st_scores/st_scores.json"
- self.args.trained_backbone_path = f"{media_path_root}/experiments/saved_models/{self.args.backbone}/{self.args.similarity_matrix_loss}"
- self.args.trained_model_path = f"{self.args.trained_backbone_path}/complete_model/"
- self.args.log_dir = f'logs/{self.args.project_name}/'
-
- os.makedirs(self.args.st_scores_path, exist_ok=True)
- os.makedirs(self.args.trained_backbone_path, exist_ok=True)
- os.makedirs(self.args.trained_model_path, exist_ok=True)
- os.makedirs(self.args.log_dir, exist_ok=True)
-
- self.args.paths = {'train':{}, 'valid':{}}
-
- for data_split in ['train', 'valid']:
- split = 'val_1' if data_split=='valid' else data_split
- self.args.paths[data_split]['frames'] = f"{media_path_root}/{self.args.dataset}/sampled_frames/{split}"
- self.args.paths[data_split]['frame_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/frames/{split}"
- self.args.paths[data_split]['text_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/description/{split}"
- self.args.paths[data_split]['gt_scores'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/gt_scores/gt_{split}.json"
|