Browse Source

first commit

main
Parsa Haghighi Naeini 2 months ago
commit
bcb96dfebb
46 changed files with 417765 additions and 0 deletions
  1. 89
    0
      config.py
  2. 0
    0
      dataloaders/__init__.py
  3. BIN
      dataloaders/__pycache__/__init__.cpython-310.pyc
  4. BIN
      dataloaders/__pycache__/__init__.cpython-36.pyc
  5. BIN
      dataloaders/__pycache__/__init__.cpython-37.pyc
  6. BIN
      dataloaders/__pycache__/__init__.cpython-38.pyc
  7. BIN
      dataloaders/__pycache__/dataloader_activitynet.cpython-38.pyc
  8. BIN
      dataloaders/__pycache__/dataloader_activitynet_for_BCE.cpython-310.pyc
  9. BIN
      dataloaders/__pycache__/dataloader_activitynet_for_BCE.cpython-37.pyc
  10. BIN
      dataloaders/__pycache__/dataloader_activitynet_for_BCE.cpython-38.pyc
  11. BIN
      dataloaders/__pycache__/rawvideo_util.cpython-36.pyc
  12. BIN
      dataloaders/__pycache__/rawvideo_util.cpython-37.pyc
  13. 62
    0
      dataloaders/rawvideo_util.py
  14. 65
    0
      dataloaders/t2v_dataloader.py
  15. 1
    0
      datasets/activity-net/test_ids.json
  16. 257170
    0
      datasets/activity-net/train.json
  17. 1
    0
      datasets/activity-net/train_ids.json
  18. 121946
    0
      datasets/activity-net/val_1.json
  19. 1
    0
      datasets/activity-net/val_ids.json
  20. 1001
    0
      datasets/msrvtt/MSRVTT_JSFUSION_test.csv
  21. 1
    0
      datasets/msrvtt/MSRVTT_data.json
  22. 9001
    0
      datasets/msrvtt/MSRVTT_train.9k.csv
  23. 2990
    0
      datasets/msrvtt/test_list_full.txt
  24. 1000
    0
      datasets/msrvtt/test_list_miech.txt
  25. 6513
    0
      datasets/msrvtt/train_list_full.txt
  26. 9000
    0
      datasets/msrvtt/train_list_jsfusion.txt
  27. 6656
    0
      datasets/msrvtt/train_list_miech.txt
  28. 497
    0
      datasets/msrvtt/val_list_full.txt
  29. 1000
    0
      datasets/msrvtt/val_list_jsfusion.txt
  30. 63
    0
      main.py
  31. 7
    0
      requirements.txt
  32. 0
    0
      src/__init__.py
  33. BIN
      src/__pycache__/model.cpython-38.pyc
  34. BIN
      src/__pycache__/train.cpython-38.pyc
  35. 50
    0
      src/gumbel_softmax.py
  36. 131
    0
      src/model.py
  37. 378
    0
      src/train.py
  38. 0
    0
      utils/__init__.py
  39. BIN
      utils/__pycache__/__init__.cpython-38.pyc
  40. BIN
      utils/__pycache__/metrics.cpython-310.pyc
  41. BIN
      utils/__pycache__/metrics.cpython-37.pyc
  42. BIN
      utils/__pycache__/metrics.cpython-38.pyc
  43. BIN
      utils/__pycache__/utils.cpython-37.pyc
  44. BIN
      utils/__pycache__/video_loader.cpython-310.pyc
  45. 73
    0
      utils/cluster_frames.py
  46. 69
    0
      utils/metrics.py

+ 89
- 0
config.py View File

@@ -0,0 +1,89 @@
import argparse
import os
import torch

media_path_root = "path/to/media"

class Config:
def __init__(self):
self.parser = argparse.ArgumentParser(description='Train a student network.')
self.add_arguments()
self.args = self.parse()
self.update_paths()

def add_arguments(self):
# General settings
self.parser.add_argument('--seed', type=int, default=12345, help='Init Seed')
self.parser.add_argument('--device', type=int, default=0, help='Device number to use for training')
self.parser.add_argument("--do_train", type=int, default=1, help="Whether to run training.")
self.parser.add_argument("--do_eval", type=int, default=1, help="Whether to run evaluation.")
self.parser.add_argument('--dataset', type=str, default='activity-net', choices=['activity-net', 'msrvtt'], help='Dataset to use')
self.parser.add_argument('--teacher_model', type=str, default='CLIP', choices=['CLIP', 'DiffusionRet', 'EMCL', 'HBI', 'DiCoSA'], help='Teacher model to use')
self.parser.add_argument('--workers', type=int, default=12, help='Number of workers')

# Model configurations
self.parser.add_argument('--dropout_ratio', type=float, default=0.3, help='Dropout ratio for the model')
self.parser.add_argument('--fc1_size', type=int, default=512, help='First FC layer Hidden size')
self.parser.add_argument('--fc2_size', type=int, default=256, help='Second FC layer Hidden size')
self.parser.add_argument('--fc3_size', type=int, default=128, help='Third FC layer Hidden size')
self.parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'], help='Backbone model')
self.parser.add_argument('--n_trainable_backbone_layers', type=int, default=1, help='Number of trainable backbone layers')
self.parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'leakyrelu', 'gelu'], help='Activation function')
self.parser.add_argument('--normalization_layer', type=str, default='sigmoid', choices=['sigmoid', 'softmax', 'gumbel', 'k_top'], help='Normalization layer type (for text-to-video retireval pipeline)')
self.parser.add_argument('--init_tau', default=5.0, type=float, help="annealing init temperature")
self.parser.add_argument('--min_tau', default=0.5, type=float, help="min temperature to anneal to")
self.parser.add_argument('--decay_factor', default=0.045, type=float, help="exp decay factor per epoch")

self.parser.add_argument('--store_complete_model', type=int, default=0, help='Store the compelte model')
self.parser.add_argument('--load_complete_model', type=int, default=0, help='1: Load')
self.parser.add_argument('--trained_model_epoch', type=int, default=100, help='Load the trained model from which epoch')
self.parser.add_argument('--store_backbone', type=int, default=0, help='Store just backbone')
self.parser.add_argument('--load_backbone', type=int, default=1, help='1: Load')
self.parser.add_argument('--trained_backbone_epoch', type=int, default=100, help='Load a trained backbone model from which epoch')
# Dataloader configurations
self.parser.add_argument('--normalization', type=str, default='binary', choices=['raw', 'min_max', 'binary'], help='Normalization type for dataloader')
self.parser.add_argument('--binary_percentage', type=float, default=0.1, help='Binary percentage for dataloader')

# Training configurations
self.parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training')
self.parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for training')
self.parser.add_argument('--similarity_matrix_loss', type=str, default='kl', choices=['mse','kl'], help='frame2frame loss')
self.parser.add_argument('--use_weighted_loss', type=int, default=0, help='Whether use weighted BCE, to handle unbalanced classes')
self.parser.add_argument('--alpha', type=float, default=1, help='Weight for f2f loss')
self.parser.add_argument('--beta', type=float, default=1, help='Weight for f2t loss')
self.parser.add_argument('--gamma', type=float, default=1, help='Weight for v2t loss')
self.parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
self.parser.add_argument('--use_scheduler', type=int, default=1, help='Use a learning rate scheduler')
self.parser.add_argument('--scheduler_step_size', type=float, default=5, help='After how many epochs the learning rate will be adjusted')
self.parser.add_argument('--scheduler_gamma', type=float, default=0.4, help='The factor by which the learning rate will be reduced at each step.')
self.parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay for the optimizer. Set 0 to disable it')
self.parser.add_argument('--do_cluster', type=int, default=1, help='Calculate results with cluster as well or not.')
def parse(self):
return self.parser.parse_args()

def update_paths(self):
self.args.device = torch.device(f'cuda:{self.args.device}' if torch.cuda.is_available() else "cpu")

self.args.raw_backbone_path = f"{media_path_root}/saved_models/student_pretrained_backbones/{self.args.backbone}_pretrained.pth"
self.args.st_scores_path = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/st_scores/st_scores.json"
self.args.trained_backbone_path = f"{media_path_root}/experiments/saved_models/{self.args.backbone}/{self.args.similarity_matrix_loss}"
self.args.trained_model_path = f"{self.args.trained_backbone_path}/complete_model/"
self.args.log_dir = f'logs/{self.args.project_name}/'

os.makedirs(self.args.st_scores_path, exist_ok=True)
os.makedirs(self.args.trained_backbone_path, exist_ok=True)
os.makedirs(self.args.trained_model_path, exist_ok=True)
os.makedirs(self.args.log_dir, exist_ok=True)

self.args.paths = {'train':{}, 'valid':{}}

for data_split in ['train', 'valid']:
split = 'val_1' if data_split=='valid' else data_split
self.args.paths[data_split]['frames'] = f"{media_path_root}/{self.args.dataset}/sampled_frames/{split}"
self.args.paths[data_split]['frame_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/frames/{split}"
self.args.paths[data_split]['text_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/description/{split}"
self.args.paths[data_split]['gt_scores'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/gt_scores/gt_{split}.json"

+ 0
- 0
dataloaders/__init__.py View File


BIN
dataloaders/__pycache__/__init__.cpython-310.pyc View File


BIN
dataloaders/__pycache__/__init__.cpython-36.pyc View File


BIN
dataloaders/__pycache__/__init__.cpython-37.pyc View File


BIN
dataloaders/__pycache__/__init__.cpython-38.pyc View File


BIN
dataloaders/__pycache__/dataloader_activitynet.cpython-38.pyc View File


BIN
dataloaders/__pycache__/dataloader_activitynet_for_BCE.cpython-310.pyc View File


BIN
dataloaders/__pycache__/dataloader_activitynet_for_BCE.cpython-37.pyc View File


BIN
dataloaders/__pycache__/dataloader_activitynet_for_BCE.cpython-38.pyc View File


BIN
dataloaders/__pycache__/rawvideo_util.cpython-36.pyc View File


BIN
dataloaders/__pycache__/rawvideo_util.cpython-37.pyc View File


+ 62
- 0
dataloaders/rawvideo_util.py View File

@@ -0,0 +1,62 @@
import cv2
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode

class RawVideoExtractor():
def __init__(self, centercrop=False, framerate=-1, size=224, to_tensor=True):
self.centercrop = centercrop
self.framerate = framerate
self.to_tensor = to_tensor
self.transform = self._transform(size)

def _transform(self, n_px):
if self.to_tensor:

return Compose([
Resize(n_px, interpolation=InterpolationMode.BICUBIC),
CenterCrop(n_px), ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))])
else:
return Compose([Resize(n_px, interpolation=InterpolationMode.BICUBIC),CenterCrop(n_px)])

def get_video_data(self, video_path, start_time=None, end_time=None):
if start_time is not None or end_time is not None:
assert start_time > -1 and end_time > start_time
assert self.framerate > -1

cap = cv2.VideoCapture(video_path)

frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = cap.get(cv2.CAP_PROP_FPS)

start_frame = int(start_time * video_fps) if start_time else 0
end_frame = int(end_time * video_fps) if end_time else frame_count - 1

interval = 1
if self.framerate > 0:
interval = video_fps / self.framerate
else:
self.framerate = video_fps
if interval == 0:
interval = 1

images = []

for i in range(frame_count):
ret, frame = cap.read()

if not ret:
break

if i >= start_frame and i <= end_frame:
if len(images) * interval < i - start_frame:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image = Image.fromarray(frame_rgb)
image = self.transform(image)
images.append(image)

cap.release()

return images

+ 65
- 0
dataloaders/t2v_dataloader.py View File

@@ -0,0 +1,65 @@
import os
import json
import torch
from torchvision.io import read_image
from torch.utils.data import Dataset
from torchvision import transforms

class VideoDescriptionDataloader(Dataset):
def __init__(self, cfg, data_type):
self.frames_dir = cfg.paths[data_type]['frames']
self.frame_features_dir = cfg.paths[data_type]['frame_features']
self.text_features_dir = cfg.paths[data_type]['text_features']
self.gt_scores_file_path = cfg.paths[data_type]['gt_scores']
self.transform = self._get_transforms()
self.binary_percentage = cfg.binary_percentage
if cfg.normalization == 'min_max' or data_type=='valid':
self.normalize_scores = self._min_max_normalize_scores
elif cfg.normalization == 'binary':
self.normalize_scores = self._to_binary_labels
else:
raise ValueError(f"Unsupported normalization method: {cfg.normalization}")

with open(self.gt_scores_file_path, "r") as f:
self.gt_scores = json.load(f)
self.video_ids = list(self.gt_scores.keys())

def __len__(self):
return len(self.video_ids)

def __getitem__(self, idx):
video_id = self.video_ids[idx]
video_frames_dir = os.path.join(self.frames_dir, self.video_ids[idx])
frames_tensor = torch.stack([self.transform(read_image(os.path.join(video_frames_dir, frame_file)).float())
for frame_file in sorted(os.listdir(video_frames_dir))])

text_features = torch.load(os.path.join(self.text_features_dir, f'{self.video_ids[idx]}.pt'))
frame_features = torch.load(os.path.join(self.frame_features_dir, f'{self.video_ids[idx]}.pt'))
gt_score = torch.tensor(self.gt_scores[self.video_ids[idx]], dtype=torch.float32)
gt_score = self.normalize_scores(gt_score)

return frames_tensor, text_features, frame_features, gt_score

@staticmethod
def _get_transforms():
return transforms.Compose([transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
@staticmethod
def _min_max_normalize_scores(score):
min_val, max_val = score.min(), score.max()

if min_val != max_val:
return (score - min_val) / (max_val - min_val)
return torch.full_like(score, 0.5)
def _to_binary_labels(self, score):
num_top_elements = max(int(len(score) * self.binary_percentage) , 1)
sorted_indices = score.argsort(descending=True)
binary_labels = torch.zeros_like(score)
binary_labels[sorted_indices[:num_top_elements]] = 1
return binary_labels

+ 1
- 0
datasets/activity-net/test_ids.json
File diff suppressed because it is too large
View File


+ 257170
- 0
datasets/activity-net/train.json
File diff suppressed because it is too large
View File


+ 1
- 0
datasets/activity-net/train_ids.json
File diff suppressed because it is too large
View File


+ 121946
- 0
datasets/activity-net/val_1.json
File diff suppressed because it is too large
View File


+ 1
- 0
datasets/activity-net/val_ids.json
File diff suppressed because it is too large
View File


+ 1001
- 0
datasets/msrvtt/MSRVTT_JSFUSION_test.csv
File diff suppressed because it is too large
View File


+ 1
- 0
datasets/msrvtt/MSRVTT_data.json
File diff suppressed because it is too large
View File


+ 9001
- 0
datasets/msrvtt/MSRVTT_train.9k.csv
File diff suppressed because it is too large
View File


+ 2990
- 0
datasets/msrvtt/test_list_full.txt
File diff suppressed because it is too large
View File


+ 1000
- 0
datasets/msrvtt/test_list_miech.txt
File diff suppressed because it is too large
View File


+ 6513
- 0
datasets/msrvtt/train_list_full.txt
File diff suppressed because it is too large
View File


+ 9000
- 0
datasets/msrvtt/train_list_jsfusion.txt
File diff suppressed because it is too large
View File


+ 6656
- 0
datasets/msrvtt/train_list_miech.txt
File diff suppressed because it is too large
View File


+ 497
- 0
datasets/msrvtt/val_list_full.txt View File

@@ -0,0 +1,497 @@
video6513
video6514
video6515
video6516
video6517
video6518
video6519
video6520
video6521
video6522
video6523
video6524
video6525
video6526
video6527
video6528
video6529
video6530
video6531
video6532
video6533
video6534
video6535
video6536
video6537
video6538
video6539
video6540
video6541
video6542
video6543
video6544
video6545
video6546
video6547
video6548
video6549
video6550
video6551
video6552
video6553
video6554
video6555
video6556
video6557
video6558
video6559
video6560
video6561
video6562
video6563
video6564
video6565
video6566
video6567
video6568
video6569
video6570
video6571
video6572
video6573
video6574
video6575
video6576
video6577
video6578
video6579
video6580
video6581
video6582
video6583
video6584
video6585
video6586
video6587
video6588
video6589
video6590
video6591
video6592
video6593
video6594
video6595
video6596
video6597
video6598
video6599
video6600
video6601
video6602
video6603
video6604
video6605
video6606
video6607
video6608
video6609
video6610
video6611
video6612
video6613
video6614
video6615
video6616
video6617
video6618
video6619
video6620
video6621
video6622
video6623
video6624
video6625
video6626
video6627
video6628
video6629
video6630
video6631
video6632
video6633
video6634
video6635
video6636
video6637
video6638
video6639
video6640
video6641
video6642
video6643
video6644
video6645
video6646
video6647
video6648
video6649
video6650
video6651
video6652
video6653
video6654
video6655
video6656
video6657
video6658
video6659
video6660
video6661
video6662
video6663
video6664
video6665
video6666
video6667
video6668
video6669
video6670
video6671
video6672
video6673
video6674
video6675
video6676
video6677
video6678
video6679
video6680
video6681
video6682
video6683
video6684
video6685
video6686
video6687
video6688
video6689
video6690
video6691
video6692
video6693
video6694
video6695
video6696
video6697
video6698
video6699
video6700
video6701
video6702
video6703
video6704
video6705
video6706
video6707
video6708
video6709
video6710
video6711
video6712
video6713
video6714
video6715
video6716
video6717
video6718
video6719
video6720
video6721
video6722
video6723
video6724
video6725
video6726
video6727
video6728
video6729
video6730
video6731
video6732
video6733
video6734
video6735
video6736
video6737
video6738
video6739
video6740
video6741
video6742
video6743
video6744
video6745
video6746
video6747
video6748
video6749
video6750
video6751
video6752
video6753
video6754
video6755
video6756
video6757
video6758
video6759
video6760
video6761
video6762
video6763
video6764
video6765
video6766
video6767
video6768
video6769
video6770
video6771
video6772
video6773
video6774
video6775
video6776
video6777
video6778
video6779
video6780
video6781
video6782
video6783
video6784
video6785
video6786
video6787
video6788
video6789
video6790
video6791
video6792
video6793
video6794
video6795
video6796
video6797
video6798
video6799
video6800
video6801
video6802
video6803
video6804
video6805
video6806
video6807
video6808
video6809
video6810
video6811
video6812
video6813
video6814
video6815
video6816
video6817
video6818
video6819
video6820
video6821
video6822
video6823
video6824
video6825
video6826
video6827
video6828
video6829
video6830
video6831
video6832
video6833
video6834
video6835
video6836
video6837
video6838
video6839
video6840
video6841
video6842
video6843
video6844
video6845
video6846
video6847
video6848
video6849
video6850
video6851
video6852
video6853
video6854
video6855
video6856
video6857
video6858
video6859
video6860
video6861
video6862
video6863
video6864
video6865
video6866
video6867
video6868
video6869
video6870
video6871
video6872
video6873
video6874
video6875
video6876
video6877
video6878
video6879
video6880
video6881
video6882
video6883
video6884
video6885
video6886
video6887
video6888
video6889
video6890
video6891
video6892
video6893
video6894
video6895
video6896
video6897
video6898
video6899
video6900
video6901
video6902
video6903
video6904
video6905
video6906
video6907
video6908
video6909
video6910
video6911
video6912
video6913
video6914
video6915
video6916
video6917
video6918
video6919
video6920
video6921
video6922
video6923
video6924
video6925
video6926
video6927
video6928
video6929
video6930
video6931
video6932
video6933
video6934
video6935
video6936
video6937
video6938
video6939
video6940
video6941
video6942
video6943
video6944
video6945
video6946
video6947
video6948
video6949
video6950
video6951
video6952
video6953
video6954
video6955
video6956
video6957
video6958
video6959
video6960
video6961
video6962
video6963
video6964
video6965
video6966
video6967
video6968
video6969
video6970
video6971
video6972
video6973
video6974
video6975
video6976
video6977
video6978
video6979
video6980
video6981
video6982
video6983
video6984
video6985
video6986
video6987
video6988
video6989
video6990
video6991
video6992
video6993
video6994
video6995
video6996
video6997
video6998
video6999
video7000
video7001
video7002
video7003
video7004
video7005
video7006
video7007
video7008
video7009

+ 1000
- 0
datasets/msrvtt/val_list_jsfusion.txt
File diff suppressed because it is too large
View File


+ 63
- 0
main.py View File

@@ -0,0 +1,63 @@
import os
import torch
from config import Config
from src.model import SaliencyNet
from torch.utils.data import DataLoader
from src.train import SalientFrameSamplerTrainer
from dataloaders.t2v_dataloader import VideoDescriptionDataloader

def load_weights(path, which_epoch):
weights, last_epoch = None, None
available_models = [name for name in os.listdir(path) if name.endswith(".pt")]

if available_models:
last_epoch = max([int(name[6:-3]) for name in available_models])
last_epoch = min(last_epoch, which_epoch)
weights = torch.load(os.path.join(path, f'epoch_{last_epoch}.pt'))

return weights, last_epoch

def set_seeds(seed: int):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def main(cfg):

set_seeds(cfg.seed)

model = SaliencyNet(cfg)
model.to(cfg.device)
if cfg.load_complete_model:
weights, last_epoch = load_weights(cfg.trained_model_path, cfg.trained_model_epoch)
if weights:
model.load_state_dict(weights)
print(f'Complete model -{cfg.backbone} trained on {cfg.similarity_matrix_loss} and {cfg.saliency_matching_loss} lossese from epoch #{last_epoch}')

elif cfg.load_backbone:
weights, last_epoch = load_weights(cfg.trained_backbone_path, cfg.trained_backbone_epoch)
if weights:
model.pretrained.load_state_dict(weights)
print(f'{cfg.backbone} backbone trained on {cfg.similarity_matrix_loss} loss loaded from epoch #{last_epoch}')
else:
print(f'{cfg.backbone} backbone loaded from scratch.')
train_dataset = VideoDescriptionDataloader(cfg, data_type='train')
val_dataset = VideoDescriptionDataloader(cfg, data_type='valid')

train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=cfg.workers)
val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=cfg.workers)
trainer = SalientFrameSamplerTrainer(model, train_dataloader, val_dataloader, cfg)
trainer.train()

if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn")
cfg = Config().args

main(cfg)

+ 7
- 0
requirements.txt View File

@@ -0,0 +1,7 @@
torch
opencv-python
Pillow
timm
transformers
torchvision
scikit-learn

+ 0
- 0
src/__init__.py View File


BIN
src/__pycache__/model.cpython-38.pyc View File


BIN
src/__pycache__/train.cpython-38.pyc View File


+ 50
- 0
src/gumbel_softmax.py View File

@@ -0,0 +1,50 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class GumbelSoftmaxTopK(nn.Module):
def __init__(self, k, tau):
super(GumbelSoftmaxTopK, self).__init__()
self.k = k
self.tau = tau

def forward(self, logits):
logits = torch.log(nn.Softmax(dim=0)(logits).clamp(min=1e-8))

"""
Adapted from: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/DL2/sampling/subsets.html#Subset-Sampler-Class
and https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
"""
m = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format),
torch.ones_like(logits, memory_format= torch.legacy_contiguous_format))
gumbels = m.sample()
y = logits + gumbels
_EPS = torch.tensor(1e-40).to(logits.device)
khot = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format)
onehot_approx = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format)

for i in range(self.k):
khot_mask = torch.maximum(1.0 - onehot_approx, _EPS)
y += khot_mask.log()

onehot_approx = nn.Softmax(dim=0)(y / self.tau)
khot = torch.add(khot, onehot_approx)

return khot

class GumbelSoftmax(nn.Module):
def __init__(self, tau, hard=False):
super(GumbelSoftmax, self).__init__()
self.tau = tau
self.hard = hard

def forward(self, logits):
m = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format),
torch.ones_like(logits, memory_format= torch.legacy_contiguous_format))
gumbels = m.sample()

y = logits + gumbels
y_soft = nn.Softmax(dim=0)(y / self.tau)

return y_soft

+ 131
- 0
src/model.py View File

@@ -0,0 +1,131 @@
import timm
import torch
import torch.nn as nn
from src.gumbel_softmax import GumbelSoftmax, GumbelSoftmaxTopK

class SaliencyNet(nn.Module):
def __init__(self, cfg):

super(SaliencyNet, self).__init__()
self.k = 8
self.tau = cfg.init_tau
self.min_tau = cfg.min_tau
self.decay_factor = cfg.decay_factor
self.raw_backbone_path = cfg.raw_backbone_path
self.configure_pretrained(cfg.backbone, cfg.n_trainable_backbone_layers)
embed_dim = self.pretrained.num_features

self.fc1 = nn.Linear(embed_dim, cfg.fc1_size)
self.fc2 = nn.Linear(cfg.fc1_size, cfg.fc2_size)
self.fc3 = nn.Linear(cfg.fc2_size, cfg.fc3_size)
self.fc4 = nn.Linear(cfg.fc3_size, 1)
# self.bn1 = nn.BatchNorm1d(fc1_size) # Batch normalization layer
self.dropout = nn.Dropout(cfg.dropout_ratio)
self.sigmoid = nn.Sigmoid()
if cfg.activation == 'relu':
self.activation = nn.ReLU()
elif cfg.activation == 'leakyrelu':
self.activation = nn.LeakyReLU()
elif cfg.activation == 'gelu':
self.activation = nn.GELU()
else:
raise ValueError("Invalid activation type. Choose 'relu', 'leakyrelu', or 'gelu'.")

self.normalization_layer = self.init_normalization_layer(cfg.normalization_layer, self.tau, self.k)

self.initialize_weights()


def init_normalization_layer(self, normalization_type, tau, k):
if normalization_type == 'raw':
normalization_layer = lambda x: x
elif normalization_type == 'sigmoid':
normalization_layer = nn.Sigmoid()
elif normalization_type == 'softmax':
normalization_layer = nn.Softmax(dim=1)
elif normalization_type == 'gumbel':
normalization_layer = GumbelSoftmax(tau)
elif normalization_type == 'k_top':
normalization_layer = GumbelSoftmaxTopK(k, tau)

return normalization_layer

def initialize_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
nn.init.constant_(module.bias, 0)

def forward(self, preprocessed_images):
features = self.pretrained(preprocessed_images)

x = self.activation(self.fc1(features))
x = self.dropout(x)
x = self.activation(self.fc2(x))
x = self.dropout(x)
x = self.activation(self.fc3(x))
x = self.fc4(x)

frame_scores = torch.flatten(x)
frame_weights = torch.flatten(self.normalization_layer(x))
return frame_scores, frame_weights, features

def anneal_tau(self):
""" Anneals the temperature parameter τ using exponential decay """
self.tau = self.tau * torch.exp(-self.decay_factor * torch.tensor(1.0))
self.tau = max(self.tau, self.min_tau) # To prevent τ from becoming too small

def configure_pretrained(self, backbone, n_trainable_backbone_layers=1):
# backbones = ['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100']
self.pretrained = timm.create_model(backbone, pretrained=False, checkpoint_path=self.raw_backbone_path)
# Freeze the early layers of the model
for param in self.pretrained.parameters():
param.requires_grad = False
# disable last layer of resnet18
if backbone =='resnet18':
self.pretrained.fc = nn.Identity()
sequential_modules = [child for child in self.pretrained.children() if isinstance(child, nn.Sequential)]
elif backbone in ['efficientnet_b0', 'mobilenetv2_100','mobilenetv3_large_100'] :
self.pretrained.classifier = nn.Identity()
sequential_modules = [child for child in self.pretrained.blocks.children()]
else:
assert('Not supported backbone model!')

num_seq_blocks = len(sequential_modules)
if n_trainable_backbone_layers ==0:
pass

elif n_trainable_backbone_layers >= num_seq_blocks:
for param in self.pretrained.parameters():
param.requires_grad = True
else:
# Select the last `n_trainable_backbone_layers` of sequential modules to be trainable
layers_to_train = sequential_modules[-n_trainable_backbone_layers:]
for layer in layers_to_train:
for param in layer.parameters():
param.requires_grad = True

if n_trainable_backbone_layers > 0:
last_layers= []

if hasattr(self.pretrained, 'conv_head'):
last_layers.append(self.pretrained.conv_head.weight)
if hasattr(self.pretrained, 'bn2'):
last_layers.append(self.pretrained.bn2.weight)
last_layers.append(self.pretrained.bn2.bias)
for param in last_layers:
param.requires_grad = True

+ 378
- 0
src/train.py View File

@@ -0,0 +1,378 @@
import os
import torch
from tqdm import tqdm, trange
import torch.nn as nn
import logging
import torch.nn.functional as F
from tabulate import tabulate
from torch.utils.data import DataLoader
from utils.cluster_frames import VideoClusterer
from utils.metrics import compute_metrics

class SalientFrameSamplerTrainer:
def __init__(self, model, train_dataloader: DataLoader, val_dataloader: DataLoader, cfg):
self.cfg = cfg
self.model = model
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.init_logging()
self.criterion1 = kl_divergence_loss()
self.criterion2 = nn.BCEWithLogitsLoss()
self.criterion3 = CrossEn()
self.optimizer = self.init_optimizer()

if self.cfg.use_scheduler:
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.cfg.scheduler_step_size, gamma=self.cfg.scheduler_gamma)
def train(self):

for epoch in trange(self.cfg.epochs, desc="Epoch", ncols=100):
log_data = {}

if self.cfg.do_train:
train_loss = self.run_training_epoch()
log_data.update({
"train_loss(f2f)": train_loss['f2f'],
"train_loss(f2t)": train_loss['f2t'],
"train_loss(v2t)": train_loss['v2t'],
})

if self.cfg.do_eval:
valid_loss, t2v_metrics, v2t_metrics = self.run_validation_epoch()
log_data.update({
"valid_loss(f2f)": valid_loss['f2f'],
"valid_loss(f2t)": valid_loss['f2t'],
"valid_loss(v2t)": valid_loss['v2t'],
})

for key in t2v_metrics.keys():
for r in [1, 5, 10]:
log_data[f"[t2v] R@{r} (k={key})"] = t2v_metrics[key][f'R@{r}']
log_data[f"[v2t] R@{r} (k={key})"] = v2t_metrics[key][f'R@{r}']
self.log_metrics(epoch, log_data)

self.model.anneal_tau()

if hasattr(self,'scheduler'):
self.scheduler.step()
if self.cfg.store_backbone:
torch.save(self.model.pretrained.state_dict(), os.path.join(self.cfg.trained_backbone_path, f"epoch_{epoch+1}.pt"))

if self.cfg.store_complete_model:
torch.save(self.model.state_dict(), os.path.join(self.args.trained_model_path, f"epoch_{epoch+1}.pt"))


def run_training_epoch(self):
self.model.train()

epoch_train_loss = {'f2f': 0.0, 'f2t': 0.0, 'v2t': 0.0}

video_embeddings = []
text_embeddings = []
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0

for i, (frames, teacher_text_features, teacher_frame_features, gt_score) in enumerate(tqdm(self.train_dataloader, desc='Train', leave=False, ncols=100)):
torch.cuda.empty_cache()
frames = frames[0].to(self.cfg.device)
teacher_frame_features = teacher_frame_features[0].to(self.cfg.device)
teacher_text_features = teacher_text_features[0].to(self.cfg.device)
gt_score = gt_score[0].to(self.cfg.device)

frame_scores, frame_weights, student_frame_features = self.model(frames)

# Frame-to-Frame loss
f2f_loss = self.criterion1(student_frame_features, teacher_frame_features)

# Frame-to-Text loss
f2t_loss = self.criterion2(frame_scores, gt_score)

# Update accumulators
f2f_loss_accum += f2f_loss.item() / self.cfg.batch_size
f2t_loss_accum += f2t_loss.item() / self.cfg.batch_size

# Calculating the in-pipeline loss for v2t
video_features = torch.matmul(teacher_frame_features.T, frame_weights) / torch.sum(frame_weights)
video_embeddings.append(video_features)
text_embeddings.append(teacher_text_features)

if (i + 1) % self.cfg.batch_size == 0 or i == len(self.train_dataloader) - 1:
sim_matrix = self.compute_sim_matrix(video_embeddings, text_embeddings)
sim_loss1 = self.criterion3(sim_matrix)
sim_loss2 = self.criterion3(sim_matrix.T)
v2t_loss = (sim_loss1 + sim_loss2) / 2
# Aggregate the total loss with regularization terms
total_loss = self.alpha * f2f_loss_accum + self.beta * f2t_loss_accum + self.gamma * v2t_loss
self.optimizer.zero_grad()
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()

# Update epoch losses
epoch_train_loss['f2f'] += f2f_loss_accum
epoch_train_loss['f2t'] += f2t_loss_accum
epoch_train_loss['v2t'] += v2t_loss.item()

# Reset accumulators
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0
video_embeddings = []
text_embeddings = []

# Normalize epoch losses by the number of batches
for key in epoch_train_loss.keys():
epoch_train_loss[key] /= len(self.train_dataloader)
return epoch_train_loss

def run_validation_epoch(self):
self.model.eval()
epoch_valid_loss = {'f2f': 0.0, 'f2t': 0.0, 'v2t': 0.0}

video_embeddings = []
text_embeddings = []
all_video_embeddings = {str(k): {'without_clustering': [],
'uniform': [], 'spectral': [], 'agglomerative': [], 'kmeans': []} for k in [1,2,4,8,16]}
all_text_embeddings = []
all_st_scores = []

f2f_loss_accum, f2t_loss_accum = 0.0, 0.0

for i, (frames, teacher_text_features, teacher_frame_features, gt_score) in enumerate(tqdm(self.val_dataloader, desc='Valid', leave=False, ncols=100)):
torch.cuda.empty_cache()
frames = frames[0].to(self.cfg.device)
teacher_frame_features = teacher_frame_features[0].to(self.cfg.device)
teacher_text_features = teacher_text_features[0].to(self.cfg.device)
gt_score = gt_score[0].to(self.cfg.device)

with torch.no_grad():
frame_scores, frame_weights, student_frame_features = self.model(frames)

f2f_loss = self.criterion1(student_frame_features, teacher_frame_features)
f2t_loss = self.criterion2(frame_scores, gt_score)

f2f_loss_accum += f2f_loss.item()
f2t_loss_accum += f2t_loss.item()

# Calculating the all embeddings, used for calculating t2v_retrieval metrics
all_text_embeddings.append(teacher_text_features)

n_frames = teacher_frame_features.size(0)
sorted_indices = sorted(range(n_frames), key=lambda i: frame_scores[i], reverse=True).copy()

for k in [1,2,4,8,16]:
n_selected_frames = min(n_frames, k)
indices = sorted_indices[:n_selected_frames]
all_video_embeddings[str(k)]['without_clustering'].append(teacher_frame_features[indices].mean(dim=0))

if self.cfg.do_cluster and k>2:
for clustering_method in ['uniform', 'spectral', 'agglomerative', 'kmeans']:
if len(frame_scores) <= k: #mean of all frames
all_video_embeddings[str(k)][clustering_method].append(teacher_frame_features.mean(dim=0))
else:
clusterer = VideoClusterer(clustering_method, k)
clusters = clusterer.get_clusters(student_frame_features)
selected_indices = self.select_representative_frame_per_cluster(clusters, frame_scores)
all_video_embeddings[str(k)][clustering_method].append(teacher_frame_features[selected_indices].mean(dim=0))
# Calculating the in-pipeline loss
video_features = torch.matmul(teacher_frame_features.T, frame_weights) / torch.sum(frame_weights)
video_embeddings.append(video_features)
text_embeddings.append(teacher_text_features)

if (i + 1) % self.cfg.batch_size == 0 or i == len(self.val_dataloader) - 1:
sim_matrix = self.compute_sim_matrix(video_embeddings, text_embeddings)
sim_loss1 = self.criterion3(sim_matrix)
sim_loss2 = self.criterion3(sim_matrix.T)
v2t_loss = (sim_loss1 + sim_loss2) / 2
v2t_loss = (sim_loss1 + sim_loss2) / 2

# Update epoch losses
epoch_valid_loss['f2f'] += f2f_loss_accum
epoch_valid_loss['f2t'] += f2t_loss_accum
epoch_valid_loss['v2t'] += v2t_loss.item()
# Reset accumulators
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0
video_embeddings = []
text_embeddings = []

# Normalize epoch losses by the number of batches
for key in epoch_valid_loss.keys():
epoch_valid_loss[key] /= len(self.val_dataloader)

# Compute t2v metrics
t2v_metrics = {}
v2t_metrics = {}

for k, embeddings_dict in all_video_embeddings.items():
t2v_metrics[str(k)]={}
for scenario, embeddings_list in embeddings_dict.items():
if embeddings_list: # Check if there are any embeddings for this scenario
sim_matrix = self.compute_sim_matrix(embeddings_list, all_text_embeddings)
_key = f"{k}" if scenario=='without_clustering' else f"{k}-{scenario}"
t2v_metrics[_key] = compute_metrics(sim_matrix.detach().cpu().numpy())
v2t_metrics[_key] = compute_metrics(sim_matrix.T.detach().cpu().numpy())

return epoch_valid_loss, t2v_metrics, v2t_metrics

def compute_sim_matrix(self, video_embeddings, text_embeddings):
video_embeddings_tensor = torch.stack(video_embeddings, dim=0)
text_embeddings_tensor = torch.stack(text_embeddings, dim=0)
video_embeddings_norm = video_embeddings_tensor / video_embeddings_tensor.norm(dim=-1, keepdim=True)
text_embeddings_norm = text_embeddings_tensor / text_embeddings_tensor.norm(dim=-1, keepdim=True)
sim_matrix = torch.matmul(video_embeddings_norm, text_embeddings_norm.T)
return sim_matrix

def init_optimizer(self):
if not hasattr(self.cfg, 'backbone_lr_scale'):
return torch.optim.Adam(params=self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay)

backbone_params = []
last_params = []
for name, param in self.model.named_parameters():
if 'pretrain' in name:
backbone_params.append(param)
else:
last_params.append(param)

return torch.optim.Adam([
{'params': backbone_params, 'lr': self.cfg.lr},
{'params': last_params, 'lr': self.cfg.lr * self.cfg.backbone_lr_scale}
], weight_decay=self.cfg.weight_decay)
def init_logging(self):
available_logs = os.listdir(self.cfg.log_dir)
last_run = 0
if available_logs:
last_run = max([int(name.split('-')[0]) for name in available_logs])
log_file_name = f"{last_run + 1}-local-run.log"

log_file = os.path.join(self.cfg.log_dir, log_file_name)

logging.basicConfig(filename=log_file, level=logging.INFO,
format='%(asctime)s [%(levelname)s] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')

hyperparameters = {key: value for key, value in vars(self.cfg).items()}
hyperparameters_str = "\n".join([f"{key}: {value}" for key, value in hyperparameters.items() if ('path' not in key and 'log_dir' not in key)])
logging.info("Hyperparameters:\n" + hyperparameters_str)
def log_metrics(self, epoch, log_data):
# Loss Metrics Table
loss_header = ["", "f2f", "f2t", "v2t"]

# Check if train_loss exists, otherwise assign "N/A"
train_loss_f2f = log_data.get("train_loss(f2f)", "N/A")
train_loss_f2t = log_data.get("train_loss(f2t)", "N/A")
train_loss_v2t = log_data.get("train_loss(v2t)", "N/A")

# Check if valid_loss exists, otherwise assign "N/A"
valid_loss_f2f = log_data.get("valid_loss(f2f)", "N/A")
valid_loss_f2t = log_data.get("valid_loss(f2t)", "N/A")
valid_loss_v2t = log_data.get("valid_loss(v2t)", "N/A")

loss_rows = [
["train", train_loss_f2f, train_loss_f2t, train_loss_v2t],
["val", valid_loss_f2f, valid_loss_f2t, valid_loss_v2t],
]
loss_table = tabulate(loss_rows, headers=loss_header, tablefmt="grid")

# Recall Metrics Table
recall_header = ["", "R@1", "R@5", "R@10"]

clustering_names = ['uniform', 'spectral', 'agglomerative', 'kmeans']

recall_rows_t2v = []
recall_rows_v2t = []
for k in [1, 2, 4, 8, 16]:
# Without clustering
base_key = f"k={k}"
if any(f"[t2v] R@{r} ({base_key})" in log_data for r in [1, 5, 10]):
row = [f"{k}"]
for r in [1, 5, 10]:
metric_key = f"[t2v] R@{r} ({base_key})"
row.append(log_data.get(metric_key, "N/A"))
recall_rows_t2v.append(row)

if any(f"[v2t] R@{r} ({base_key})" in log_data for r in [1, 5, 10]):
row = [f"{k}"]
for r in [1, 5, 10]:
metric_key = f"[v2t] R@{r} ({base_key})"
row.append(log_data.get(metric_key, "N/A"))
recall_rows_v2t.append(row)

# With clustering
for clustering_name in clustering_names:
cluster_key = f"{base_key}-{clustering_name}"
if any(f"[t2v] R@{r} ({cluster_key})" in log_data for r in [1, 5, 10]):
row = [f"{k}-{clustering_name}"]
for r in [1, 5, 10]:
metric_key = f"[t2v] R@{r} ({cluster_key})"
row.append(log_data.get(metric_key, "N/A"))
recall_rows_t2v.append(row)
if any(f"[v2t] R@{r} ({cluster_key})" in log_data for r in [1, 5, 10]):
row = [f"{k}-{clustering_name}"]
for r in [1, 5, 10]:
metric_key = f"[v2t] R@{r} ({cluster_key})"
row.append(log_data.get(metric_key, "N/A"))
recall_rows_v2t.append(row)

recall_table_t2v = tabulate(recall_rows_t2v, headers=recall_header, tablefmt="grid")
recall_table_v2t = tabulate(recall_rows_v2t, headers=recall_header, tablefmt="grid")
logging.info(f"Epoch {epoch+1} - Loss and Recall Metrics:")
logging.info("\n" + loss_table)
logging.info("\nText-to-Video Recall Metrics:")
logging.info("\n" + recall_table_t2v)
logging.info("\nVideo-to-Text Recall Metrics:")
logging.info("\n" + recall_table_v2t)

def select_representative_frame_per_cluster(self, clusters, frame_scores):
representative_frames = []

for cluster in clusters:
best_frame_index = max(cluster, key=lambda index: frame_scores[index])
representative_frames.append(int(best_frame_index))

return representative_frames

class kl_divergence_loss(nn.Module):
def __init__(self,):
super(kl_divergence_loss, self).__init__()

def forward(self, features1, features2):
features1 = F.normalize(features1, p=2, dim=-1)
features2 = F.normalize(features2, p=2, dim=-1)

cos_sim_features1 = torch.mm(features1, features1.t())
cos_sim_features2 = torch.mm(features2, features2.t())

probs_features1 = F.softmax(cos_sim_features1, dim=-1)
probs_features2 = F.softmax(cos_sim_features2, dim=-1)

loss = F.kl_div(probs_features1.log(), probs_features2, reduction='batchmean')

return loss
class CrossEn(nn.Module):
def __init__(self,):
super(CrossEn, self).__init__()

def forward(self, sim_matrix):
logpt = F.log_softmax(sim_matrix, dim=-1)
logpt = torch.diag(logpt)
nce_loss = -logpt
sim_loss = nce_loss.mean()
return sim_loss

+ 0
- 0
utils/__init__.py View File


BIN
utils/__pycache__/__init__.cpython-38.pyc View File


BIN
utils/__pycache__/metrics.cpython-310.pyc View File


BIN
utils/__pycache__/metrics.cpython-37.pyc View File


BIN
utils/__pycache__/metrics.cpython-38.pyc View File


BIN
utils/__pycache__/utils.cpython-37.pyc View File


BIN
utils/__pycache__/video_loader.cpython-310.pyc View File


+ 73
- 0
utils/cluster_frames.py View File

@@ -0,0 +1,73 @@
from sklearn.cluster import SpectralClustering, AgglomerativeClustering, KMeans
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

class VideoClusterer:
def __init__(self, clustering_method='uniform', n_clusters=2, similarity_threshold=0.8):
self.n_clusters = n_clusters
self.similarity_threshold = similarity_threshold
self.clustering_method = clustering_method

# Decide on the clustering method to use
if clustering_method == 'uniform':
self.clusterer = self.uniform_clustering
elif clustering_method == 'spectral':
self.clusterer = SpectralClustering(n_clusters=n_clusters, affinity='precomputed')
elif clustering_method == 'agglomerative':
self.clusterer = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward')
elif clustering_method == 'kmeans':
self.clusterer = KMeans(n_clusters=n_clusters, n_init=1)
else:
raise ValueError(f"Invalid clustering method: {clustering_method}")

def uniform_clustering(self, features):
n = len(features)
clusters = []
cluster_size = n // self.n_clusters
remainder = n % self.n_clusters

start = 0
for i in range(self.n_clusters):
if i < remainder:
end = start + cluster_size + 1
else:
end = start + cluster_size
clusters.append(list(range(start, end)))
start = end

return clusters
def detect_outliers(self, features):
dot_product_matrix = features.dot(features.T)
average_similarities = np.mean(dot_product_matrix, axis=0)
# Adding a small constant epsilon to the standard deviation to prevent division by zero
epsilon = 1e-8
normal = (average_similarities - np.mean(average_similarities)) / (np.std(average_similarities) + epsilon)
outlier_mask = np.logical_or(normal > 1.5, normal < -1.5)
return outlier_mask
def get_clusters(self, features):
features = features.cpu().numpy()

if self.clustering_method == 'uniform':
return self.uniform_clustering(features)
else:
# For non-uniform methods, follow the original procedure
outlier_mask = self.detect_outliers(features)

if np.sum(~outlier_mask) > self.n_clusters:
features = features[~outlier_mask]

# Compute cosine similarity matrix for spectral clustering
if self.clustering_method == 'spectral':
similarity_matrix = cosine_similarity(features)
labels = self.clusterer.fit_predict(similarity_matrix)
else:
# For agglomerative, k-means, and other clustering methods that don't require a precomputed matrix
labels = self.clusterer.fit_predict(features)

# Organize frames into clusters based on labels
clusters = [[] for _ in range(self.n_clusters)]
for idx, label in enumerate(labels):
clusters[label].append(idx)
return clusters

+ 69
- 0
utils/metrics.py View File

@@ -0,0 +1,69 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

import numpy as np
import torch

def compute_metrics(x):
sx = np.sort(-x, axis=1)
d = np.diag(-x)
d = d[:, np.newaxis]
ind = sx - d
ind = np.where(ind == 0)
ind = ind[1]
metrics = {}
metrics['R@1'] = float(np.sum(ind == 0)) * 100 / len(ind)
metrics['R@5'] = float(np.sum(ind < 5)) * 100 / len(ind)
metrics['R@10'] = float(np.sum(ind < 10)) * 100 / len(ind)
metrics["MedianR"] = np.median(ind) + 1
metrics["MeanR"] = np.mean(ind) + 1
# metrics["cols"] = [int(i) for i in list(ind)]
return metrics

def print_computed_metrics(metrics):
r1 = metrics['R@1']
r5 = metrics['R@5']
r10 = metrics['R@10']
mr = metrics['MR']
print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr))

# below two functions directly come from: https://github.com/Deferf/Experiments
def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]):
if not torch.is_tensor(sim_tensor):
sim_tensor = torch.tensor(sim_tensor)

# Permute sim_tensor so it represents a sequence of text-video similarity matrices.
# Then obtain the double argsort to position the rank on the diagonal
stacked_sim_matrices = sim_tensor.permute(1, 0, 2)
first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True)
second_argsort = torch.argsort(first_argsort, dim = -1, descending= False)

# Extracts ranks i.e diagonals
ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2))

# Now we need to extract valid ranks, as some belong to inf padding values
permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2))
mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data))
valid_ranks = ranks[mask]
# A quick dimension check validates our results, there may be other correctness tests pending
# Such as dot product localization, but that is for other time.
#assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict])
if not torch.is_tensor(valid_ranks):
valid_ranks = torch.tensor(valid_ranks)
results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k}
results["MedianR"] = float(torch.median(valid_ranks + 1))
results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1))
results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1))
results['MR'] = results["MedianR"]
return results

def tensor_video_to_text_sim(sim_tensor):
if not torch.is_tensor(sim_tensor):
sim_tensor = torch.tensor(sim_tensor)
# Code to avoid nans
sim_tensor[sim_tensor != sim_tensor] = float('-inf')
# Forms a similarity matrix for use with rank at k
values, _ = torch.max(sim_tensor, dim=1, keepdim=True)
return torch.squeeze(values).T

Loading…
Cancel
Save