import argparse | |||||
import os | |||||
import torch | |||||
media_path_root = "path/to/media" | |||||
class Config: | |||||
def __init__(self): | |||||
self.parser = argparse.ArgumentParser(description='Train a student network.') | |||||
self.add_arguments() | |||||
self.args = self.parse() | |||||
self.update_paths() | |||||
def add_arguments(self): | |||||
# General settings | |||||
self.parser.add_argument('--seed', type=int, default=12345, help='Init Seed') | |||||
self.parser.add_argument('--device', type=int, default=0, help='Device number to use for training') | |||||
self.parser.add_argument("--do_train", type=int, default=1, help="Whether to run training.") | |||||
self.parser.add_argument("--do_eval", type=int, default=1, help="Whether to run evaluation.") | |||||
self.parser.add_argument('--dataset', type=str, default='activity-net', choices=['activity-net', 'msrvtt'], help='Dataset to use') | |||||
self.parser.add_argument('--teacher_model', type=str, default='CLIP', choices=['CLIP', 'DiffusionRet', 'EMCL', 'HBI', 'DiCoSA'], help='Teacher model to use') | |||||
self.parser.add_argument('--workers', type=int, default=12, help='Number of workers') | |||||
# Model configurations | |||||
self.parser.add_argument('--dropout_ratio', type=float, default=0.3, help='Dropout ratio for the model') | |||||
self.parser.add_argument('--fc1_size', type=int, default=512, help='First FC layer Hidden size') | |||||
self.parser.add_argument('--fc2_size', type=int, default=256, help='Second FC layer Hidden size') | |||||
self.parser.add_argument('--fc3_size', type=int, default=128, help='Third FC layer Hidden size') | |||||
self.parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'], help='Backbone model') | |||||
self.parser.add_argument('--n_trainable_backbone_layers', type=int, default=1, help='Number of trainable backbone layers') | |||||
self.parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'leakyrelu', 'gelu'], help='Activation function') | |||||
self.parser.add_argument('--normalization_layer', type=str, default='sigmoid', choices=['sigmoid', 'softmax', 'gumbel', 'k_top'], help='Normalization layer type (for text-to-video retireval pipeline)') | |||||
self.parser.add_argument('--init_tau', default=5.0, type=float, help="annealing init temperature") | |||||
self.parser.add_argument('--min_tau', default=0.5, type=float, help="min temperature to anneal to") | |||||
self.parser.add_argument('--decay_factor', default=0.045, type=float, help="exp decay factor per epoch") | |||||
self.parser.add_argument('--store_complete_model', type=int, default=0, help='Store the compelte model') | |||||
self.parser.add_argument('--load_complete_model', type=int, default=0, help='1: Load') | |||||
self.parser.add_argument('--trained_model_epoch', type=int, default=100, help='Load the trained model from which epoch') | |||||
self.parser.add_argument('--store_backbone', type=int, default=0, help='Store just backbone') | |||||
self.parser.add_argument('--load_backbone', type=int, default=1, help='1: Load') | |||||
self.parser.add_argument('--trained_backbone_epoch', type=int, default=100, help='Load a trained backbone model from which epoch') | |||||
# Dataloader configurations | |||||
self.parser.add_argument('--normalization', type=str, default='binary', choices=['raw', 'min_max', 'binary'], help='Normalization type for dataloader') | |||||
self.parser.add_argument('--binary_percentage', type=float, default=0.1, help='Binary percentage for dataloader') | |||||
# Training configurations | |||||
self.parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training') | |||||
self.parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for training') | |||||
self.parser.add_argument('--similarity_matrix_loss', type=str, default='kl', choices=['mse','kl'], help='frame2frame loss') | |||||
self.parser.add_argument('--use_weighted_loss', type=int, default=0, help='Whether use weighted BCE, to handle unbalanced classes') | |||||
self.parser.add_argument('--alpha', type=float, default=1, help='Weight for f2f loss') | |||||
self.parser.add_argument('--beta', type=float, default=1, help='Weight for f2t loss') | |||||
self.parser.add_argument('--gamma', type=float, default=1, help='Weight for v2t loss') | |||||
self.parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') | |||||
self.parser.add_argument('--use_scheduler', type=int, default=1, help='Use a learning rate scheduler') | |||||
self.parser.add_argument('--scheduler_step_size', type=float, default=5, help='After how many epochs the learning rate will be adjusted') | |||||
self.parser.add_argument('--scheduler_gamma', type=float, default=0.4, help='The factor by which the learning rate will be reduced at each step.') | |||||
self.parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay for the optimizer. Set 0 to disable it') | |||||
self.parser.add_argument('--do_cluster', type=int, default=1, help='Calculate results with cluster as well or not.') | |||||
def parse(self): | |||||
return self.parser.parse_args() | |||||
def update_paths(self): | |||||
self.args.device = torch.device(f'cuda:{self.args.device}' if torch.cuda.is_available() else "cpu") | |||||
self.args.raw_backbone_path = f"{media_path_root}/saved_models/student_pretrained_backbones/{self.args.backbone}_pretrained.pth" | |||||
self.args.st_scores_path = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/st_scores/st_scores.json" | |||||
self.args.trained_backbone_path = f"{media_path_root}/experiments/saved_models/{self.args.backbone}/{self.args.similarity_matrix_loss}" | |||||
self.args.trained_model_path = f"{self.args.trained_backbone_path}/complete_model/" | |||||
self.args.log_dir = f'logs/{self.args.project_name}/' | |||||
os.makedirs(self.args.st_scores_path, exist_ok=True) | |||||
os.makedirs(self.args.trained_backbone_path, exist_ok=True) | |||||
os.makedirs(self.args.trained_model_path, exist_ok=True) | |||||
os.makedirs(self.args.log_dir, exist_ok=True) | |||||
self.args.paths = {'train':{}, 'valid':{}} | |||||
for data_split in ['train', 'valid']: | |||||
split = 'val_1' if data_split=='valid' else data_split | |||||
self.args.paths[data_split]['frames'] = f"{media_path_root}/{self.args.dataset}/sampled_frames/{split}" | |||||
self.args.paths[data_split]['frame_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/frames/{split}" | |||||
self.args.paths[data_split]['text_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/description/{split}" | |||||
self.args.paths[data_split]['gt_scores'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/gt_scores/gt_{split}.json" |
import cv2 | |||||
from PIL import Image | |||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode | |||||
class RawVideoExtractor(): | |||||
def __init__(self, centercrop=False, framerate=-1, size=224, to_tensor=True): | |||||
self.centercrop = centercrop | |||||
self.framerate = framerate | |||||
self.to_tensor = to_tensor | |||||
self.transform = self._transform(size) | |||||
def _transform(self, n_px): | |||||
if self.to_tensor: | |||||
return Compose([ | |||||
Resize(n_px, interpolation=InterpolationMode.BICUBIC), | |||||
CenterCrop(n_px), ToTensor(), | |||||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
(0.26862954, 0.26130258, 0.27577711))]) | |||||
else: | |||||
return Compose([Resize(n_px, interpolation=InterpolationMode.BICUBIC),CenterCrop(n_px)]) | |||||
def get_video_data(self, video_path, start_time=None, end_time=None): | |||||
if start_time is not None or end_time is not None: | |||||
assert start_time > -1 and end_time > start_time | |||||
assert self.framerate > -1 | |||||
cap = cv2.VideoCapture(video_path) | |||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |||||
video_fps = cap.get(cv2.CAP_PROP_FPS) | |||||
start_frame = int(start_time * video_fps) if start_time else 0 | |||||
end_frame = int(end_time * video_fps) if end_time else frame_count - 1 | |||||
interval = 1 | |||||
if self.framerate > 0: | |||||
interval = video_fps / self.framerate | |||||
else: | |||||
self.framerate = video_fps | |||||
if interval == 0: | |||||
interval = 1 | |||||
images = [] | |||||
for i in range(frame_count): | |||||
ret, frame = cap.read() | |||||
if not ret: | |||||
break | |||||
if i >= start_frame and i <= end_frame: | |||||
if len(images) * interval < i - start_frame: | |||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |||||
image = Image.fromarray(frame_rgb) | |||||
image = self.transform(image) | |||||
images.append(image) | |||||
cap.release() | |||||
return images |
import os | |||||
import json | |||||
import torch | |||||
from torchvision.io import read_image | |||||
from torch.utils.data import Dataset | |||||
from torchvision import transforms | |||||
class VideoDescriptionDataloader(Dataset): | |||||
def __init__(self, cfg, data_type): | |||||
self.frames_dir = cfg.paths[data_type]['frames'] | |||||
self.frame_features_dir = cfg.paths[data_type]['frame_features'] | |||||
self.text_features_dir = cfg.paths[data_type]['text_features'] | |||||
self.gt_scores_file_path = cfg.paths[data_type]['gt_scores'] | |||||
self.transform = self._get_transforms() | |||||
self.binary_percentage = cfg.binary_percentage | |||||
if cfg.normalization == 'min_max' or data_type=='valid': | |||||
self.normalize_scores = self._min_max_normalize_scores | |||||
elif cfg.normalization == 'binary': | |||||
self.normalize_scores = self._to_binary_labels | |||||
else: | |||||
raise ValueError(f"Unsupported normalization method: {cfg.normalization}") | |||||
with open(self.gt_scores_file_path, "r") as f: | |||||
self.gt_scores = json.load(f) | |||||
self.video_ids = list(self.gt_scores.keys()) | |||||
def __len__(self): | |||||
return len(self.video_ids) | |||||
def __getitem__(self, idx): | |||||
video_id = self.video_ids[idx] | |||||
video_frames_dir = os.path.join(self.frames_dir, self.video_ids[idx]) | |||||
frames_tensor = torch.stack([self.transform(read_image(os.path.join(video_frames_dir, frame_file)).float()) | |||||
for frame_file in sorted(os.listdir(video_frames_dir))]) | |||||
text_features = torch.load(os.path.join(self.text_features_dir, f'{self.video_ids[idx]}.pt')) | |||||
frame_features = torch.load(os.path.join(self.frame_features_dir, f'{self.video_ids[idx]}.pt')) | |||||
gt_score = torch.tensor(self.gt_scores[self.video_ids[idx]], dtype=torch.float32) | |||||
gt_score = self.normalize_scores(gt_score) | |||||
return frames_tensor, text_features, frame_features, gt_score | |||||
@staticmethod | |||||
def _get_transforms(): | |||||
return transforms.Compose([transforms.Normalize([0.485, 0.456, 0.406], | |||||
[0.229, 0.224, 0.225])]) | |||||
@staticmethod | |||||
def _min_max_normalize_scores(score): | |||||
min_val, max_val = score.min(), score.max() | |||||
if min_val != max_val: | |||||
return (score - min_val) / (max_val - min_val) | |||||
return torch.full_like(score, 0.5) | |||||
def _to_binary_labels(self, score): | |||||
num_top_elements = max(int(len(score) * self.binary_percentage) , 1) | |||||
sorted_indices = score.argsort(descending=True) | |||||
binary_labels = torch.zeros_like(score) | |||||
binary_labels[sorted_indices[:num_top_elements]] = 1 | |||||
return binary_labels |
video6513 | |||||
video6514 | |||||
video6515 | |||||
video6516 | |||||
video6517 | |||||
video6518 | |||||
video6519 | |||||
video6520 | |||||
video6521 | |||||
video6522 | |||||
video6523 | |||||
video6524 | |||||
video6525 | |||||
video6526 | |||||
video6527 | |||||
video6528 | |||||
video6529 | |||||
video6530 | |||||
video6531 | |||||
video6532 | |||||
video6533 | |||||
video6534 | |||||
video6535 | |||||
video6536 | |||||
video6537 | |||||
video6538 | |||||
video6539 | |||||
video6540 | |||||
video6541 | |||||
video6542 | |||||
video6543 | |||||
video6544 | |||||
video6545 | |||||
video6546 | |||||
video6547 | |||||
video6548 | |||||
video6549 | |||||
video6550 | |||||
video6551 | |||||
video6552 | |||||
video6553 | |||||
video6554 | |||||
video6555 | |||||
video6556 | |||||
video6557 | |||||
video6558 | |||||
video6559 | |||||
video6560 | |||||
video6561 | |||||
video6562 | |||||
video6563 | |||||
video6564 | |||||
video6565 | |||||
video6566 | |||||
video6567 | |||||
video6568 | |||||
video6569 | |||||
video6570 | |||||
video6571 | |||||
video6572 | |||||
video6573 | |||||
video6574 | |||||
video6575 | |||||
video6576 | |||||
video6577 | |||||
video6578 | |||||
video6579 | |||||
video6580 | |||||
video6581 | |||||
video6582 | |||||
video6583 | |||||
video6584 | |||||
video6585 | |||||
video6586 | |||||
video6587 | |||||
video6588 | |||||
video6589 | |||||
video6590 | |||||
video6591 | |||||
video6592 | |||||
video6593 | |||||
video6594 | |||||
video6595 | |||||
video6596 | |||||
video6597 | |||||
video6598 | |||||
video6599 | |||||
video6600 | |||||
video6601 | |||||
video6602 | |||||
video6603 | |||||
video6604 | |||||
video6605 | |||||
video6606 | |||||
video6607 | |||||
video6608 | |||||
video6609 | |||||
video6610 | |||||
video6611 | |||||
video6612 | |||||
video6613 | |||||
video6614 | |||||
video6615 | |||||
video6616 | |||||
video6617 | |||||
video6618 | |||||
video6619 | |||||
video6620 | |||||
video6621 | |||||
video6622 | |||||
video6623 | |||||
video6624 | |||||
video6625 | |||||
video6626 | |||||
video6627 | |||||
video6628 | |||||
video6629 | |||||
video6630 | |||||
video6631 | |||||
video6632 | |||||
video6633 | |||||
video6634 | |||||
video6635 | |||||
video6636 | |||||
video6637 | |||||
video6638 | |||||
video6639 | |||||
video6640 | |||||
video6641 | |||||
video6642 | |||||
video6643 | |||||
video6644 | |||||
video6645 | |||||
video6646 | |||||
video6647 | |||||
video6648 | |||||
video6649 | |||||
video6650 | |||||
video6651 | |||||
video6652 | |||||
video6653 | |||||
video6654 | |||||
video6655 | |||||
video6656 | |||||
video6657 | |||||
video6658 | |||||
video6659 | |||||
video6660 | |||||
video6661 | |||||
video6662 | |||||
video6663 | |||||
video6664 | |||||
video6665 | |||||
video6666 | |||||
video6667 | |||||
video6668 | |||||
video6669 | |||||
video6670 | |||||
video6671 | |||||
video6672 | |||||
video6673 | |||||
video6674 | |||||
video6675 | |||||
video6676 | |||||
video6677 | |||||
video6678 | |||||
video6679 | |||||
video6680 | |||||
video6681 | |||||
video6682 | |||||
video6683 | |||||
video6684 | |||||
video6685 | |||||
video6686 | |||||
video6687 | |||||
video6688 | |||||
video6689 | |||||
video6690 | |||||
video6691 | |||||
video6692 | |||||
video6693 | |||||
video6694 | |||||
video6695 | |||||
video6696 | |||||
video6697 | |||||
video6698 | |||||
video6699 | |||||
video6700 | |||||
video6701 | |||||
video6702 | |||||
video6703 | |||||
video6704 | |||||
video6705 | |||||
video6706 | |||||
video6707 | |||||
video6708 | |||||
video6709 | |||||
video6710 | |||||
video6711 | |||||
video6712 | |||||
video6713 | |||||
video6714 | |||||
video6715 | |||||
video6716 | |||||
video6717 | |||||
video6718 | |||||
video6719 | |||||
video6720 | |||||
video6721 | |||||
video6722 | |||||
video6723 | |||||
video6724 | |||||
video6725 | |||||
video6726 | |||||
video6727 | |||||
video6728 | |||||
video6729 | |||||
video6730 | |||||
video6731 | |||||
video6732 | |||||
video6733 | |||||
video6734 | |||||
video6735 | |||||
video6736 | |||||
video6737 | |||||
video6738 | |||||
video6739 | |||||
video6740 | |||||
video6741 | |||||
video6742 | |||||
video6743 | |||||
video6744 | |||||
video6745 | |||||
video6746 | |||||
video6747 | |||||
video6748 | |||||
video6749 | |||||
video6750 | |||||
video6751 | |||||
video6752 | |||||
video6753 | |||||
video6754 | |||||
video6755 | |||||
video6756 | |||||
video6757 | |||||
video6758 | |||||
video6759 | |||||
video6760 | |||||
video6761 | |||||
video6762 | |||||
video6763 | |||||
video6764 | |||||
video6765 | |||||
video6766 | |||||
video6767 | |||||
video6768 | |||||
video6769 | |||||
video6770 | |||||
video6771 | |||||
video6772 | |||||
video6773 | |||||
video6774 | |||||
video6775 | |||||
video6776 | |||||
video6777 | |||||
video6778 | |||||
video6779 | |||||
video6780 | |||||
video6781 | |||||
video6782 | |||||
video6783 | |||||
video6784 | |||||
video6785 | |||||
video6786 | |||||
video6787 | |||||
video6788 | |||||
video6789 | |||||
video6790 | |||||
video6791 | |||||
video6792 | |||||
video6793 | |||||
video6794 | |||||
video6795 | |||||
video6796 | |||||
video6797 | |||||
video6798 | |||||
video6799 | |||||
video6800 | |||||
video6801 | |||||
video6802 | |||||
video6803 | |||||
video6804 | |||||
video6805 | |||||
video6806 | |||||
video6807 | |||||
video6808 | |||||
video6809 | |||||
video6810 | |||||
video6811 | |||||
video6812 | |||||
video6813 | |||||
video6814 | |||||
video6815 | |||||
video6816 | |||||
video6817 | |||||
video6818 | |||||
video6819 | |||||
video6820 | |||||
video6821 | |||||
video6822 | |||||
video6823 | |||||
video6824 | |||||
video6825 | |||||
video6826 | |||||
video6827 | |||||
video6828 | |||||
video6829 | |||||
video6830 | |||||
video6831 | |||||
video6832 | |||||
video6833 | |||||
video6834 | |||||
video6835 | |||||
video6836 | |||||
video6837 | |||||
video6838 | |||||
video6839 | |||||
video6840 | |||||
video6841 | |||||
video6842 | |||||
video6843 | |||||
video6844 | |||||
video6845 | |||||
video6846 | |||||
video6847 | |||||
video6848 | |||||
video6849 | |||||
video6850 | |||||
video6851 | |||||
video6852 | |||||
video6853 | |||||
video6854 | |||||
video6855 | |||||
video6856 | |||||
video6857 | |||||
video6858 | |||||
video6859 | |||||
video6860 | |||||
video6861 | |||||
video6862 | |||||
video6863 | |||||
video6864 | |||||
video6865 | |||||
video6866 | |||||
video6867 | |||||
video6868 | |||||
video6869 | |||||
video6870 | |||||
video6871 | |||||
video6872 | |||||
video6873 | |||||
video6874 | |||||
video6875 | |||||
video6876 | |||||
video6877 | |||||
video6878 | |||||
video6879 | |||||
video6880 | |||||
video6881 | |||||
video6882 | |||||
video6883 | |||||
video6884 | |||||
video6885 | |||||
video6886 | |||||
video6887 | |||||
video6888 | |||||
video6889 | |||||
video6890 | |||||
video6891 | |||||
video6892 | |||||
video6893 | |||||
video6894 | |||||
video6895 | |||||
video6896 | |||||
video6897 | |||||
video6898 | |||||
video6899 | |||||
video6900 | |||||
video6901 | |||||
video6902 | |||||
video6903 | |||||
video6904 | |||||
video6905 | |||||
video6906 | |||||
video6907 | |||||
video6908 | |||||
video6909 | |||||
video6910 | |||||
video6911 | |||||
video6912 | |||||
video6913 | |||||
video6914 | |||||
video6915 | |||||
video6916 | |||||
video6917 | |||||
video6918 | |||||
video6919 | |||||
video6920 | |||||
video6921 | |||||
video6922 | |||||
video6923 | |||||
video6924 | |||||
video6925 | |||||
video6926 | |||||
video6927 | |||||
video6928 | |||||
video6929 | |||||
video6930 | |||||
video6931 | |||||
video6932 | |||||
video6933 | |||||
video6934 | |||||
video6935 | |||||
video6936 | |||||
video6937 | |||||
video6938 | |||||
video6939 | |||||
video6940 | |||||
video6941 | |||||
video6942 | |||||
video6943 | |||||
video6944 | |||||
video6945 | |||||
video6946 | |||||
video6947 | |||||
video6948 | |||||
video6949 | |||||
video6950 | |||||
video6951 | |||||
video6952 | |||||
video6953 | |||||
video6954 | |||||
video6955 | |||||
video6956 | |||||
video6957 | |||||
video6958 | |||||
video6959 | |||||
video6960 | |||||
video6961 | |||||
video6962 | |||||
video6963 | |||||
video6964 | |||||
video6965 | |||||
video6966 | |||||
video6967 | |||||
video6968 | |||||
video6969 | |||||
video6970 | |||||
video6971 | |||||
video6972 | |||||
video6973 | |||||
video6974 | |||||
video6975 | |||||
video6976 | |||||
video6977 | |||||
video6978 | |||||
video6979 | |||||
video6980 | |||||
video6981 | |||||
video6982 | |||||
video6983 | |||||
video6984 | |||||
video6985 | |||||
video6986 | |||||
video6987 | |||||
video6988 | |||||
video6989 | |||||
video6990 | |||||
video6991 | |||||
video6992 | |||||
video6993 | |||||
video6994 | |||||
video6995 | |||||
video6996 | |||||
video6997 | |||||
video6998 | |||||
video6999 | |||||
video7000 | |||||
video7001 | |||||
video7002 | |||||
video7003 | |||||
video7004 | |||||
video7005 | |||||
video7006 | |||||
video7007 | |||||
video7008 | |||||
video7009 |
import os | |||||
import torch | |||||
from config import Config | |||||
from src.model import SaliencyNet | |||||
from torch.utils.data import DataLoader | |||||
from src.train import SalientFrameSamplerTrainer | |||||
from dataloaders.t2v_dataloader import VideoDescriptionDataloader | |||||
def load_weights(path, which_epoch): | |||||
weights, last_epoch = None, None | |||||
available_models = [name for name in os.listdir(path) if name.endswith(".pt")] | |||||
if available_models: | |||||
last_epoch = max([int(name[6:-3]) for name in available_models]) | |||||
last_epoch = min(last_epoch, which_epoch) | |||||
weights = torch.load(os.path.join(path, f'epoch_{last_epoch}.pt')) | |||||
return weights, last_epoch | |||||
def set_seeds(seed: int): | |||||
os.environ['PYTHONHASHSEED'] = str(seed) | |||||
torch.manual_seed(seed) | |||||
torch.cuda.manual_seed(seed) | |||||
torch.backends.cudnn.benchmark = False | |||||
torch.backends.cudnn.deterministic = True | |||||
def main(cfg): | |||||
set_seeds(cfg.seed) | |||||
model = SaliencyNet(cfg) | |||||
model.to(cfg.device) | |||||
if cfg.load_complete_model: | |||||
weights, last_epoch = load_weights(cfg.trained_model_path, cfg.trained_model_epoch) | |||||
if weights: | |||||
model.load_state_dict(weights) | |||||
print(f'Complete model -{cfg.backbone} trained on {cfg.similarity_matrix_loss} and {cfg.saliency_matching_loss} lossese from epoch #{last_epoch}') | |||||
elif cfg.load_backbone: | |||||
weights, last_epoch = load_weights(cfg.trained_backbone_path, cfg.trained_backbone_epoch) | |||||
if weights: | |||||
model.pretrained.load_state_dict(weights) | |||||
print(f'{cfg.backbone} backbone trained on {cfg.similarity_matrix_loss} loss loaded from epoch #{last_epoch}') | |||||
else: | |||||
print(f'{cfg.backbone} backbone loaded from scratch.') | |||||
train_dataset = VideoDescriptionDataloader(cfg, data_type='train') | |||||
val_dataset = VideoDescriptionDataloader(cfg, data_type='valid') | |||||
train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=cfg.workers) | |||||
val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=cfg.workers) | |||||
trainer = SalientFrameSamplerTrainer(model, train_dataloader, val_dataloader, cfg) | |||||
trainer.train() | |||||
if __name__ == "__main__": | |||||
torch.multiprocessing.set_start_method("spawn") | |||||
cfg = Config().args | |||||
main(cfg) |
torch | |||||
opencv-python | |||||
Pillow | |||||
timm | |||||
transformers | |||||
torchvision | |||||
scikit-learn |
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
class GumbelSoftmaxTopK(nn.Module): | |||||
def __init__(self, k, tau): | |||||
super(GumbelSoftmaxTopK, self).__init__() | |||||
self.k = k | |||||
self.tau = tau | |||||
def forward(self, logits): | |||||
logits = torch.log(nn.Softmax(dim=0)(logits).clamp(min=1e-8)) | |||||
""" | |||||
Adapted from: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/DL2/sampling/subsets.html#Subset-Sampler-Class | |||||
and https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax | |||||
""" | |||||
m = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format), | |||||
torch.ones_like(logits, memory_format= torch.legacy_contiguous_format)) | |||||
gumbels = m.sample() | |||||
y = logits + gumbels | |||||
_EPS = torch.tensor(1e-40).to(logits.device) | |||||
khot = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) | |||||
onehot_approx = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) | |||||
for i in range(self.k): | |||||
khot_mask = torch.maximum(1.0 - onehot_approx, _EPS) | |||||
y += khot_mask.log() | |||||
onehot_approx = nn.Softmax(dim=0)(y / self.tau) | |||||
khot = torch.add(khot, onehot_approx) | |||||
return khot | |||||
class GumbelSoftmax(nn.Module): | |||||
def __init__(self, tau, hard=False): | |||||
super(GumbelSoftmax, self).__init__() | |||||
self.tau = tau | |||||
self.hard = hard | |||||
def forward(self, logits): | |||||
m = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format), | |||||
torch.ones_like(logits, memory_format= torch.legacy_contiguous_format)) | |||||
gumbels = m.sample() | |||||
y = logits + gumbels | |||||
y_soft = nn.Softmax(dim=0)(y / self.tau) | |||||
return y_soft |
import timm | |||||
import torch | |||||
import torch.nn as nn | |||||
from src.gumbel_softmax import GumbelSoftmax, GumbelSoftmaxTopK | |||||
class SaliencyNet(nn.Module): | |||||
def __init__(self, cfg): | |||||
super(SaliencyNet, self).__init__() | |||||
self.k = 8 | |||||
self.tau = cfg.init_tau | |||||
self.min_tau = cfg.min_tau | |||||
self.decay_factor = cfg.decay_factor | |||||
self.raw_backbone_path = cfg.raw_backbone_path | |||||
self.configure_pretrained(cfg.backbone, cfg.n_trainable_backbone_layers) | |||||
embed_dim = self.pretrained.num_features | |||||
self.fc1 = nn.Linear(embed_dim, cfg.fc1_size) | |||||
self.fc2 = nn.Linear(cfg.fc1_size, cfg.fc2_size) | |||||
self.fc3 = nn.Linear(cfg.fc2_size, cfg.fc3_size) | |||||
self.fc4 = nn.Linear(cfg.fc3_size, 1) | |||||
# self.bn1 = nn.BatchNorm1d(fc1_size) # Batch normalization layer | |||||
self.dropout = nn.Dropout(cfg.dropout_ratio) | |||||
self.sigmoid = nn.Sigmoid() | |||||
if cfg.activation == 'relu': | |||||
self.activation = nn.ReLU() | |||||
elif cfg.activation == 'leakyrelu': | |||||
self.activation = nn.LeakyReLU() | |||||
elif cfg.activation == 'gelu': | |||||
self.activation = nn.GELU() | |||||
else: | |||||
raise ValueError("Invalid activation type. Choose 'relu', 'leakyrelu', or 'gelu'.") | |||||
self.normalization_layer = self.init_normalization_layer(cfg.normalization_layer, self.tau, self.k) | |||||
self.initialize_weights() | |||||
def init_normalization_layer(self, normalization_type, tau, k): | |||||
if normalization_type == 'raw': | |||||
normalization_layer = lambda x: x | |||||
elif normalization_type == 'sigmoid': | |||||
normalization_layer = nn.Sigmoid() | |||||
elif normalization_type == 'softmax': | |||||
normalization_layer = nn.Softmax(dim=1) | |||||
elif normalization_type == 'gumbel': | |||||
normalization_layer = GumbelSoftmax(tau) | |||||
elif normalization_type == 'k_top': | |||||
normalization_layer = GumbelSoftmaxTopK(k, tau) | |||||
return normalization_layer | |||||
def initialize_weights(self): | |||||
for module in self.modules(): | |||||
if isinstance(module, nn.Linear): | |||||
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |||||
if module.bias is not None: | |||||
nn.init.constant_(module.bias, 0) | |||||
def forward(self, preprocessed_images): | |||||
features = self.pretrained(preprocessed_images) | |||||
x = self.activation(self.fc1(features)) | |||||
x = self.dropout(x) | |||||
x = self.activation(self.fc2(x)) | |||||
x = self.dropout(x) | |||||
x = self.activation(self.fc3(x)) | |||||
x = self.fc4(x) | |||||
frame_scores = torch.flatten(x) | |||||
frame_weights = torch.flatten(self.normalization_layer(x)) | |||||
return frame_scores, frame_weights, features | |||||
def anneal_tau(self): | |||||
""" Anneals the temperature parameter τ using exponential decay """ | |||||
self.tau = self.tau * torch.exp(-self.decay_factor * torch.tensor(1.0)) | |||||
self.tau = max(self.tau, self.min_tau) # To prevent τ from becoming too small | |||||
def configure_pretrained(self, backbone, n_trainable_backbone_layers=1): | |||||
# backbones = ['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'] | |||||
self.pretrained = timm.create_model(backbone, pretrained=False, checkpoint_path=self.raw_backbone_path) | |||||
# Freeze the early layers of the model | |||||
for param in self.pretrained.parameters(): | |||||
param.requires_grad = False | |||||
# disable last layer of resnet18 | |||||
if backbone =='resnet18': | |||||
self.pretrained.fc = nn.Identity() | |||||
sequential_modules = [child for child in self.pretrained.children() if isinstance(child, nn.Sequential)] | |||||
elif backbone in ['efficientnet_b0', 'mobilenetv2_100','mobilenetv3_large_100'] : | |||||
self.pretrained.classifier = nn.Identity() | |||||
sequential_modules = [child for child in self.pretrained.blocks.children()] | |||||
else: | |||||
assert('Not supported backbone model!') | |||||
num_seq_blocks = len(sequential_modules) | |||||
if n_trainable_backbone_layers ==0: | |||||
pass | |||||
elif n_trainable_backbone_layers >= num_seq_blocks: | |||||
for param in self.pretrained.parameters(): | |||||
param.requires_grad = True | |||||
else: | |||||
# Select the last `n_trainable_backbone_layers` of sequential modules to be trainable | |||||
layers_to_train = sequential_modules[-n_trainable_backbone_layers:] | |||||
for layer in layers_to_train: | |||||
for param in layer.parameters(): | |||||
param.requires_grad = True | |||||
if n_trainable_backbone_layers > 0: | |||||
last_layers= [] | |||||
if hasattr(self.pretrained, 'conv_head'): | |||||
last_layers.append(self.pretrained.conv_head.weight) | |||||
if hasattr(self.pretrained, 'bn2'): | |||||
last_layers.append(self.pretrained.bn2.weight) | |||||
last_layers.append(self.pretrained.bn2.bias) | |||||
for param in last_layers: | |||||
param.requires_grad = True |
import os | |||||
import torch | |||||
from tqdm import tqdm, trange | |||||
import torch.nn as nn | |||||
import logging | |||||
import torch.nn.functional as F | |||||
from tabulate import tabulate | |||||
from torch.utils.data import DataLoader | |||||
from utils.cluster_frames import VideoClusterer | |||||
from utils.metrics import compute_metrics | |||||
class SalientFrameSamplerTrainer: | |||||
def __init__(self, model, train_dataloader: DataLoader, val_dataloader: DataLoader, cfg): | |||||
self.cfg = cfg | |||||
self.model = model | |||||
self.train_dataloader = train_dataloader | |||||
self.val_dataloader = val_dataloader | |||||
self.init_logging() | |||||
self.criterion1 = kl_divergence_loss() | |||||
self.criterion2 = nn.BCEWithLogitsLoss() | |||||
self.criterion3 = CrossEn() | |||||
self.optimizer = self.init_optimizer() | |||||
if self.cfg.use_scheduler: | |||||
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.cfg.scheduler_step_size, gamma=self.cfg.scheduler_gamma) | |||||
def train(self): | |||||
for epoch in trange(self.cfg.epochs, desc="Epoch", ncols=100): | |||||
log_data = {} | |||||
if self.cfg.do_train: | |||||
train_loss = self.run_training_epoch() | |||||
log_data.update({ | |||||
"train_loss(f2f)": train_loss['f2f'], | |||||
"train_loss(f2t)": train_loss['f2t'], | |||||
"train_loss(v2t)": train_loss['v2t'], | |||||
}) | |||||
if self.cfg.do_eval: | |||||
valid_loss, t2v_metrics, v2t_metrics = self.run_validation_epoch() | |||||
log_data.update({ | |||||
"valid_loss(f2f)": valid_loss['f2f'], | |||||
"valid_loss(f2t)": valid_loss['f2t'], | |||||
"valid_loss(v2t)": valid_loss['v2t'], | |||||
}) | |||||
for key in t2v_metrics.keys(): | |||||
for r in [1, 5, 10]: | |||||
log_data[f"[t2v] R@{r} (k={key})"] = t2v_metrics[key][f'R@{r}'] | |||||
log_data[f"[v2t] R@{r} (k={key})"] = v2t_metrics[key][f'R@{r}'] | |||||
self.log_metrics(epoch, log_data) | |||||
self.model.anneal_tau() | |||||
if hasattr(self,'scheduler'): | |||||
self.scheduler.step() | |||||
if self.cfg.store_backbone: | |||||
torch.save(self.model.pretrained.state_dict(), os.path.join(self.cfg.trained_backbone_path, f"epoch_{epoch+1}.pt")) | |||||
if self.cfg.store_complete_model: | |||||
torch.save(self.model.state_dict(), os.path.join(self.args.trained_model_path, f"epoch_{epoch+1}.pt")) | |||||
def run_training_epoch(self): | |||||
self.model.train() | |||||
epoch_train_loss = {'f2f': 0.0, 'f2t': 0.0, 'v2t': 0.0} | |||||
video_embeddings = [] | |||||
text_embeddings = [] | |||||
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||||
for i, (frames, teacher_text_features, teacher_frame_features, gt_score) in enumerate(tqdm(self.train_dataloader, desc='Train', leave=False, ncols=100)): | |||||
torch.cuda.empty_cache() | |||||
frames = frames[0].to(self.cfg.device) | |||||
teacher_frame_features = teacher_frame_features[0].to(self.cfg.device) | |||||
teacher_text_features = teacher_text_features[0].to(self.cfg.device) | |||||
gt_score = gt_score[0].to(self.cfg.device) | |||||
frame_scores, frame_weights, student_frame_features = self.model(frames) | |||||
# Frame-to-Frame loss | |||||
f2f_loss = self.criterion1(student_frame_features, teacher_frame_features) | |||||
# Frame-to-Text loss | |||||
f2t_loss = self.criterion2(frame_scores, gt_score) | |||||
# Update accumulators | |||||
f2f_loss_accum += f2f_loss.item() / self.cfg.batch_size | |||||
f2t_loss_accum += f2t_loss.item() / self.cfg.batch_size | |||||
# Calculating the in-pipeline loss for v2t | |||||
video_features = torch.matmul(teacher_frame_features.T, frame_weights) / torch.sum(frame_weights) | |||||
video_embeddings.append(video_features) | |||||
text_embeddings.append(teacher_text_features) | |||||
if (i + 1) % self.cfg.batch_size == 0 or i == len(self.train_dataloader) - 1: | |||||
sim_matrix = self.compute_sim_matrix(video_embeddings, text_embeddings) | |||||
sim_loss1 = self.criterion3(sim_matrix) | |||||
sim_loss2 = self.criterion3(sim_matrix.T) | |||||
v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||||
# Aggregate the total loss with regularization terms | |||||
total_loss = self.alpha * f2f_loss_accum + self.beta * f2t_loss_accum + self.gamma * v2t_loss | |||||
self.optimizer.zero_grad() | |||||
total_loss.backward() | |||||
# Gradient clipping | |||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |||||
self.optimizer.step() | |||||
# Update epoch losses | |||||
epoch_train_loss['f2f'] += f2f_loss_accum | |||||
epoch_train_loss['f2t'] += f2t_loss_accum | |||||
epoch_train_loss['v2t'] += v2t_loss.item() | |||||
# Reset accumulators | |||||
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||||
video_embeddings = [] | |||||
text_embeddings = [] | |||||
# Normalize epoch losses by the number of batches | |||||
for key in epoch_train_loss.keys(): | |||||
epoch_train_loss[key] /= len(self.train_dataloader) | |||||
return epoch_train_loss | |||||
def run_validation_epoch(self): | |||||
self.model.eval() | |||||
epoch_valid_loss = {'f2f': 0.0, 'f2t': 0.0, 'v2t': 0.0} | |||||
video_embeddings = [] | |||||
text_embeddings = [] | |||||
all_video_embeddings = {str(k): {'without_clustering': [], | |||||
'uniform': [], 'spectral': [], 'agglomerative': [], 'kmeans': []} for k in [1,2,4,8,16]} | |||||
all_text_embeddings = [] | |||||
all_st_scores = [] | |||||
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||||
for i, (frames, teacher_text_features, teacher_frame_features, gt_score) in enumerate(tqdm(self.val_dataloader, desc='Valid', leave=False, ncols=100)): | |||||
torch.cuda.empty_cache() | |||||
frames = frames[0].to(self.cfg.device) | |||||
teacher_frame_features = teacher_frame_features[0].to(self.cfg.device) | |||||
teacher_text_features = teacher_text_features[0].to(self.cfg.device) | |||||
gt_score = gt_score[0].to(self.cfg.device) | |||||
with torch.no_grad(): | |||||
frame_scores, frame_weights, student_frame_features = self.model(frames) | |||||
f2f_loss = self.criterion1(student_frame_features, teacher_frame_features) | |||||
f2t_loss = self.criterion2(frame_scores, gt_score) | |||||
f2f_loss_accum += f2f_loss.item() | |||||
f2t_loss_accum += f2t_loss.item() | |||||
# Calculating the all embeddings, used for calculating t2v_retrieval metrics | |||||
all_text_embeddings.append(teacher_text_features) | |||||
n_frames = teacher_frame_features.size(0) | |||||
sorted_indices = sorted(range(n_frames), key=lambda i: frame_scores[i], reverse=True).copy() | |||||
for k in [1,2,4,8,16]: | |||||
n_selected_frames = min(n_frames, k) | |||||
indices = sorted_indices[:n_selected_frames] | |||||
all_video_embeddings[str(k)]['without_clustering'].append(teacher_frame_features[indices].mean(dim=0)) | |||||
if self.cfg.do_cluster and k>2: | |||||
for clustering_method in ['uniform', 'spectral', 'agglomerative', 'kmeans']: | |||||
if len(frame_scores) <= k: #mean of all frames | |||||
all_video_embeddings[str(k)][clustering_method].append(teacher_frame_features.mean(dim=0)) | |||||
else: | |||||
clusterer = VideoClusterer(clustering_method, k) | |||||
clusters = clusterer.get_clusters(student_frame_features) | |||||
selected_indices = self.select_representative_frame_per_cluster(clusters, frame_scores) | |||||
all_video_embeddings[str(k)][clustering_method].append(teacher_frame_features[selected_indices].mean(dim=0)) | |||||
# Calculating the in-pipeline loss | |||||
video_features = torch.matmul(teacher_frame_features.T, frame_weights) / torch.sum(frame_weights) | |||||
video_embeddings.append(video_features) | |||||
text_embeddings.append(teacher_text_features) | |||||
if (i + 1) % self.cfg.batch_size == 0 or i == len(self.val_dataloader) - 1: | |||||
sim_matrix = self.compute_sim_matrix(video_embeddings, text_embeddings) | |||||
sim_loss1 = self.criterion3(sim_matrix) | |||||
sim_loss2 = self.criterion3(sim_matrix.T) | |||||
v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||||
v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||||
# Update epoch losses | |||||
epoch_valid_loss['f2f'] += f2f_loss_accum | |||||
epoch_valid_loss['f2t'] += f2t_loss_accum | |||||
epoch_valid_loss['v2t'] += v2t_loss.item() | |||||
# Reset accumulators | |||||
f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||||
video_embeddings = [] | |||||
text_embeddings = [] | |||||
# Normalize epoch losses by the number of batches | |||||
for key in epoch_valid_loss.keys(): | |||||
epoch_valid_loss[key] /= len(self.val_dataloader) | |||||
# Compute t2v metrics | |||||
t2v_metrics = {} | |||||
v2t_metrics = {} | |||||
for k, embeddings_dict in all_video_embeddings.items(): | |||||
t2v_metrics[str(k)]={} | |||||
for scenario, embeddings_list in embeddings_dict.items(): | |||||
if embeddings_list: # Check if there are any embeddings for this scenario | |||||
sim_matrix = self.compute_sim_matrix(embeddings_list, all_text_embeddings) | |||||
_key = f"{k}" if scenario=='without_clustering' else f"{k}-{scenario}" | |||||
t2v_metrics[_key] = compute_metrics(sim_matrix.detach().cpu().numpy()) | |||||
v2t_metrics[_key] = compute_metrics(sim_matrix.T.detach().cpu().numpy()) | |||||
return epoch_valid_loss, t2v_metrics, v2t_metrics | |||||
def compute_sim_matrix(self, video_embeddings, text_embeddings): | |||||
video_embeddings_tensor = torch.stack(video_embeddings, dim=0) | |||||
text_embeddings_tensor = torch.stack(text_embeddings, dim=0) | |||||
video_embeddings_norm = video_embeddings_tensor / video_embeddings_tensor.norm(dim=-1, keepdim=True) | |||||
text_embeddings_norm = text_embeddings_tensor / text_embeddings_tensor.norm(dim=-1, keepdim=True) | |||||
sim_matrix = torch.matmul(video_embeddings_norm, text_embeddings_norm.T) | |||||
return sim_matrix | |||||
def init_optimizer(self): | |||||
if not hasattr(self.cfg, 'backbone_lr_scale'): | |||||
return torch.optim.Adam(params=self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay) | |||||
backbone_params = [] | |||||
last_params = [] | |||||
for name, param in self.model.named_parameters(): | |||||
if 'pretrain' in name: | |||||
backbone_params.append(param) | |||||
else: | |||||
last_params.append(param) | |||||
return torch.optim.Adam([ | |||||
{'params': backbone_params, 'lr': self.cfg.lr}, | |||||
{'params': last_params, 'lr': self.cfg.lr * self.cfg.backbone_lr_scale} | |||||
], weight_decay=self.cfg.weight_decay) | |||||
def init_logging(self): | |||||
available_logs = os.listdir(self.cfg.log_dir) | |||||
last_run = 0 | |||||
if available_logs: | |||||
last_run = max([int(name.split('-')[0]) for name in available_logs]) | |||||
log_file_name = f"{last_run + 1}-local-run.log" | |||||
log_file = os.path.join(self.cfg.log_dir, log_file_name) | |||||
logging.basicConfig(filename=log_file, level=logging.INFO, | |||||
format='%(asctime)s [%(levelname)s] - %(message)s', | |||||
datefmt='%Y-%m-%d %H:%M:%S') | |||||
hyperparameters = {key: value for key, value in vars(self.cfg).items()} | |||||
hyperparameters_str = "\n".join([f"{key}: {value}" for key, value in hyperparameters.items() if ('path' not in key and 'log_dir' not in key)]) | |||||
logging.info("Hyperparameters:\n" + hyperparameters_str) | |||||
def log_metrics(self, epoch, log_data): | |||||
# Loss Metrics Table | |||||
loss_header = ["", "f2f", "f2t", "v2t"] | |||||
# Check if train_loss exists, otherwise assign "N/A" | |||||
train_loss_f2f = log_data.get("train_loss(f2f)", "N/A") | |||||
train_loss_f2t = log_data.get("train_loss(f2t)", "N/A") | |||||
train_loss_v2t = log_data.get("train_loss(v2t)", "N/A") | |||||
# Check if valid_loss exists, otherwise assign "N/A" | |||||
valid_loss_f2f = log_data.get("valid_loss(f2f)", "N/A") | |||||
valid_loss_f2t = log_data.get("valid_loss(f2t)", "N/A") | |||||
valid_loss_v2t = log_data.get("valid_loss(v2t)", "N/A") | |||||
loss_rows = [ | |||||
["train", train_loss_f2f, train_loss_f2t, train_loss_v2t], | |||||
["val", valid_loss_f2f, valid_loss_f2t, valid_loss_v2t], | |||||
] | |||||
loss_table = tabulate(loss_rows, headers=loss_header, tablefmt="grid") | |||||
# Recall Metrics Table | |||||
recall_header = ["", "R@1", "R@5", "R@10"] | |||||
clustering_names = ['uniform', 'spectral', 'agglomerative', 'kmeans'] | |||||
recall_rows_t2v = [] | |||||
recall_rows_v2t = [] | |||||
for k in [1, 2, 4, 8, 16]: | |||||
# Without clustering | |||||
base_key = f"k={k}" | |||||
if any(f"[t2v] R@{r} ({base_key})" in log_data for r in [1, 5, 10]): | |||||
row = [f"{k}"] | |||||
for r in [1, 5, 10]: | |||||
metric_key = f"[t2v] R@{r} ({base_key})" | |||||
row.append(log_data.get(metric_key, "N/A")) | |||||
recall_rows_t2v.append(row) | |||||
if any(f"[v2t] R@{r} ({base_key})" in log_data for r in [1, 5, 10]): | |||||
row = [f"{k}"] | |||||
for r in [1, 5, 10]: | |||||
metric_key = f"[v2t] R@{r} ({base_key})" | |||||
row.append(log_data.get(metric_key, "N/A")) | |||||
recall_rows_v2t.append(row) | |||||
# With clustering | |||||
for clustering_name in clustering_names: | |||||
cluster_key = f"{base_key}-{clustering_name}" | |||||
if any(f"[t2v] R@{r} ({cluster_key})" in log_data for r in [1, 5, 10]): | |||||
row = [f"{k}-{clustering_name}"] | |||||
for r in [1, 5, 10]: | |||||
metric_key = f"[t2v] R@{r} ({cluster_key})" | |||||
row.append(log_data.get(metric_key, "N/A")) | |||||
recall_rows_t2v.append(row) | |||||
if any(f"[v2t] R@{r} ({cluster_key})" in log_data for r in [1, 5, 10]): | |||||
row = [f"{k}-{clustering_name}"] | |||||
for r in [1, 5, 10]: | |||||
metric_key = f"[v2t] R@{r} ({cluster_key})" | |||||
row.append(log_data.get(metric_key, "N/A")) | |||||
recall_rows_v2t.append(row) | |||||
recall_table_t2v = tabulate(recall_rows_t2v, headers=recall_header, tablefmt="grid") | |||||
recall_table_v2t = tabulate(recall_rows_v2t, headers=recall_header, tablefmt="grid") | |||||
logging.info(f"Epoch {epoch+1} - Loss and Recall Metrics:") | |||||
logging.info("\n" + loss_table) | |||||
logging.info("\nText-to-Video Recall Metrics:") | |||||
logging.info("\n" + recall_table_t2v) | |||||
logging.info("\nVideo-to-Text Recall Metrics:") | |||||
logging.info("\n" + recall_table_v2t) | |||||
def select_representative_frame_per_cluster(self, clusters, frame_scores): | |||||
representative_frames = [] | |||||
for cluster in clusters: | |||||
best_frame_index = max(cluster, key=lambda index: frame_scores[index]) | |||||
representative_frames.append(int(best_frame_index)) | |||||
return representative_frames | |||||
class kl_divergence_loss(nn.Module): | |||||
def __init__(self,): | |||||
super(kl_divergence_loss, self).__init__() | |||||
def forward(self, features1, features2): | |||||
features1 = F.normalize(features1, p=2, dim=-1) | |||||
features2 = F.normalize(features2, p=2, dim=-1) | |||||
cos_sim_features1 = torch.mm(features1, features1.t()) | |||||
cos_sim_features2 = torch.mm(features2, features2.t()) | |||||
probs_features1 = F.softmax(cos_sim_features1, dim=-1) | |||||
probs_features2 = F.softmax(cos_sim_features2, dim=-1) | |||||
loss = F.kl_div(probs_features1.log(), probs_features2, reduction='batchmean') | |||||
return loss | |||||
class CrossEn(nn.Module): | |||||
def __init__(self,): | |||||
super(CrossEn, self).__init__() | |||||
def forward(self, sim_matrix): | |||||
logpt = F.log_softmax(sim_matrix, dim=-1) | |||||
logpt = torch.diag(logpt) | |||||
nce_loss = -logpt | |||||
sim_loss = nce_loss.mean() | |||||
return sim_loss |
from sklearn.cluster import SpectralClustering, AgglomerativeClustering, KMeans | |||||
import numpy as np | |||||
from sklearn.metrics.pairwise import cosine_similarity | |||||
class VideoClusterer: | |||||
def __init__(self, clustering_method='uniform', n_clusters=2, similarity_threshold=0.8): | |||||
self.n_clusters = n_clusters | |||||
self.similarity_threshold = similarity_threshold | |||||
self.clustering_method = clustering_method | |||||
# Decide on the clustering method to use | |||||
if clustering_method == 'uniform': | |||||
self.clusterer = self.uniform_clustering | |||||
elif clustering_method == 'spectral': | |||||
self.clusterer = SpectralClustering(n_clusters=n_clusters, affinity='precomputed') | |||||
elif clustering_method == 'agglomerative': | |||||
self.clusterer = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward') | |||||
elif clustering_method == 'kmeans': | |||||
self.clusterer = KMeans(n_clusters=n_clusters, n_init=1) | |||||
else: | |||||
raise ValueError(f"Invalid clustering method: {clustering_method}") | |||||
def uniform_clustering(self, features): | |||||
n = len(features) | |||||
clusters = [] | |||||
cluster_size = n // self.n_clusters | |||||
remainder = n % self.n_clusters | |||||
start = 0 | |||||
for i in range(self.n_clusters): | |||||
if i < remainder: | |||||
end = start + cluster_size + 1 | |||||
else: | |||||
end = start + cluster_size | |||||
clusters.append(list(range(start, end))) | |||||
start = end | |||||
return clusters | |||||
def detect_outliers(self, features): | |||||
dot_product_matrix = features.dot(features.T) | |||||
average_similarities = np.mean(dot_product_matrix, axis=0) | |||||
# Adding a small constant epsilon to the standard deviation to prevent division by zero | |||||
epsilon = 1e-8 | |||||
normal = (average_similarities - np.mean(average_similarities)) / (np.std(average_similarities) + epsilon) | |||||
outlier_mask = np.logical_or(normal > 1.5, normal < -1.5) | |||||
return outlier_mask | |||||
def get_clusters(self, features): | |||||
features = features.cpu().numpy() | |||||
if self.clustering_method == 'uniform': | |||||
return self.uniform_clustering(features) | |||||
else: | |||||
# For non-uniform methods, follow the original procedure | |||||
outlier_mask = self.detect_outliers(features) | |||||
if np.sum(~outlier_mask) > self.n_clusters: | |||||
features = features[~outlier_mask] | |||||
# Compute cosine similarity matrix for spectral clustering | |||||
if self.clustering_method == 'spectral': | |||||
similarity_matrix = cosine_similarity(features) | |||||
labels = self.clusterer.fit_predict(similarity_matrix) | |||||
else: | |||||
# For agglomerative, k-means, and other clustering methods that don't require a precomputed matrix | |||||
labels = self.clusterer.fit_predict(features) | |||||
# Organize frames into clusters based on labels | |||||
clusters = [[] for _ in range(self.n_clusters)] | |||||
for idx, label in enumerate(labels): | |||||
clusters[label].append(idx) | |||||
return clusters |
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import unicode_literals | |||||
from __future__ import print_function | |||||
import numpy as np | |||||
import torch | |||||
def compute_metrics(x): | |||||
sx = np.sort(-x, axis=1) | |||||
d = np.diag(-x) | |||||
d = d[:, np.newaxis] | |||||
ind = sx - d | |||||
ind = np.where(ind == 0) | |||||
ind = ind[1] | |||||
metrics = {} | |||||
metrics['R@1'] = float(np.sum(ind == 0)) * 100 / len(ind) | |||||
metrics['R@5'] = float(np.sum(ind < 5)) * 100 / len(ind) | |||||
metrics['R@10'] = float(np.sum(ind < 10)) * 100 / len(ind) | |||||
metrics["MedianR"] = np.median(ind) + 1 | |||||
metrics["MeanR"] = np.mean(ind) + 1 | |||||
# metrics["cols"] = [int(i) for i in list(ind)] | |||||
return metrics | |||||
def print_computed_metrics(metrics): | |||||
r1 = metrics['R@1'] | |||||
r5 = metrics['R@5'] | |||||
r10 = metrics['R@10'] | |||||
mr = metrics['MR'] | |||||
print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr)) | |||||
# below two functions directly come from: https://github.com/Deferf/Experiments | |||||
def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]): | |||||
if not torch.is_tensor(sim_tensor): | |||||
sim_tensor = torch.tensor(sim_tensor) | |||||
# Permute sim_tensor so it represents a sequence of text-video similarity matrices. | |||||
# Then obtain the double argsort to position the rank on the diagonal | |||||
stacked_sim_matrices = sim_tensor.permute(1, 0, 2) | |||||
first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True) | |||||
second_argsort = torch.argsort(first_argsort, dim = -1, descending= False) | |||||
# Extracts ranks i.e diagonals | |||||
ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2)) | |||||
# Now we need to extract valid ranks, as some belong to inf padding values | |||||
permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2)) | |||||
mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) | |||||
valid_ranks = ranks[mask] | |||||
# A quick dimension check validates our results, there may be other correctness tests pending | |||||
# Such as dot product localization, but that is for other time. | |||||
#assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) | |||||
if not torch.is_tensor(valid_ranks): | |||||
valid_ranks = torch.tensor(valid_ranks) | |||||
results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} | |||||
results["MedianR"] = float(torch.median(valid_ranks + 1)) | |||||
results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) | |||||
results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) | |||||
results['MR'] = results["MedianR"] | |||||
return results | |||||
def tensor_video_to_text_sim(sim_tensor): | |||||
if not torch.is_tensor(sim_tensor): | |||||
sim_tensor = torch.tensor(sim_tensor) | |||||
# Code to avoid nans | |||||
sim_tensor[sim_tensor != sim_tensor] = float('-inf') | |||||
# Forms a similarity matrix for use with rank at k | |||||
values, _ = torch.max(sim_tensor, dim=1, keepdim=True) | |||||
return torch.squeeze(values).T |