@@ -0,0 +1,89 @@ | |||
import argparse | |||
import os | |||
import torch | |||
media_path_root = "path/to/media" | |||
class Config: | |||
def __init__(self): | |||
self.parser = argparse.ArgumentParser(description='Train a student network.') | |||
self.add_arguments() | |||
self.args = self.parse() | |||
self.update_paths() | |||
def add_arguments(self): | |||
# General settings | |||
self.parser.add_argument('--seed', type=int, default=12345, help='Init Seed') | |||
self.parser.add_argument('--device', type=int, default=0, help='Device number to use for training') | |||
self.parser.add_argument("--do_train", type=int, default=1, help="Whether to run training.") | |||
self.parser.add_argument("--do_eval", type=int, default=1, help="Whether to run evaluation.") | |||
self.parser.add_argument('--dataset', type=str, default='activity-net', choices=['activity-net', 'msrvtt'], help='Dataset to use') | |||
self.parser.add_argument('--teacher_model', type=str, default='CLIP', choices=['CLIP', 'DiffusionRet', 'EMCL', 'HBI', 'DiCoSA'], help='Teacher model to use') | |||
self.parser.add_argument('--workers', type=int, default=12, help='Number of workers') | |||
# Model configurations | |||
self.parser.add_argument('--dropout_ratio', type=float, default=0.3, help='Dropout ratio for the model') | |||
self.parser.add_argument('--fc1_size', type=int, default=512, help='First FC layer Hidden size') | |||
self.parser.add_argument('--fc2_size', type=int, default=256, help='Second FC layer Hidden size') | |||
self.parser.add_argument('--fc3_size', type=int, default=128, help='Third FC layer Hidden size') | |||
self.parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'], help='Backbone model') | |||
self.parser.add_argument('--n_trainable_backbone_layers', type=int, default=1, help='Number of trainable backbone layers') | |||
self.parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'leakyrelu', 'gelu'], help='Activation function') | |||
self.parser.add_argument('--normalization_layer', type=str, default='sigmoid', choices=['sigmoid', 'softmax', 'gumbel', 'k_top'], help='Normalization layer type (for text-to-video retireval pipeline)') | |||
self.parser.add_argument('--init_tau', default=5.0, type=float, help="annealing init temperature") | |||
self.parser.add_argument('--min_tau', default=0.5, type=float, help="min temperature to anneal to") | |||
self.parser.add_argument('--decay_factor', default=0.045, type=float, help="exp decay factor per epoch") | |||
self.parser.add_argument('--store_complete_model', type=int, default=0, help='Store the compelte model') | |||
self.parser.add_argument('--load_complete_model', type=int, default=0, help='1: Load') | |||
self.parser.add_argument('--trained_model_epoch', type=int, default=100, help='Load the trained model from which epoch') | |||
self.parser.add_argument('--store_backbone', type=int, default=0, help='Store just backbone') | |||
self.parser.add_argument('--load_backbone', type=int, default=1, help='1: Load') | |||
self.parser.add_argument('--trained_backbone_epoch', type=int, default=100, help='Load a trained backbone model from which epoch') | |||
# Dataloader configurations | |||
self.parser.add_argument('--normalization', type=str, default='binary', choices=['raw', 'min_max', 'binary'], help='Normalization type for dataloader') | |||
self.parser.add_argument('--binary_percentage', type=float, default=0.1, help='Binary percentage for dataloader') | |||
# Training configurations | |||
self.parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training') | |||
self.parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for training') | |||
self.parser.add_argument('--similarity_matrix_loss', type=str, default='kl', choices=['mse','kl'], help='frame2frame loss') | |||
self.parser.add_argument('--use_weighted_loss', type=int, default=0, help='Whether use weighted BCE, to handle unbalanced classes') | |||
self.parser.add_argument('--alpha', type=float, default=1, help='Weight for f2f loss') | |||
self.parser.add_argument('--beta', type=float, default=1, help='Weight for f2t loss') | |||
self.parser.add_argument('--gamma', type=float, default=1, help='Weight for v2t loss') | |||
self.parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') | |||
self.parser.add_argument('--use_scheduler', type=int, default=1, help='Use a learning rate scheduler') | |||
self.parser.add_argument('--scheduler_step_size', type=float, default=5, help='After how many epochs the learning rate will be adjusted') | |||
self.parser.add_argument('--scheduler_gamma', type=float, default=0.4, help='The factor by which the learning rate will be reduced at each step.') | |||
self.parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay for the optimizer. Set 0 to disable it') | |||
self.parser.add_argument('--do_cluster', type=int, default=1, help='Calculate results with cluster as well or not.') | |||
def parse(self): | |||
return self.parser.parse_args() | |||
def update_paths(self): | |||
self.args.device = torch.device(f'cuda:{self.args.device}' if torch.cuda.is_available() else "cpu") | |||
self.args.raw_backbone_path = f"{media_path_root}/saved_models/student_pretrained_backbones/{self.args.backbone}_pretrained.pth" | |||
self.args.st_scores_path = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/st_scores/st_scores.json" | |||
self.args.trained_backbone_path = f"{media_path_root}/experiments/saved_models/{self.args.backbone}/{self.args.similarity_matrix_loss}" | |||
self.args.trained_model_path = f"{self.args.trained_backbone_path}/complete_model/" | |||
self.args.log_dir = f'logs/{self.args.project_name}/' | |||
os.makedirs(self.args.st_scores_path, exist_ok=True) | |||
os.makedirs(self.args.trained_backbone_path, exist_ok=True) | |||
os.makedirs(self.args.trained_model_path, exist_ok=True) | |||
os.makedirs(self.args.log_dir, exist_ok=True) | |||
self.args.paths = {'train':{}, 'valid':{}} | |||
for data_split in ['train', 'valid']: | |||
split = 'val_1' if data_split=='valid' else data_split | |||
self.args.paths[data_split]['frames'] = f"{media_path_root}/{self.args.dataset}/sampled_frames/{split}" | |||
self.args.paths[data_split]['frame_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/frames/{split}" | |||
self.args.paths[data_split]['text_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/description/{split}" | |||
self.args.paths[data_split]['gt_scores'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/gt_scores/gt_{split}.json" |
@@ -0,0 +1,62 @@ | |||
import cv2 | |||
from PIL import Image | |||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode | |||
class RawVideoExtractor(): | |||
def __init__(self, centercrop=False, framerate=-1, size=224, to_tensor=True): | |||
self.centercrop = centercrop | |||
self.framerate = framerate | |||
self.to_tensor = to_tensor | |||
self.transform = self._transform(size) | |||
def _transform(self, n_px): | |||
if self.to_tensor: | |||
return Compose([ | |||
Resize(n_px, interpolation=InterpolationMode.BICUBIC), | |||
CenterCrop(n_px), ToTensor(), | |||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711))]) | |||
else: | |||
return Compose([Resize(n_px, interpolation=InterpolationMode.BICUBIC),CenterCrop(n_px)]) | |||
def get_video_data(self, video_path, start_time=None, end_time=None): | |||
if start_time is not None or end_time is not None: | |||
assert start_time > -1 and end_time > start_time | |||
assert self.framerate > -1 | |||
cap = cv2.VideoCapture(video_path) | |||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |||
video_fps = cap.get(cv2.CAP_PROP_FPS) | |||
start_frame = int(start_time * video_fps) if start_time else 0 | |||
end_frame = int(end_time * video_fps) if end_time else frame_count - 1 | |||
interval = 1 | |||
if self.framerate > 0: | |||
interval = video_fps / self.framerate | |||
else: | |||
self.framerate = video_fps | |||
if interval == 0: | |||
interval = 1 | |||
images = [] | |||
for i in range(frame_count): | |||
ret, frame = cap.read() | |||
if not ret: | |||
break | |||
if i >= start_frame and i <= end_frame: | |||
if len(images) * interval < i - start_frame: | |||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |||
image = Image.fromarray(frame_rgb) | |||
image = self.transform(image) | |||
images.append(image) | |||
cap.release() | |||
return images |
@@ -0,0 +1,65 @@ | |||
import os | |||
import json | |||
import torch | |||
from torchvision.io import read_image | |||
from torch.utils.data import Dataset | |||
from torchvision import transforms | |||
class VideoDescriptionDataloader(Dataset): | |||
def __init__(self, cfg, data_type): | |||
self.frames_dir = cfg.paths[data_type]['frames'] | |||
self.frame_features_dir = cfg.paths[data_type]['frame_features'] | |||
self.text_features_dir = cfg.paths[data_type]['text_features'] | |||
self.gt_scores_file_path = cfg.paths[data_type]['gt_scores'] | |||
self.transform = self._get_transforms() | |||
self.binary_percentage = cfg.binary_percentage | |||
if cfg.normalization == 'min_max' or data_type=='valid': | |||
self.normalize_scores = self._min_max_normalize_scores | |||
elif cfg.normalization == 'binary': | |||
self.normalize_scores = self._to_binary_labels | |||
else: | |||
raise ValueError(f"Unsupported normalization method: {cfg.normalization}") | |||
with open(self.gt_scores_file_path, "r") as f: | |||
self.gt_scores = json.load(f) | |||
self.video_ids = list(self.gt_scores.keys()) | |||
def __len__(self): | |||
return len(self.video_ids) | |||
def __getitem__(self, idx): | |||
video_id = self.video_ids[idx] | |||
video_frames_dir = os.path.join(self.frames_dir, self.video_ids[idx]) | |||
frames_tensor = torch.stack([self.transform(read_image(os.path.join(video_frames_dir, frame_file)).float()) | |||
for frame_file in sorted(os.listdir(video_frames_dir))]) | |||
text_features = torch.load(os.path.join(self.text_features_dir, f'{self.video_ids[idx]}.pt')) | |||
frame_features = torch.load(os.path.join(self.frame_features_dir, f'{self.video_ids[idx]}.pt')) | |||
gt_score = torch.tensor(self.gt_scores[self.video_ids[idx]], dtype=torch.float32) | |||
gt_score = self.normalize_scores(gt_score) | |||
return frames_tensor, text_features, frame_features, gt_score | |||
@staticmethod | |||
def _get_transforms(): | |||
return transforms.Compose([transforms.Normalize([0.485, 0.456, 0.406], | |||
[0.229, 0.224, 0.225])]) | |||
@staticmethod | |||
def _min_max_normalize_scores(score): | |||
min_val, max_val = score.min(), score.max() | |||
if min_val != max_val: | |||
return (score - min_val) / (max_val - min_val) | |||
return torch.full_like(score, 0.5) | |||
def _to_binary_labels(self, score): | |||
num_top_elements = max(int(len(score) * self.binary_percentage) , 1) | |||
sorted_indices = score.argsort(descending=True) | |||
binary_labels = torch.zeros_like(score) | |||
binary_labels[sorted_indices[:num_top_elements]] = 1 | |||
return binary_labels |
@@ -0,0 +1,497 @@ | |||
video6513 | |||
video6514 | |||
video6515 | |||
video6516 | |||
video6517 | |||
video6518 | |||
video6519 | |||
video6520 | |||
video6521 | |||
video6522 | |||
video6523 | |||
video6524 | |||
video6525 | |||
video6526 | |||
video6527 | |||
video6528 | |||
video6529 | |||
video6530 | |||
video6531 | |||
video6532 | |||
video6533 | |||
video6534 | |||
video6535 | |||
video6536 | |||
video6537 | |||
video6538 | |||
video6539 | |||
video6540 | |||
video6541 | |||
video6542 | |||
video6543 | |||
video6544 | |||
video6545 | |||
video6546 | |||
video6547 | |||
video6548 | |||
video6549 | |||
video6550 | |||
video6551 | |||
video6552 | |||
video6553 | |||
video6554 | |||
video6555 | |||
video6556 | |||
video6557 | |||
video6558 | |||
video6559 | |||
video6560 | |||
video6561 | |||
video6562 | |||
video6563 | |||
video6564 | |||
video6565 | |||
video6566 | |||
video6567 | |||
video6568 | |||
video6569 | |||
video6570 | |||
video6571 | |||
video6572 | |||
video6573 | |||
video6574 | |||
video6575 | |||
video6576 | |||
video6577 | |||
video6578 | |||
video6579 | |||
video6580 | |||
video6581 | |||
video6582 | |||
video6583 | |||
video6584 | |||
video6585 | |||
video6586 | |||
video6587 | |||
video6588 | |||
video6589 | |||
video6590 | |||
video6591 | |||
video6592 | |||
video6593 | |||
video6594 | |||
video6595 | |||
video6596 | |||
video6597 | |||
video6598 | |||
video6599 | |||
video6600 | |||
video6601 | |||
video6602 | |||
video6603 | |||
video6604 | |||
video6605 | |||
video6606 | |||
video6607 | |||
video6608 | |||
video6609 | |||
video6610 | |||
video6611 | |||
video6612 | |||
video6613 | |||
video6614 | |||
video6615 | |||
video6616 | |||
video6617 | |||
video6618 | |||
video6619 | |||
video6620 | |||
video6621 | |||
video6622 | |||
video6623 | |||
video6624 | |||
video6625 | |||
video6626 | |||
video6627 | |||
video6628 | |||
video6629 | |||
video6630 | |||
video6631 | |||
video6632 | |||
video6633 | |||
video6634 | |||
video6635 | |||
video6636 | |||
video6637 | |||
video6638 | |||
video6639 | |||
video6640 | |||
video6641 | |||
video6642 | |||
video6643 | |||
video6644 | |||
video6645 | |||
video6646 | |||
video6647 | |||
video6648 | |||
video6649 | |||
video6650 | |||
video6651 | |||
video6652 | |||
video6653 | |||
video6654 | |||
video6655 | |||
video6656 | |||
video6657 | |||
video6658 | |||
video6659 | |||
video6660 | |||
video6661 | |||
video6662 | |||
video6663 | |||
video6664 | |||
video6665 | |||
video6666 | |||
video6667 | |||
video6668 | |||
video6669 | |||
video6670 | |||
video6671 | |||
video6672 | |||
video6673 | |||
video6674 | |||
video6675 | |||
video6676 | |||
video6677 | |||
video6678 | |||
video6679 | |||
video6680 | |||
video6681 | |||
video6682 | |||
video6683 | |||
video6684 | |||
video6685 | |||
video6686 | |||
video6687 | |||
video6688 | |||
video6689 | |||
video6690 | |||
video6691 | |||
video6692 | |||
video6693 | |||
video6694 | |||
video6695 | |||
video6696 | |||
video6697 | |||
video6698 | |||
video6699 | |||
video6700 | |||
video6701 | |||
video6702 | |||
video6703 | |||
video6704 | |||
video6705 | |||
video6706 | |||
video6707 | |||
video6708 | |||
video6709 | |||
video6710 | |||
video6711 | |||
video6712 | |||
video6713 | |||
video6714 | |||
video6715 | |||
video6716 | |||
video6717 | |||
video6718 | |||
video6719 | |||
video6720 | |||
video6721 | |||
video6722 | |||
video6723 | |||
video6724 | |||
video6725 | |||
video6726 | |||
video6727 | |||
video6728 | |||
video6729 | |||
video6730 | |||
video6731 | |||
video6732 | |||
video6733 | |||
video6734 | |||
video6735 | |||
video6736 | |||
video6737 | |||
video6738 | |||
video6739 | |||
video6740 | |||
video6741 | |||
video6742 | |||
video6743 | |||
video6744 | |||
video6745 | |||
video6746 | |||
video6747 | |||
video6748 | |||
video6749 | |||
video6750 | |||
video6751 | |||
video6752 | |||
video6753 | |||
video6754 | |||
video6755 | |||
video6756 | |||
video6757 | |||
video6758 | |||
video6759 | |||
video6760 | |||
video6761 | |||
video6762 | |||
video6763 | |||
video6764 | |||
video6765 | |||
video6766 | |||
video6767 | |||
video6768 | |||
video6769 | |||
video6770 | |||
video6771 | |||
video6772 | |||
video6773 | |||
video6774 | |||
video6775 | |||
video6776 | |||
video6777 | |||
video6778 | |||
video6779 | |||
video6780 | |||
video6781 | |||
video6782 | |||
video6783 | |||
video6784 | |||
video6785 | |||
video6786 | |||
video6787 | |||
video6788 | |||
video6789 | |||
video6790 | |||
video6791 | |||
video6792 | |||
video6793 | |||
video6794 | |||
video6795 | |||
video6796 | |||
video6797 | |||
video6798 | |||
video6799 | |||
video6800 | |||
video6801 | |||
video6802 | |||
video6803 | |||
video6804 | |||
video6805 | |||
video6806 | |||
video6807 | |||
video6808 | |||
video6809 | |||
video6810 | |||
video6811 | |||
video6812 | |||
video6813 | |||
video6814 | |||
video6815 | |||
video6816 | |||
video6817 | |||
video6818 | |||
video6819 | |||
video6820 | |||
video6821 | |||
video6822 | |||
video6823 | |||
video6824 | |||
video6825 | |||
video6826 | |||
video6827 | |||
video6828 | |||
video6829 | |||
video6830 | |||
video6831 | |||
video6832 | |||
video6833 | |||
video6834 | |||
video6835 | |||
video6836 | |||
video6837 | |||
video6838 | |||
video6839 | |||
video6840 | |||
video6841 | |||
video6842 | |||
video6843 | |||
video6844 | |||
video6845 | |||
video6846 | |||
video6847 | |||
video6848 | |||
video6849 | |||
video6850 | |||
video6851 | |||
video6852 | |||
video6853 | |||
video6854 | |||
video6855 | |||
video6856 | |||
video6857 | |||
video6858 | |||
video6859 | |||
video6860 | |||
video6861 | |||
video6862 | |||
video6863 | |||
video6864 | |||
video6865 | |||
video6866 | |||
video6867 | |||
video6868 | |||
video6869 | |||
video6870 | |||
video6871 | |||
video6872 | |||
video6873 | |||
video6874 | |||
video6875 | |||
video6876 | |||
video6877 | |||
video6878 | |||
video6879 | |||
video6880 | |||
video6881 | |||
video6882 | |||
video6883 | |||
video6884 | |||
video6885 | |||
video6886 | |||
video6887 | |||
video6888 | |||
video6889 | |||
video6890 | |||
video6891 | |||
video6892 | |||
video6893 | |||
video6894 | |||
video6895 | |||
video6896 | |||
video6897 | |||
video6898 | |||
video6899 | |||
video6900 | |||
video6901 | |||
video6902 | |||
video6903 | |||
video6904 | |||
video6905 | |||
video6906 | |||
video6907 | |||
video6908 | |||
video6909 | |||
video6910 | |||
video6911 | |||
video6912 | |||
video6913 | |||
video6914 | |||
video6915 | |||
video6916 | |||
video6917 | |||
video6918 | |||
video6919 | |||
video6920 | |||
video6921 | |||
video6922 | |||
video6923 | |||
video6924 | |||
video6925 | |||
video6926 | |||
video6927 | |||
video6928 | |||
video6929 | |||
video6930 | |||
video6931 | |||
video6932 | |||
video6933 | |||
video6934 | |||
video6935 | |||
video6936 | |||
video6937 | |||
video6938 | |||
video6939 | |||
video6940 | |||
video6941 | |||
video6942 | |||
video6943 | |||
video6944 | |||
video6945 | |||
video6946 | |||
video6947 | |||
video6948 | |||
video6949 | |||
video6950 | |||
video6951 | |||
video6952 | |||
video6953 | |||
video6954 | |||
video6955 | |||
video6956 | |||
video6957 | |||
video6958 | |||
video6959 | |||
video6960 | |||
video6961 | |||
video6962 | |||
video6963 | |||
video6964 | |||
video6965 | |||
video6966 | |||
video6967 | |||
video6968 | |||
video6969 | |||
video6970 | |||
video6971 | |||
video6972 | |||
video6973 | |||
video6974 | |||
video6975 | |||
video6976 | |||
video6977 | |||
video6978 | |||
video6979 | |||
video6980 | |||
video6981 | |||
video6982 | |||
video6983 | |||
video6984 | |||
video6985 | |||
video6986 | |||
video6987 | |||
video6988 | |||
video6989 | |||
video6990 | |||
video6991 | |||
video6992 | |||
video6993 | |||
video6994 | |||
video6995 | |||
video6996 | |||
video6997 | |||
video6998 | |||
video6999 | |||
video7000 | |||
video7001 | |||
video7002 | |||
video7003 | |||
video7004 | |||
video7005 | |||
video7006 | |||
video7007 | |||
video7008 | |||
video7009 |
@@ -0,0 +1,63 @@ | |||
import os | |||
import torch | |||
from config import Config | |||
from src.model import SaliencyNet | |||
from torch.utils.data import DataLoader | |||
from src.train import SalientFrameSamplerTrainer | |||
from dataloaders.t2v_dataloader import VideoDescriptionDataloader | |||
def load_weights(path, which_epoch): | |||
weights, last_epoch = None, None | |||
available_models = [name for name in os.listdir(path) if name.endswith(".pt")] | |||
if available_models: | |||
last_epoch = max([int(name[6:-3]) for name in available_models]) | |||
last_epoch = min(last_epoch, which_epoch) | |||
weights = torch.load(os.path.join(path, f'epoch_{last_epoch}.pt')) | |||
return weights, last_epoch | |||
def set_seeds(seed: int): | |||
os.environ['PYTHONHASHSEED'] = str(seed) | |||
torch.manual_seed(seed) | |||
torch.cuda.manual_seed(seed) | |||
torch.backends.cudnn.benchmark = False | |||
torch.backends.cudnn.deterministic = True | |||
def main(cfg): | |||
set_seeds(cfg.seed) | |||
model = SaliencyNet(cfg) | |||
model.to(cfg.device) | |||
if cfg.load_complete_model: | |||
weights, last_epoch = load_weights(cfg.trained_model_path, cfg.trained_model_epoch) | |||
if weights: | |||
model.load_state_dict(weights) | |||
print(f'Complete model -{cfg.backbone} trained on {cfg.similarity_matrix_loss} and {cfg.saliency_matching_loss} lossese from epoch #{last_epoch}') | |||
elif cfg.load_backbone: | |||
weights, last_epoch = load_weights(cfg.trained_backbone_path, cfg.trained_backbone_epoch) | |||
if weights: | |||
model.pretrained.load_state_dict(weights) | |||
print(f'{cfg.backbone} backbone trained on {cfg.similarity_matrix_loss} loss loaded from epoch #{last_epoch}') | |||
else: | |||
print(f'{cfg.backbone} backbone loaded from scratch.') | |||
train_dataset = VideoDescriptionDataloader(cfg, data_type='train') | |||
val_dataset = VideoDescriptionDataloader(cfg, data_type='valid') | |||
train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=cfg.workers) | |||
val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=cfg.workers) | |||
trainer = SalientFrameSamplerTrainer(model, train_dataloader, val_dataloader, cfg) | |||
trainer.train() | |||
if __name__ == "__main__": | |||
torch.multiprocessing.set_start_method("spawn") | |||
cfg = Config().args | |||
main(cfg) |
@@ -0,0 +1,7 @@ | |||
torch | |||
opencv-python | |||
Pillow | |||
timm | |||
transformers | |||
torchvision | |||
scikit-learn |
@@ -0,0 +1,50 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
class GumbelSoftmaxTopK(nn.Module): | |||
def __init__(self, k, tau): | |||
super(GumbelSoftmaxTopK, self).__init__() | |||
self.k = k | |||
self.tau = tau | |||
def forward(self, logits): | |||
logits = torch.log(nn.Softmax(dim=0)(logits).clamp(min=1e-8)) | |||
""" | |||
Adapted from: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/DL2/sampling/subsets.html#Subset-Sampler-Class | |||
and https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax | |||
""" | |||
m = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format), | |||
torch.ones_like(logits, memory_format= torch.legacy_contiguous_format)) | |||
gumbels = m.sample() | |||
y = logits + gumbels | |||
_EPS = torch.tensor(1e-40).to(logits.device) | |||
khot = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) | |||
onehot_approx = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) | |||
for i in range(self.k): | |||
khot_mask = torch.maximum(1.0 - onehot_approx, _EPS) | |||
y += khot_mask.log() | |||
onehot_approx = nn.Softmax(dim=0)(y / self.tau) | |||
khot = torch.add(khot, onehot_approx) | |||
return khot | |||
class GumbelSoftmax(nn.Module): | |||
def __init__(self, tau, hard=False): | |||
super(GumbelSoftmax, self).__init__() | |||
self.tau = tau | |||
self.hard = hard | |||
def forward(self, logits): | |||
m = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format), | |||
torch.ones_like(logits, memory_format= torch.legacy_contiguous_format)) | |||
gumbels = m.sample() | |||
y = logits + gumbels | |||
y_soft = nn.Softmax(dim=0)(y / self.tau) | |||
return y_soft |
@@ -0,0 +1,131 @@ | |||
import timm | |||
import torch | |||
import torch.nn as nn | |||
from src.gumbel_softmax import GumbelSoftmax, GumbelSoftmaxTopK | |||
class SaliencyNet(nn.Module): | |||
def __init__(self, cfg): | |||
super(SaliencyNet, self).__init__() | |||
self.k = 8 | |||
self.tau = cfg.init_tau | |||
self.min_tau = cfg.min_tau | |||
self.decay_factor = cfg.decay_factor | |||
self.raw_backbone_path = cfg.raw_backbone_path | |||
self.configure_pretrained(cfg.backbone, cfg.n_trainable_backbone_layers) | |||
embed_dim = self.pretrained.num_features | |||
self.fc1 = nn.Linear(embed_dim, cfg.fc1_size) | |||
self.fc2 = nn.Linear(cfg.fc1_size, cfg.fc2_size) | |||
self.fc3 = nn.Linear(cfg.fc2_size, cfg.fc3_size) | |||
self.fc4 = nn.Linear(cfg.fc3_size, 1) | |||
# self.bn1 = nn.BatchNorm1d(fc1_size) # Batch normalization layer | |||
self.dropout = nn.Dropout(cfg.dropout_ratio) | |||
self.sigmoid = nn.Sigmoid() | |||
if cfg.activation == 'relu': | |||
self.activation = nn.ReLU() | |||
elif cfg.activation == 'leakyrelu': | |||
self.activation = nn.LeakyReLU() | |||
elif cfg.activation == 'gelu': | |||
self.activation = nn.GELU() | |||
else: | |||
raise ValueError("Invalid activation type. Choose 'relu', 'leakyrelu', or 'gelu'.") | |||
self.normalization_layer = self.init_normalization_layer(cfg.normalization_layer, self.tau, self.k) | |||
self.initialize_weights() | |||
def init_normalization_layer(self, normalization_type, tau, k): | |||
if normalization_type == 'raw': | |||
normalization_layer = lambda x: x | |||
elif normalization_type == 'sigmoid': | |||
normalization_layer = nn.Sigmoid() | |||
elif normalization_type == 'softmax': | |||
normalization_layer = nn.Softmax(dim=1) | |||
elif normalization_type == 'gumbel': | |||
normalization_layer = GumbelSoftmax(tau) | |||
elif normalization_type == 'k_top': | |||
normalization_layer = GumbelSoftmaxTopK(k, tau) | |||
return normalization_layer | |||
def initialize_weights(self): | |||
for module in self.modules(): | |||
if isinstance(module, nn.Linear): | |||
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |||
if module.bias is not None: | |||
nn.init.constant_(module.bias, 0) | |||
def forward(self, preprocessed_images): | |||
features = self.pretrained(preprocessed_images) | |||
x = self.activation(self.fc1(features)) | |||
x = self.dropout(x) | |||
x = self.activation(self.fc2(x)) | |||
x = self.dropout(x) | |||
x = self.activation(self.fc3(x)) | |||
x = self.fc4(x) | |||
frame_scores = torch.flatten(x) | |||
frame_weights = torch.flatten(self.normalization_layer(x)) | |||
return frame_scores, frame_weights, features | |||
def anneal_tau(self): | |||
""" Anneals the temperature parameter τ using exponential decay """ | |||
self.tau = self.tau * torch.exp(-self.decay_factor * torch.tensor(1.0)) | |||
self.tau = max(self.tau, self.min_tau) # To prevent τ from becoming too small | |||
def configure_pretrained(self, backbone, n_trainable_backbone_layers=1): | |||
# backbones = ['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'] | |||
self.pretrained = timm.create_model(backbone, pretrained=False, checkpoint_path=self.raw_backbone_path) | |||
# Freeze the early layers of the model | |||
for param in self.pretrained.parameters(): | |||
param.requires_grad = False | |||
# disable last layer of resnet18 | |||
if backbone =='resnet18': | |||
self.pretrained.fc = nn.Identity() | |||
sequential_modules = [child for child in self.pretrained.children() if isinstance(child, nn.Sequential)] | |||
elif backbone in ['efficientnet_b0', 'mobilenetv2_100','mobilenetv3_large_100'] : | |||
self.pretrained.classifier = nn.Identity() | |||
sequential_modules = [child for child in self.pretrained.blocks.children()] | |||
else: | |||
assert('Not supported backbone model!') | |||
num_seq_blocks = len(sequential_modules) | |||
if n_trainable_backbone_layers ==0: | |||
pass | |||
elif n_trainable_backbone_layers >= num_seq_blocks: | |||
for param in self.pretrained.parameters(): | |||
param.requires_grad = True | |||
else: | |||
# Select the last `n_trainable_backbone_layers` of sequential modules to be trainable | |||
layers_to_train = sequential_modules[-n_trainable_backbone_layers:] | |||
for layer in layers_to_train: | |||
for param in layer.parameters(): | |||
param.requires_grad = True | |||
if n_trainable_backbone_layers > 0: | |||
last_layers= [] | |||
if hasattr(self.pretrained, 'conv_head'): | |||
last_layers.append(self.pretrained.conv_head.weight) | |||
if hasattr(self.pretrained, 'bn2'): | |||
last_layers.append(self.pretrained.bn2.weight) | |||
last_layers.append(self.pretrained.bn2.bias) | |||
for param in last_layers: | |||
param.requires_grad = True |
@@ -0,0 +1,378 @@ | |||
import os | |||
import torch | |||
from tqdm import tqdm, trange | |||
import torch.nn as nn | |||
import logging | |||
import torch.nn.functional as F | |||
from tabulate import tabulate | |||
from torch.utils.data import DataLoader | |||
from utils.cluster_frames import VideoClusterer | |||
from utils.metrics import compute_metrics | |||
class SalientFrameSamplerTrainer: | |||
def __init__(self, model, train_dataloader: DataLoader, val_dataloader: DataLoader, cfg): | |||
self.cfg = cfg | |||
self.model = model | |||
self.train_dataloader = train_dataloader | |||
self.val_dataloader = val_dataloader | |||
self.init_logging() | |||
self.criterion1 = kl_divergence_loss() | |||
self.criterion2 = nn.BCEWithLogitsLoss() | |||
self.criterion3 = CrossEn() | |||
self.optimizer = self.init_optimizer() | |||
if self.cfg.use_scheduler: | |||
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.cfg.scheduler_step_size, gamma=self.cfg.scheduler_gamma) | |||
def train(self): | |||
for epoch in trange(self.cfg.epochs, desc="Epoch", ncols=100): | |||
log_data = {} | |||
if self.cfg.do_train: | |||
train_loss = self.run_training_epoch() | |||
log_data.update({ | |||
"train_loss(f2f)": train_loss['f2f'], | |||
"train_loss(f2t)": train_loss['f2t'], | |||
"train_loss(v2t)": train_loss['v2t'], | |||
}) | |||
if self.cfg.do_eval: | |||
valid_loss, t2v_metrics, v2t_metrics = self.run_validation_epoch() | |||
log_data.update({ | |||
"valid_loss(f2f)": valid_loss['f2f'], | |||
"valid_loss(f2t)": valid_loss['f2t'], | |||
"valid_loss(v2t)": valid_loss['v2t'], | |||
}) | |||
for key in t2v_metrics.keys(): | |||
for r in [1, 5, 10]: | |||
log_data[f"[t2v] R@{r} (k={key})"] = t2v_metrics[key][f'R@{r}'] | |||
log_data[f"[v2t] R@{r} (k={key})"] = v2t_metrics[key][f'R@{r}'] | |||
self.log_metrics(epoch, log_data) | |||
self.model.anneal_tau() | |||
if hasattr(self,'scheduler'): | |||
self.scheduler.step() | |||
if self.cfg.store_backbone: | |||
torch.save(self.model.pretrained.state_dict(), os.path.join(self.cfg.trained_backbone_path, f"epoch_{epoch+1}.pt")) | |||
if self.cfg.store_complete_model: | |||
torch.save(self.model.state_dict(), os.path.join(self.args.trained_model_path, f"epoch_{epoch+1}.pt")) | |||
def run_training_epoch(self): | |||
self.model.train() | |||
epoch_train_loss = {'f2f': 0.0, 'f2t': 0.0, 'v2t': 0.0} | |||
video_embeddings = [] | |||
text_embeddings = [] | |||
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||
for i, (frames, teacher_text_features, teacher_frame_features, gt_score) in enumerate(tqdm(self.train_dataloader, desc='Train', leave=False, ncols=100)): | |||
torch.cuda.empty_cache() | |||
frames = frames[0].to(self.cfg.device) | |||
teacher_frame_features = teacher_frame_features[0].to(self.cfg.device) | |||
teacher_text_features = teacher_text_features[0].to(self.cfg.device) | |||
gt_score = gt_score[0].to(self.cfg.device) | |||
frame_scores, frame_weights, student_frame_features = self.model(frames) | |||
# Frame-to-Frame loss | |||
f2f_loss = self.criterion1(student_frame_features, teacher_frame_features) | |||
# Frame-to-Text loss | |||
f2t_loss = self.criterion2(frame_scores, gt_score) | |||
# Update accumulators | |||
f2f_loss_accum += f2f_loss.item() / self.cfg.batch_size | |||
f2t_loss_accum += f2t_loss.item() / self.cfg.batch_size | |||
# Calculating the in-pipeline loss for v2t | |||
video_features = torch.matmul(teacher_frame_features.T, frame_weights) / torch.sum(frame_weights) | |||
video_embeddings.append(video_features) | |||
text_embeddings.append(teacher_text_features) | |||
if (i + 1) % self.cfg.batch_size == 0 or i == len(self.train_dataloader) - 1: | |||
sim_matrix = self.compute_sim_matrix(video_embeddings, text_embeddings) | |||
sim_loss1 = self.criterion3(sim_matrix) | |||
sim_loss2 = self.criterion3(sim_matrix.T) | |||
v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||
# Aggregate the total loss with regularization terms | |||
total_loss = self.alpha * f2f_loss_accum + self.beta * f2t_loss_accum + self.gamma * v2t_loss | |||
self.optimizer.zero_grad() | |||
total_loss.backward() | |||
# Gradient clipping | |||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |||
self.optimizer.step() | |||
# Update epoch losses | |||
epoch_train_loss['f2f'] += f2f_loss_accum | |||
epoch_train_loss['f2t'] += f2t_loss_accum | |||
epoch_train_loss['v2t'] += v2t_loss.item() | |||
# Reset accumulators | |||
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||
video_embeddings = [] | |||
text_embeddings = [] | |||
# Normalize epoch losses by the number of batches | |||
for key in epoch_train_loss.keys(): | |||
epoch_train_loss[key] /= len(self.train_dataloader) | |||
return epoch_train_loss | |||
def run_validation_epoch(self): | |||
self.model.eval() | |||
epoch_valid_loss = {'f2f': 0.0, 'f2t': 0.0, 'v2t': 0.0} | |||
video_embeddings = [] | |||
text_embeddings = [] | |||
all_video_embeddings = {str(k): {'without_clustering': [], | |||
'uniform': [], 'spectral': [], 'agglomerative': [], 'kmeans': []} for k in [1,2,4,8,16]} | |||
all_text_embeddings = [] | |||
all_st_scores = [] | |||
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||
for i, (frames, teacher_text_features, teacher_frame_features, gt_score) in enumerate(tqdm(self.val_dataloader, desc='Valid', leave=False, ncols=100)): | |||
torch.cuda.empty_cache() | |||
frames = frames[0].to(self.cfg.device) | |||
teacher_frame_features = teacher_frame_features[0].to(self.cfg.device) | |||
teacher_text_features = teacher_text_features[0].to(self.cfg.device) | |||
gt_score = gt_score[0].to(self.cfg.device) | |||
with torch.no_grad(): | |||
frame_scores, frame_weights, student_frame_features = self.model(frames) | |||
f2f_loss = self.criterion1(student_frame_features, teacher_frame_features) | |||
f2t_loss = self.criterion2(frame_scores, gt_score) | |||
f2f_loss_accum += f2f_loss.item() | |||
f2t_loss_accum += f2t_loss.item() | |||
# Calculating the all embeddings, used for calculating t2v_retrieval metrics | |||
all_text_embeddings.append(teacher_text_features) | |||
n_frames = teacher_frame_features.size(0) | |||
sorted_indices = sorted(range(n_frames), key=lambda i: frame_scores[i], reverse=True).copy() | |||
for k in [1,2,4,8,16]: | |||
n_selected_frames = min(n_frames, k) | |||
indices = sorted_indices[:n_selected_frames] | |||
all_video_embeddings[str(k)]['without_clustering'].append(teacher_frame_features[indices].mean(dim=0)) | |||
if self.cfg.do_cluster and k>2: | |||
for clustering_method in ['uniform', 'spectral', 'agglomerative', 'kmeans']: | |||
if len(frame_scores) <= k: #mean of all frames | |||
all_video_embeddings[str(k)][clustering_method].append(teacher_frame_features.mean(dim=0)) | |||
else: | |||
clusterer = VideoClusterer(clustering_method, k) | |||
clusters = clusterer.get_clusters(student_frame_features) | |||
selected_indices = self.select_representative_frame_per_cluster(clusters, frame_scores) | |||
all_video_embeddings[str(k)][clustering_method].append(teacher_frame_features[selected_indices].mean(dim=0)) | |||
# Calculating the in-pipeline loss | |||
video_features = torch.matmul(teacher_frame_features.T, frame_weights) / torch.sum(frame_weights) | |||
video_embeddings.append(video_features) | |||
text_embeddings.append(teacher_text_features) | |||
if (i + 1) % self.cfg.batch_size == 0 or i == len(self.val_dataloader) - 1: | |||
sim_matrix = self.compute_sim_matrix(video_embeddings, text_embeddings) | |||
sim_loss1 = self.criterion3(sim_matrix) | |||
sim_loss2 = self.criterion3(sim_matrix.T) | |||
v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||
v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||
# Update epoch losses | |||
epoch_valid_loss['f2f'] += f2f_loss_accum | |||
epoch_valid_loss['f2t'] += f2t_loss_accum | |||
epoch_valid_loss['v2t'] += v2t_loss.item() | |||
# Reset accumulators | |||
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||
video_embeddings = [] | |||
text_embeddings = [] | |||
# Normalize epoch losses by the number of batches | |||
for key in epoch_valid_loss.keys(): | |||
epoch_valid_loss[key] /= len(self.val_dataloader) | |||
# Compute t2v metrics | |||
t2v_metrics = {} | |||
v2t_metrics = {} | |||
for k, embeddings_dict in all_video_embeddings.items(): | |||
t2v_metrics[str(k)]={} | |||
for scenario, embeddings_list in embeddings_dict.items(): | |||
if embeddings_list: # Check if there are any embeddings for this scenario | |||
sim_matrix = self.compute_sim_matrix(embeddings_list, all_text_embeddings) | |||
_key = f"{k}" if scenario=='without_clustering' else f"{k}-{scenario}" | |||
t2v_metrics[_key] = compute_metrics(sim_matrix.detach().cpu().numpy()) | |||
v2t_metrics[_key] = compute_metrics(sim_matrix.T.detach().cpu().numpy()) | |||
return epoch_valid_loss, t2v_metrics, v2t_metrics | |||
def compute_sim_matrix(self, video_embeddings, text_embeddings): | |||
video_embeddings_tensor = torch.stack(video_embeddings, dim=0) | |||
text_embeddings_tensor = torch.stack(text_embeddings, dim=0) | |||
video_embeddings_norm = video_embeddings_tensor / video_embeddings_tensor.norm(dim=-1, keepdim=True) | |||
text_embeddings_norm = text_embeddings_tensor / text_embeddings_tensor.norm(dim=-1, keepdim=True) | |||
sim_matrix = torch.matmul(video_embeddings_norm, text_embeddings_norm.T) | |||
return sim_matrix | |||
def init_optimizer(self): | |||
if not hasattr(self.cfg, 'backbone_lr_scale'): | |||
return torch.optim.Adam(params=self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay) | |||
backbone_params = [] | |||
last_params = [] | |||
for name, param in self.model.named_parameters(): | |||
if 'pretrain' in name: | |||
backbone_params.append(param) | |||
else: | |||
last_params.append(param) | |||
return torch.optim.Adam([ | |||
{'params': backbone_params, 'lr': self.cfg.lr}, | |||
{'params': last_params, 'lr': self.cfg.lr * self.cfg.backbone_lr_scale} | |||
], weight_decay=self.cfg.weight_decay) | |||
def init_logging(self): | |||
available_logs = os.listdir(self.cfg.log_dir) | |||
last_run = 0 | |||
if available_logs: | |||
last_run = max([int(name.split('-')[0]) for name in available_logs]) | |||
log_file_name = f"{last_run + 1}-local-run.log" | |||
log_file = os.path.join(self.cfg.log_dir, log_file_name) | |||
logging.basicConfig(filename=log_file, level=logging.INFO, | |||
format='%(asctime)s [%(levelname)s] - %(message)s', | |||
datefmt='%Y-%m-%d %H:%M:%S') | |||
hyperparameters = {key: value for key, value in vars(self.cfg).items()} | |||
hyperparameters_str = "\n".join([f"{key}: {value}" for key, value in hyperparameters.items() if ('path' not in key and 'log_dir' not in key)]) | |||
logging.info("Hyperparameters:\n" + hyperparameters_str) | |||
def log_metrics(self, epoch, log_data): | |||
# Loss Metrics Table | |||
loss_header = ["", "f2f", "f2t", "v2t"] | |||
# Check if train_loss exists, otherwise assign "N/A" | |||
train_loss_f2f = log_data.get("train_loss(f2f)", "N/A") | |||
train_loss_f2t = log_data.get("train_loss(f2t)", "N/A") | |||
train_loss_v2t = log_data.get("train_loss(v2t)", "N/A") | |||
# Check if valid_loss exists, otherwise assign "N/A" | |||
valid_loss_f2f = log_data.get("valid_loss(f2f)", "N/A") | |||
valid_loss_f2t = log_data.get("valid_loss(f2t)", "N/A") | |||
valid_loss_v2t = log_data.get("valid_loss(v2t)", "N/A") | |||
loss_rows = [ | |||
["train", train_loss_f2f, train_loss_f2t, train_loss_v2t], | |||
["val", valid_loss_f2f, valid_loss_f2t, valid_loss_v2t], | |||
] | |||
loss_table = tabulate(loss_rows, headers=loss_header, tablefmt="grid") | |||
# Recall Metrics Table | |||
recall_header = ["", "R@1", "R@5", "R@10"] | |||
clustering_names = ['uniform', 'spectral', 'agglomerative', 'kmeans'] | |||
recall_rows_t2v = [] | |||
recall_rows_v2t = [] | |||
for k in [1, 2, 4, 8, 16]: | |||
# Without clustering | |||
base_key = f"k={k}" | |||
if any(f"[t2v] R@{r} ({base_key})" in log_data for r in [1, 5, 10]): | |||
row = [f"{k}"] | |||
for r in [1, 5, 10]: | |||
metric_key = f"[t2v] R@{r} ({base_key})" | |||
row.append(log_data.get(metric_key, "N/A")) | |||
recall_rows_t2v.append(row) | |||
if any(f"[v2t] R@{r} ({base_key})" in log_data for r in [1, 5, 10]): | |||
row = [f"{k}"] | |||
for r in [1, 5, 10]: | |||
metric_key = f"[v2t] R@{r} ({base_key})" | |||
row.append(log_data.get(metric_key, "N/A")) | |||
recall_rows_v2t.append(row) | |||
# With clustering | |||
for clustering_name in clustering_names: | |||
cluster_key = f"{base_key}-{clustering_name}" | |||
if any(f"[t2v] R@{r} ({cluster_key})" in log_data for r in [1, 5, 10]): | |||
row = [f"{k}-{clustering_name}"] | |||
for r in [1, 5, 10]: | |||
metric_key = f"[t2v] R@{r} ({cluster_key})" | |||
row.append(log_data.get(metric_key, "N/A")) | |||
recall_rows_t2v.append(row) | |||
if any(f"[v2t] R@{r} ({cluster_key})" in log_data for r in [1, 5, 10]): | |||
row = [f"{k}-{clustering_name}"] | |||
for r in [1, 5, 10]: | |||
metric_key = f"[v2t] R@{r} ({cluster_key})" | |||
row.append(log_data.get(metric_key, "N/A")) | |||
recall_rows_v2t.append(row) | |||
recall_table_t2v = tabulate(recall_rows_t2v, headers=recall_header, tablefmt="grid") | |||
recall_table_v2t = tabulate(recall_rows_v2t, headers=recall_header, tablefmt="grid") | |||
logging.info(f"Epoch {epoch+1} - Loss and Recall Metrics:") | |||
logging.info("\n" + loss_table) | |||
logging.info("\nText-to-Video Recall Metrics:") | |||
logging.info("\n" + recall_table_t2v) | |||
logging.info("\nVideo-to-Text Recall Metrics:") | |||
logging.info("\n" + recall_table_v2t) | |||
def select_representative_frame_per_cluster(self, clusters, frame_scores): | |||
representative_frames = [] | |||
for cluster in clusters: | |||
best_frame_index = max(cluster, key=lambda index: frame_scores[index]) | |||
representative_frames.append(int(best_frame_index)) | |||
return representative_frames | |||
class kl_divergence_loss(nn.Module): | |||
def __init__(self,): | |||
super(kl_divergence_loss, self).__init__() | |||
def forward(self, features1, features2): | |||
features1 = F.normalize(features1, p=2, dim=-1) | |||
features2 = F.normalize(features2, p=2, dim=-1) | |||
cos_sim_features1 = torch.mm(features1, features1.t()) | |||
cos_sim_features2 = torch.mm(features2, features2.t()) | |||
probs_features1 = F.softmax(cos_sim_features1, dim=-1) | |||
probs_features2 = F.softmax(cos_sim_features2, dim=-1) | |||
loss = F.kl_div(probs_features1.log(), probs_features2, reduction='batchmean') | |||
return loss | |||
class CrossEn(nn.Module): | |||
def __init__(self,): | |||
super(CrossEn, self).__init__() | |||
def forward(self, sim_matrix): | |||
logpt = F.log_softmax(sim_matrix, dim=-1) | |||
logpt = torch.diag(logpt) | |||
nce_loss = -logpt | |||
sim_loss = nce_loss.mean() | |||
return sim_loss |
@@ -0,0 +1,73 @@ | |||
from sklearn.cluster import SpectralClustering, AgglomerativeClustering, KMeans | |||
import numpy as np | |||
from sklearn.metrics.pairwise import cosine_similarity | |||
class VideoClusterer: | |||
def __init__(self, clustering_method='uniform', n_clusters=2, similarity_threshold=0.8): | |||
self.n_clusters = n_clusters | |||
self.similarity_threshold = similarity_threshold | |||
self.clustering_method = clustering_method | |||
# Decide on the clustering method to use | |||
if clustering_method == 'uniform': | |||
self.clusterer = self.uniform_clustering | |||
elif clustering_method == 'spectral': | |||
self.clusterer = SpectralClustering(n_clusters=n_clusters, affinity='precomputed') | |||
elif clustering_method == 'agglomerative': | |||
self.clusterer = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward') | |||
elif clustering_method == 'kmeans': | |||
self.clusterer = KMeans(n_clusters=n_clusters, n_init=1) | |||
else: | |||
raise ValueError(f"Invalid clustering method: {clustering_method}") | |||
def uniform_clustering(self, features): | |||
n = len(features) | |||
clusters = [] | |||
cluster_size = n // self.n_clusters | |||
remainder = n % self.n_clusters | |||
start = 0 | |||
for i in range(self.n_clusters): | |||
if i < remainder: | |||
end = start + cluster_size + 1 | |||
else: | |||
end = start + cluster_size | |||
clusters.append(list(range(start, end))) | |||
start = end | |||
return clusters | |||
def detect_outliers(self, features): | |||
dot_product_matrix = features.dot(features.T) | |||
average_similarities = np.mean(dot_product_matrix, axis=0) | |||
# Adding a small constant epsilon to the standard deviation to prevent division by zero | |||
epsilon = 1e-8 | |||
normal = (average_similarities - np.mean(average_similarities)) / (np.std(average_similarities) + epsilon) | |||
outlier_mask = np.logical_or(normal > 1.5, normal < -1.5) | |||
return outlier_mask | |||
def get_clusters(self, features): | |||
features = features.cpu().numpy() | |||
if self.clustering_method == 'uniform': | |||
return self.uniform_clustering(features) | |||
else: | |||
# For non-uniform methods, follow the original procedure | |||
outlier_mask = self.detect_outliers(features) | |||
if np.sum(~outlier_mask) > self.n_clusters: | |||
features = features[~outlier_mask] | |||
# Compute cosine similarity matrix for spectral clustering | |||
if self.clustering_method == 'spectral': | |||
similarity_matrix = cosine_similarity(features) | |||
labels = self.clusterer.fit_predict(similarity_matrix) | |||
else: | |||
# For agglomerative, k-means, and other clustering methods that don't require a precomputed matrix | |||
labels = self.clusterer.fit_predict(features) | |||
# Organize frames into clusters based on labels | |||
clusters = [[] for _ in range(self.n_clusters)] | |||
for idx, label in enumerate(labels): | |||
clusters[label].append(idx) | |||
return clusters |
@@ -0,0 +1,69 @@ | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import unicode_literals | |||
from __future__ import print_function | |||
import numpy as np | |||
import torch | |||
def compute_metrics(x): | |||
sx = np.sort(-x, axis=1) | |||
d = np.diag(-x) | |||
d = d[:, np.newaxis] | |||
ind = sx - d | |||
ind = np.where(ind == 0) | |||
ind = ind[1] | |||
metrics = {} | |||
metrics['R@1'] = float(np.sum(ind == 0)) * 100 / len(ind) | |||
metrics['R@5'] = float(np.sum(ind < 5)) * 100 / len(ind) | |||
metrics['R@10'] = float(np.sum(ind < 10)) * 100 / len(ind) | |||
metrics["MedianR"] = np.median(ind) + 1 | |||
metrics["MeanR"] = np.mean(ind) + 1 | |||
# metrics["cols"] = [int(i) for i in list(ind)] | |||
return metrics | |||
def print_computed_metrics(metrics): | |||
r1 = metrics['R@1'] | |||
r5 = metrics['R@5'] | |||
r10 = metrics['R@10'] | |||
mr = metrics['MR'] | |||
print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr)) | |||
# below two functions directly come from: https://github.com/Deferf/Experiments | |||
def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]): | |||
if not torch.is_tensor(sim_tensor): | |||
sim_tensor = torch.tensor(sim_tensor) | |||
# Permute sim_tensor so it represents a sequence of text-video similarity matrices. | |||
# Then obtain the double argsort to position the rank on the diagonal | |||
stacked_sim_matrices = sim_tensor.permute(1, 0, 2) | |||
first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True) | |||
second_argsort = torch.argsort(first_argsort, dim = -1, descending= False) | |||
# Extracts ranks i.e diagonals | |||
ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2)) | |||
# Now we need to extract valid ranks, as some belong to inf padding values | |||
permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2)) | |||
mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) | |||
valid_ranks = ranks[mask] | |||
# A quick dimension check validates our results, there may be other correctness tests pending | |||
# Such as dot product localization, but that is for other time. | |||
#assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) | |||
if not torch.is_tensor(valid_ranks): | |||
valid_ranks = torch.tensor(valid_ranks) | |||
results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} | |||
results["MedianR"] = float(torch.median(valid_ranks + 1)) | |||
results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) | |||
results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) | |||
results['MR'] = results["MedianR"] | |||
return results | |||
def tensor_video_to_text_sim(sim_tensor): | |||
if not torch.is_tensor(sim_tensor): | |||
sim_tensor = torch.tensor(sim_tensor) | |||
# Code to avoid nans | |||
sim_tensor[sim_tensor != sim_tensor] = float('-inf') | |||
# Forms a similarity matrix for use with rank at k | |||
values, _ = torch.max(sim_tensor, dim=1, keepdim=True) | |||
return torch.squeeze(values).T |