You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 6.8KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import argparse
  2. import os
  3. import torch
  4. media_path_root = "path/to/media"
  5. class Config:
  6. def __init__(self):
  7. self.parser = argparse.ArgumentParser(description='Train a student network.')
  8. self.add_arguments()
  9. self.args = self.parse()
  10. self.update_paths()
  11. def add_arguments(self):
  12. # General settings
  13. self.parser.add_argument('--seed', type=int, default=12345, help='Init Seed')
  14. self.parser.add_argument('--device', type=int, default=0, help='Device number to use for training')
  15. self.parser.add_argument("--do_train", type=int, default=1, help="Whether to run training.")
  16. self.parser.add_argument("--do_eval", type=int, default=1, help="Whether to run evaluation.")
  17. self.parser.add_argument('--dataset', type=str, default='activity-net', choices=['activity-net', 'msrvtt'], help='Dataset to use')
  18. self.parser.add_argument('--teacher_model', type=str, default='CLIP', choices=['CLIP', 'DiffusionRet', 'EMCL', 'HBI', 'DiCoSA'], help='Teacher model to use')
  19. self.parser.add_argument('--workers', type=int, default=12, help='Number of workers')
  20. # Model configurations
  21. self.parser.add_argument('--dropout_ratio', type=float, default=0.3, help='Dropout ratio for the model')
  22. self.parser.add_argument('--fc1_size', type=int, default=512, help='First FC layer Hidden size')
  23. self.parser.add_argument('--fc2_size', type=int, default=256, help='Second FC layer Hidden size')
  24. self.parser.add_argument('--fc3_size', type=int, default=128, help='Third FC layer Hidden size')
  25. self.parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'], help='Backbone model')
  26. self.parser.add_argument('--n_trainable_backbone_layers', type=int, default=1, help='Number of trainable backbone layers')
  27. self.parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'leakyrelu', 'gelu'], help='Activation function')
  28. self.parser.add_argument('--normalization_layer', type=str, default='sigmoid', choices=['sigmoid', 'softmax', 'gumbel', 'k_top'], help='Normalization layer type (for text-to-video retireval pipeline)')
  29. self.parser.add_argument('--init_tau', default=5.0, type=float, help="annealing init temperature")
  30. self.parser.add_argument('--min_tau', default=0.5, type=float, help="min temperature to anneal to")
  31. self.parser.add_argument('--decay_factor', default=0.045, type=float, help="exp decay factor per epoch")
  32. self.parser.add_argument('--store_complete_model', type=int, default=0, help='Store the compelte model')
  33. self.parser.add_argument('--load_complete_model', type=int, default=0, help='1: Load')
  34. self.parser.add_argument('--trained_model_epoch', type=int, default=100, help='Load the trained model from which epoch')
  35. self.parser.add_argument('--store_backbone', type=int, default=0, help='Store just backbone')
  36. self.parser.add_argument('--load_backbone', type=int, default=1, help='1: Load')
  37. self.parser.add_argument('--trained_backbone_epoch', type=int, default=100, help='Load a trained backbone model from which epoch')
  38. # Dataloader configurations
  39. self.parser.add_argument('--normalization', type=str, default='binary', choices=['raw', 'min_max', 'binary'], help='Normalization type for dataloader')
  40. self.parser.add_argument('--binary_percentage', type=float, default=0.1, help='Binary percentage for dataloader')
  41. # Training configurations
  42. self.parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training')
  43. self.parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for training')
  44. self.parser.add_argument('--similarity_matrix_loss', type=str, default='kl', choices=['mse','kl'], help='frame2frame loss')
  45. self.parser.add_argument('--use_weighted_loss', type=int, default=0, help='Whether use weighted BCE, to handle unbalanced classes')
  46. self.parser.add_argument('--alpha', type=float, default=1, help='Weight for f2f loss')
  47. self.parser.add_argument('--beta', type=float, default=1, help='Weight for f2t loss')
  48. self.parser.add_argument('--gamma', type=float, default=1, help='Weight for v2t loss')
  49. self.parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
  50. self.parser.add_argument('--use_scheduler', type=int, default=1, help='Use a learning rate scheduler')
  51. self.parser.add_argument('--scheduler_step_size', type=float, default=5, help='After how many epochs the learning rate will be adjusted')
  52. self.parser.add_argument('--scheduler_gamma', type=float, default=0.4, help='The factor by which the learning rate will be reduced at each step.')
  53. self.parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay for the optimizer. Set 0 to disable it')
  54. self.parser.add_argument('--do_cluster', type=int, default=1, help='Calculate results with cluster as well or not.')
  55. def parse(self):
  56. return self.parser.parse_args()
  57. def update_paths(self):
  58. self.args.device = torch.device(f'cuda:{self.args.device}' if torch.cuda.is_available() else "cpu")
  59. self.args.raw_backbone_path = f"{media_path_root}/saved_models/student_pretrained_backbones/{self.args.backbone}_pretrained.pth"
  60. self.args.st_scores_path = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/st_scores/st_scores.json"
  61. self.args.trained_backbone_path = f"{media_path_root}/experiments/saved_models/{self.args.backbone}/{self.args.similarity_matrix_loss}"
  62. self.args.trained_model_path = f"{self.args.trained_backbone_path}/complete_model/"
  63. self.args.log_dir = f'logs/{self.args.project_name}/'
  64. os.makedirs(self.args.st_scores_path, exist_ok=True)
  65. os.makedirs(self.args.trained_backbone_path, exist_ok=True)
  66. os.makedirs(self.args.trained_model_path, exist_ok=True)
  67. os.makedirs(self.args.log_dir, exist_ok=True)
  68. self.args.paths = {'train':{}, 'valid':{}}
  69. for data_split in ['train', 'valid']:
  70. split = 'val_1' if data_split=='valid' else data_split
  71. self.args.paths[data_split]['frames'] = f"{media_path_root}/{self.args.dataset}/sampled_frames/{split}"
  72. self.args.paths[data_split]['frame_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/frames/{split}"
  73. self.args.paths[data_split]['text_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/description/{split}"
  74. self.args.paths[data_split]['gt_scores'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/gt_scores/gt_{split}.json"