| @@ -0,0 +1,89 @@ | |||
| import argparse | |||
| import os | |||
| import torch | |||
| media_path_root = "path/to/media" | |||
| class Config: | |||
| def __init__(self): | |||
| self.parser = argparse.ArgumentParser(description='Train a student network.') | |||
| self.add_arguments() | |||
| self.args = self.parse() | |||
| self.update_paths() | |||
| def add_arguments(self): | |||
| # General settings | |||
| self.parser.add_argument('--seed', type=int, default=12345, help='Init Seed') | |||
| self.parser.add_argument('--device', type=int, default=0, help='Device number to use for training') | |||
| self.parser.add_argument("--do_train", type=int, default=1, help="Whether to run training.") | |||
| self.parser.add_argument("--do_eval", type=int, default=1, help="Whether to run evaluation.") | |||
| self.parser.add_argument('--dataset', type=str, default='activity-net', choices=['activity-net', 'msrvtt'], help='Dataset to use') | |||
| self.parser.add_argument('--teacher_model', type=str, default='CLIP', choices=['CLIP', 'DiffusionRet', 'EMCL', 'HBI', 'DiCoSA'], help='Teacher model to use') | |||
| self.parser.add_argument('--workers', type=int, default=12, help='Number of workers') | |||
| # Model configurations | |||
| self.parser.add_argument('--dropout_ratio', type=float, default=0.3, help='Dropout ratio for the model') | |||
| self.parser.add_argument('--fc1_size', type=int, default=512, help='First FC layer Hidden size') | |||
| self.parser.add_argument('--fc2_size', type=int, default=256, help='Second FC layer Hidden size') | |||
| self.parser.add_argument('--fc3_size', type=int, default=128, help='Third FC layer Hidden size') | |||
| self.parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'], help='Backbone model') | |||
| self.parser.add_argument('--n_trainable_backbone_layers', type=int, default=1, help='Number of trainable backbone layers') | |||
| self.parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'leakyrelu', 'gelu'], help='Activation function') | |||
| self.parser.add_argument('--normalization_layer', type=str, default='sigmoid', choices=['sigmoid', 'softmax', 'gumbel', 'k_top'], help='Normalization layer type (for text-to-video retireval pipeline)') | |||
| self.parser.add_argument('--init_tau', default=5.0, type=float, help="annealing init temperature") | |||
| self.parser.add_argument('--min_tau', default=0.5, type=float, help="min temperature to anneal to") | |||
| self.parser.add_argument('--decay_factor', default=0.045, type=float, help="exp decay factor per epoch") | |||
| self.parser.add_argument('--store_complete_model', type=int, default=0, help='Store the compelte model') | |||
| self.parser.add_argument('--load_complete_model', type=int, default=0, help='1: Load') | |||
| self.parser.add_argument('--trained_model_epoch', type=int, default=100, help='Load the trained model from which epoch') | |||
| self.parser.add_argument('--store_backbone', type=int, default=0, help='Store just backbone') | |||
| self.parser.add_argument('--load_backbone', type=int, default=1, help='1: Load') | |||
| self.parser.add_argument('--trained_backbone_epoch', type=int, default=100, help='Load a trained backbone model from which epoch') | |||
| # Dataloader configurations | |||
| self.parser.add_argument('--normalization', type=str, default='binary', choices=['raw', 'min_max', 'binary'], help='Normalization type for dataloader') | |||
| self.parser.add_argument('--binary_percentage', type=float, default=0.1, help='Binary percentage for dataloader') | |||
| # Training configurations | |||
| self.parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training') | |||
| self.parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for training') | |||
| self.parser.add_argument('--similarity_matrix_loss', type=str, default='kl', choices=['mse','kl'], help='frame2frame loss') | |||
| self.parser.add_argument('--use_weighted_loss', type=int, default=0, help='Whether use weighted BCE, to handle unbalanced classes') | |||
| self.parser.add_argument('--alpha', type=float, default=1, help='Weight for f2f loss') | |||
| self.parser.add_argument('--beta', type=float, default=1, help='Weight for f2t loss') | |||
| self.parser.add_argument('--gamma', type=float, default=1, help='Weight for v2t loss') | |||
| self.parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') | |||
| self.parser.add_argument('--use_scheduler', type=int, default=1, help='Use a learning rate scheduler') | |||
| self.parser.add_argument('--scheduler_step_size', type=float, default=5, help='After how many epochs the learning rate will be adjusted') | |||
| self.parser.add_argument('--scheduler_gamma', type=float, default=0.4, help='The factor by which the learning rate will be reduced at each step.') | |||
| self.parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay for the optimizer. Set 0 to disable it') | |||
| self.parser.add_argument('--do_cluster', type=int, default=1, help='Calculate results with cluster as well or not.') | |||
| def parse(self): | |||
| return self.parser.parse_args() | |||
| def update_paths(self): | |||
| self.args.device = torch.device(f'cuda:{self.args.device}' if torch.cuda.is_available() else "cpu") | |||
| self.args.raw_backbone_path = f"{media_path_root}/saved_models/student_pretrained_backbones/{self.args.backbone}_pretrained.pth" | |||
| self.args.st_scores_path = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/st_scores/st_scores.json" | |||
| self.args.trained_backbone_path = f"{media_path_root}/experiments/saved_models/{self.args.backbone}/{self.args.similarity_matrix_loss}" | |||
| self.args.trained_model_path = f"{self.args.trained_backbone_path}/complete_model/" | |||
| self.args.log_dir = f'logs/{self.args.project_name}/' | |||
| os.makedirs(self.args.st_scores_path, exist_ok=True) | |||
| os.makedirs(self.args.trained_backbone_path, exist_ok=True) | |||
| os.makedirs(self.args.trained_model_path, exist_ok=True) | |||
| os.makedirs(self.args.log_dir, exist_ok=True) | |||
| self.args.paths = {'train':{}, 'valid':{}} | |||
| for data_split in ['train', 'valid']: | |||
| split = 'val_1' if data_split=='valid' else data_split | |||
| self.args.paths[data_split]['frames'] = f"{media_path_root}/{self.args.dataset}/sampled_frames/{split}" | |||
| self.args.paths[data_split]['frame_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/frames/{split}" | |||
| self.args.paths[data_split]['text_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/description/{split}" | |||
| self.args.paths[data_split]['gt_scores'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/gt_scores/gt_{split}.json" | |||
| @@ -0,0 +1,62 @@ | |||
| import cv2 | |||
| from PIL import Image | |||
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode | |||
| class RawVideoExtractor(): | |||
| def __init__(self, centercrop=False, framerate=-1, size=224, to_tensor=True): | |||
| self.centercrop = centercrop | |||
| self.framerate = framerate | |||
| self.to_tensor = to_tensor | |||
| self.transform = self._transform(size) | |||
| def _transform(self, n_px): | |||
| if self.to_tensor: | |||
| return Compose([ | |||
| Resize(n_px, interpolation=InterpolationMode.BICUBIC), | |||
| CenterCrop(n_px), ToTensor(), | |||
| Normalize((0.48145466, 0.4578275, 0.40821073), | |||
| (0.26862954, 0.26130258, 0.27577711))]) | |||
| else: | |||
| return Compose([Resize(n_px, interpolation=InterpolationMode.BICUBIC),CenterCrop(n_px)]) | |||
| def get_video_data(self, video_path, start_time=None, end_time=None): | |||
| if start_time is not None or end_time is not None: | |||
| assert start_time > -1 and end_time > start_time | |||
| assert self.framerate > -1 | |||
| cap = cv2.VideoCapture(video_path) | |||
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |||
| video_fps = cap.get(cv2.CAP_PROP_FPS) | |||
| start_frame = int(start_time * video_fps) if start_time else 0 | |||
| end_frame = int(end_time * video_fps) if end_time else frame_count - 1 | |||
| interval = 1 | |||
| if self.framerate > 0: | |||
| interval = video_fps / self.framerate | |||
| else: | |||
| self.framerate = video_fps | |||
| if interval == 0: | |||
| interval = 1 | |||
| images = [] | |||
| for i in range(frame_count): | |||
| ret, frame = cap.read() | |||
| if not ret: | |||
| break | |||
| if i >= start_frame and i <= end_frame: | |||
| if len(images) * interval < i - start_frame: | |||
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |||
| image = Image.fromarray(frame_rgb) | |||
| image = self.transform(image) | |||
| images.append(image) | |||
| cap.release() | |||
| return images | |||
| @@ -0,0 +1,65 @@ | |||
| import os | |||
| import json | |||
| import torch | |||
| from torchvision.io import read_image | |||
| from torch.utils.data import Dataset | |||
| from torchvision import transforms | |||
| class VideoDescriptionDataloader(Dataset): | |||
| def __init__(self, cfg, data_type): | |||
| self.frames_dir = cfg.paths[data_type]['frames'] | |||
| self.frame_features_dir = cfg.paths[data_type]['frame_features'] | |||
| self.text_features_dir = cfg.paths[data_type]['text_features'] | |||
| self.gt_scores_file_path = cfg.paths[data_type]['gt_scores'] | |||
| self.transform = self._get_transforms() | |||
| self.binary_percentage = cfg.binary_percentage | |||
| if cfg.normalization == 'min_max' or data_type=='valid': | |||
| self.normalize_scores = self._min_max_normalize_scores | |||
| elif cfg.normalization == 'binary': | |||
| self.normalize_scores = self._to_binary_labels | |||
| else: | |||
| raise ValueError(f"Unsupported normalization method: {cfg.normalization}") | |||
| with open(self.gt_scores_file_path, "r") as f: | |||
| self.gt_scores = json.load(f) | |||
| self.video_ids = list(self.gt_scores.keys()) | |||
| def __len__(self): | |||
| return len(self.video_ids) | |||
| def __getitem__(self, idx): | |||
| video_id = self.video_ids[idx] | |||
| video_frames_dir = os.path.join(self.frames_dir, self.video_ids[idx]) | |||
| frames_tensor = torch.stack([self.transform(read_image(os.path.join(video_frames_dir, frame_file)).float()) | |||
| for frame_file in sorted(os.listdir(video_frames_dir))]) | |||
| text_features = torch.load(os.path.join(self.text_features_dir, f'{self.video_ids[idx]}.pt')) | |||
| frame_features = torch.load(os.path.join(self.frame_features_dir, f'{self.video_ids[idx]}.pt')) | |||
| gt_score = torch.tensor(self.gt_scores[self.video_ids[idx]], dtype=torch.float32) | |||
| gt_score = self.normalize_scores(gt_score) | |||
| return frames_tensor, text_features, frame_features, gt_score | |||
| @staticmethod | |||
| def _get_transforms(): | |||
| return transforms.Compose([transforms.Normalize([0.485, 0.456, 0.406], | |||
| [0.229, 0.224, 0.225])]) | |||
| @staticmethod | |||
| def _min_max_normalize_scores(score): | |||
| min_val, max_val = score.min(), score.max() | |||
| if min_val != max_val: | |||
| return (score - min_val) / (max_val - min_val) | |||
| return torch.full_like(score, 0.5) | |||
| def _to_binary_labels(self, score): | |||
| num_top_elements = max(int(len(score) * self.binary_percentage) , 1) | |||
| sorted_indices = score.argsort(descending=True) | |||
| binary_labels = torch.zeros_like(score) | |||
| binary_labels[sorted_indices[:num_top_elements]] = 1 | |||
| return binary_labels | |||
| @@ -0,0 +1,497 @@ | |||
| video6513 | |||
| video6514 | |||
| video6515 | |||
| video6516 | |||
| video6517 | |||
| video6518 | |||
| video6519 | |||
| video6520 | |||
| video6521 | |||
| video6522 | |||
| video6523 | |||
| video6524 | |||
| video6525 | |||
| video6526 | |||
| video6527 | |||
| video6528 | |||
| video6529 | |||
| video6530 | |||
| video6531 | |||
| video6532 | |||
| video6533 | |||
| video6534 | |||
| video6535 | |||
| video6536 | |||
| video6537 | |||
| video6538 | |||
| video6539 | |||
| video6540 | |||
| video6541 | |||
| video6542 | |||
| video6543 | |||
| video6544 | |||
| video6545 | |||
| video6546 | |||
| video6547 | |||
| video6548 | |||
| video6549 | |||
| video6550 | |||
| video6551 | |||
| video6552 | |||
| video6553 | |||
| video6554 | |||
| video6555 | |||
| video6556 | |||
| video6557 | |||
| video6558 | |||
| video6559 | |||
| video6560 | |||
| video6561 | |||
| video6562 | |||
| video6563 | |||
| video6564 | |||
| video6565 | |||
| video6566 | |||
| video6567 | |||
| video6568 | |||
| video6569 | |||
| video6570 | |||
| video6571 | |||
| video6572 | |||
| video6573 | |||
| video6574 | |||
| video6575 | |||
| video6576 | |||
| video6577 | |||
| video6578 | |||
| video6579 | |||
| video6580 | |||
| video6581 | |||
| video6582 | |||
| video6583 | |||
| video6584 | |||
| video6585 | |||
| video6586 | |||
| video6587 | |||
| video6588 | |||
| video6589 | |||
| video6590 | |||
| video6591 | |||
| video6592 | |||
| video6593 | |||
| video6594 | |||
| video6595 | |||
| video6596 | |||
| video6597 | |||
| video6598 | |||
| video6599 | |||
| video6600 | |||
| video6601 | |||
| video6602 | |||
| video6603 | |||
| video6604 | |||
| video6605 | |||
| video6606 | |||
| video6607 | |||
| video6608 | |||
| video6609 | |||
| video6610 | |||
| video6611 | |||
| video6612 | |||
| video6613 | |||
| video6614 | |||
| video6615 | |||
| video6616 | |||
| video6617 | |||
| video6618 | |||
| video6619 | |||
| video6620 | |||
| video6621 | |||
| video6622 | |||
| video6623 | |||
| video6624 | |||
| video6625 | |||
| video6626 | |||
| video6627 | |||
| video6628 | |||
| video6629 | |||
| video6630 | |||
| video6631 | |||
| video6632 | |||
| video6633 | |||
| video6634 | |||
| video6635 | |||
| video6636 | |||
| video6637 | |||
| video6638 | |||
| video6639 | |||
| video6640 | |||
| video6641 | |||
| video6642 | |||
| video6643 | |||
| video6644 | |||
| video6645 | |||
| video6646 | |||
| video6647 | |||
| video6648 | |||
| video6649 | |||
| video6650 | |||
| video6651 | |||
| video6652 | |||
| video6653 | |||
| video6654 | |||
| video6655 | |||
| video6656 | |||
| video6657 | |||
| video6658 | |||
| video6659 | |||
| video6660 | |||
| video6661 | |||
| video6662 | |||
| video6663 | |||
| video6664 | |||
| video6665 | |||
| video6666 | |||
| video6667 | |||
| video6668 | |||
| video6669 | |||
| video6670 | |||
| video6671 | |||
| video6672 | |||
| video6673 | |||
| video6674 | |||
| video6675 | |||
| video6676 | |||
| video6677 | |||
| video6678 | |||
| video6679 | |||
| video6680 | |||
| video6681 | |||
| video6682 | |||
| video6683 | |||
| video6684 | |||
| video6685 | |||
| video6686 | |||
| video6687 | |||
| video6688 | |||
| video6689 | |||
| video6690 | |||
| video6691 | |||
| video6692 | |||
| video6693 | |||
| video6694 | |||
| video6695 | |||
| video6696 | |||
| video6697 | |||
| video6698 | |||
| video6699 | |||
| video6700 | |||
| video6701 | |||
| video6702 | |||
| video6703 | |||
| video6704 | |||
| video6705 | |||
| video6706 | |||
| video6707 | |||
| video6708 | |||
| video6709 | |||
| video6710 | |||
| video6711 | |||
| video6712 | |||
| video6713 | |||
| video6714 | |||
| video6715 | |||
| video6716 | |||
| video6717 | |||
| video6718 | |||
| video6719 | |||
| video6720 | |||
| video6721 | |||
| video6722 | |||
| video6723 | |||
| video6724 | |||
| video6725 | |||
| video6726 | |||
| video6727 | |||
| video6728 | |||
| video6729 | |||
| video6730 | |||
| video6731 | |||
| video6732 | |||
| video6733 | |||
| video6734 | |||
| video6735 | |||
| video6736 | |||
| video6737 | |||
| video6738 | |||
| video6739 | |||
| video6740 | |||
| video6741 | |||
| video6742 | |||
| video6743 | |||
| video6744 | |||
| video6745 | |||
| video6746 | |||
| video6747 | |||
| video6748 | |||
| video6749 | |||
| video6750 | |||
| video6751 | |||
| video6752 | |||
| video6753 | |||
| video6754 | |||
| video6755 | |||
| video6756 | |||
| video6757 | |||
| video6758 | |||
| video6759 | |||
| video6760 | |||
| video6761 | |||
| video6762 | |||
| video6763 | |||
| video6764 | |||
| video6765 | |||
| video6766 | |||
| video6767 | |||
| video6768 | |||
| video6769 | |||
| video6770 | |||
| video6771 | |||
| video6772 | |||
| video6773 | |||
| video6774 | |||
| video6775 | |||
| video6776 | |||
| video6777 | |||
| video6778 | |||
| video6779 | |||
| video6780 | |||
| video6781 | |||
| video6782 | |||
| video6783 | |||
| video6784 | |||
| video6785 | |||
| video6786 | |||
| video6787 | |||
| video6788 | |||
| video6789 | |||
| video6790 | |||
| video6791 | |||
| video6792 | |||
| video6793 | |||
| video6794 | |||
| video6795 | |||
| video6796 | |||
| video6797 | |||
| video6798 | |||
| video6799 | |||
| video6800 | |||
| video6801 | |||
| video6802 | |||
| video6803 | |||
| video6804 | |||
| video6805 | |||
| video6806 | |||
| video6807 | |||
| video6808 | |||
| video6809 | |||
| video6810 | |||
| video6811 | |||
| video6812 | |||
| video6813 | |||
| video6814 | |||
| video6815 | |||
| video6816 | |||
| video6817 | |||
| video6818 | |||
| video6819 | |||
| video6820 | |||
| video6821 | |||
| video6822 | |||
| video6823 | |||
| video6824 | |||
| video6825 | |||
| video6826 | |||
| video6827 | |||
| video6828 | |||
| video6829 | |||
| video6830 | |||
| video6831 | |||
| video6832 | |||
| video6833 | |||
| video6834 | |||
| video6835 | |||
| video6836 | |||
| video6837 | |||
| video6838 | |||
| video6839 | |||
| video6840 | |||
| video6841 | |||
| video6842 | |||
| video6843 | |||
| video6844 | |||
| video6845 | |||
| video6846 | |||
| video6847 | |||
| video6848 | |||
| video6849 | |||
| video6850 | |||
| video6851 | |||
| video6852 | |||
| video6853 | |||
| video6854 | |||
| video6855 | |||
| video6856 | |||
| video6857 | |||
| video6858 | |||
| video6859 | |||
| video6860 | |||
| video6861 | |||
| video6862 | |||
| video6863 | |||
| video6864 | |||
| video6865 | |||
| video6866 | |||
| video6867 | |||
| video6868 | |||
| video6869 | |||
| video6870 | |||
| video6871 | |||
| video6872 | |||
| video6873 | |||
| video6874 | |||
| video6875 | |||
| video6876 | |||
| video6877 | |||
| video6878 | |||
| video6879 | |||
| video6880 | |||
| video6881 | |||
| video6882 | |||
| video6883 | |||
| video6884 | |||
| video6885 | |||
| video6886 | |||
| video6887 | |||
| video6888 | |||
| video6889 | |||
| video6890 | |||
| video6891 | |||
| video6892 | |||
| video6893 | |||
| video6894 | |||
| video6895 | |||
| video6896 | |||
| video6897 | |||
| video6898 | |||
| video6899 | |||
| video6900 | |||
| video6901 | |||
| video6902 | |||
| video6903 | |||
| video6904 | |||
| video6905 | |||
| video6906 | |||
| video6907 | |||
| video6908 | |||
| video6909 | |||
| video6910 | |||
| video6911 | |||
| video6912 | |||
| video6913 | |||
| video6914 | |||
| video6915 | |||
| video6916 | |||
| video6917 | |||
| video6918 | |||
| video6919 | |||
| video6920 | |||
| video6921 | |||
| video6922 | |||
| video6923 | |||
| video6924 | |||
| video6925 | |||
| video6926 | |||
| video6927 | |||
| video6928 | |||
| video6929 | |||
| video6930 | |||
| video6931 | |||
| video6932 | |||
| video6933 | |||
| video6934 | |||
| video6935 | |||
| video6936 | |||
| video6937 | |||
| video6938 | |||
| video6939 | |||
| video6940 | |||
| video6941 | |||
| video6942 | |||
| video6943 | |||
| video6944 | |||
| video6945 | |||
| video6946 | |||
| video6947 | |||
| video6948 | |||
| video6949 | |||
| video6950 | |||
| video6951 | |||
| video6952 | |||
| video6953 | |||
| video6954 | |||
| video6955 | |||
| video6956 | |||
| video6957 | |||
| video6958 | |||
| video6959 | |||
| video6960 | |||
| video6961 | |||
| video6962 | |||
| video6963 | |||
| video6964 | |||
| video6965 | |||
| video6966 | |||
| video6967 | |||
| video6968 | |||
| video6969 | |||
| video6970 | |||
| video6971 | |||
| video6972 | |||
| video6973 | |||
| video6974 | |||
| video6975 | |||
| video6976 | |||
| video6977 | |||
| video6978 | |||
| video6979 | |||
| video6980 | |||
| video6981 | |||
| video6982 | |||
| video6983 | |||
| video6984 | |||
| video6985 | |||
| video6986 | |||
| video6987 | |||
| video6988 | |||
| video6989 | |||
| video6990 | |||
| video6991 | |||
| video6992 | |||
| video6993 | |||
| video6994 | |||
| video6995 | |||
| video6996 | |||
| video6997 | |||
| video6998 | |||
| video6999 | |||
| video7000 | |||
| video7001 | |||
| video7002 | |||
| video7003 | |||
| video7004 | |||
| video7005 | |||
| video7006 | |||
| video7007 | |||
| video7008 | |||
| video7009 | |||
| @@ -0,0 +1,63 @@ | |||
| import os | |||
| import torch | |||
| from config import Config | |||
| from src.model import SaliencyNet | |||
| from torch.utils.data import DataLoader | |||
| from src.train import SalientFrameSamplerTrainer | |||
| from dataloaders.t2v_dataloader import VideoDescriptionDataloader | |||
| def load_weights(path, which_epoch): | |||
| weights, last_epoch = None, None | |||
| available_models = [name for name in os.listdir(path) if name.endswith(".pt")] | |||
| if available_models: | |||
| last_epoch = max([int(name[6:-3]) for name in available_models]) | |||
| last_epoch = min(last_epoch, which_epoch) | |||
| weights = torch.load(os.path.join(path, f'epoch_{last_epoch}.pt')) | |||
| return weights, last_epoch | |||
| def set_seeds(seed: int): | |||
| os.environ['PYTHONHASHSEED'] = str(seed) | |||
| torch.manual_seed(seed) | |||
| torch.cuda.manual_seed(seed) | |||
| torch.backends.cudnn.benchmark = False | |||
| torch.backends.cudnn.deterministic = True | |||
| def main(cfg): | |||
| set_seeds(cfg.seed) | |||
| model = SaliencyNet(cfg) | |||
| model.to(cfg.device) | |||
| if cfg.load_complete_model: | |||
| weights, last_epoch = load_weights(cfg.trained_model_path, cfg.trained_model_epoch) | |||
| if weights: | |||
| model.load_state_dict(weights) | |||
| print(f'Complete model -{cfg.backbone} trained on {cfg.similarity_matrix_loss} and {cfg.saliency_matching_loss} lossese from epoch #{last_epoch}') | |||
| elif cfg.load_backbone: | |||
| weights, last_epoch = load_weights(cfg.trained_backbone_path, cfg.trained_backbone_epoch) | |||
| if weights: | |||
| model.pretrained.load_state_dict(weights) | |||
| print(f'{cfg.backbone} backbone trained on {cfg.similarity_matrix_loss} loss loaded from epoch #{last_epoch}') | |||
| else: | |||
| print(f'{cfg.backbone} backbone loaded from scratch.') | |||
| train_dataset = VideoDescriptionDataloader(cfg, data_type='train') | |||
| val_dataset = VideoDescriptionDataloader(cfg, data_type='valid') | |||
| train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=cfg.workers) | |||
| val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=cfg.workers) | |||
| trainer = SalientFrameSamplerTrainer(model, train_dataloader, val_dataloader, cfg) | |||
| trainer.train() | |||
| if __name__ == "__main__": | |||
| torch.multiprocessing.set_start_method("spawn") | |||
| cfg = Config().args | |||
| main(cfg) | |||
| @@ -0,0 +1,7 @@ | |||
| torch | |||
| opencv-python | |||
| Pillow | |||
| timm | |||
| transformers | |||
| torchvision | |||
| scikit-learn | |||
| @@ -0,0 +1,50 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class GumbelSoftmaxTopK(nn.Module): | |||
| def __init__(self, k, tau): | |||
| super(GumbelSoftmaxTopK, self).__init__() | |||
| self.k = k | |||
| self.tau = tau | |||
| def forward(self, logits): | |||
| logits = torch.log(nn.Softmax(dim=0)(logits).clamp(min=1e-8)) | |||
| """ | |||
| Adapted from: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/DL2/sampling/subsets.html#Subset-Sampler-Class | |||
| and https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax | |||
| """ | |||
| m = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format), | |||
| torch.ones_like(logits, memory_format= torch.legacy_contiguous_format)) | |||
| gumbels = m.sample() | |||
| y = logits + gumbels | |||
| _EPS = torch.tensor(1e-40).to(logits.device) | |||
| khot = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) | |||
| onehot_approx = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) | |||
| for i in range(self.k): | |||
| khot_mask = torch.maximum(1.0 - onehot_approx, _EPS) | |||
| y += khot_mask.log() | |||
| onehot_approx = nn.Softmax(dim=0)(y / self.tau) | |||
| khot = torch.add(khot, onehot_approx) | |||
| return khot | |||
| class GumbelSoftmax(nn.Module): | |||
| def __init__(self, tau, hard=False): | |||
| super(GumbelSoftmax, self).__init__() | |||
| self.tau = tau | |||
| self.hard = hard | |||
| def forward(self, logits): | |||
| m = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format), | |||
| torch.ones_like(logits, memory_format= torch.legacy_contiguous_format)) | |||
| gumbels = m.sample() | |||
| y = logits + gumbels | |||
| y_soft = nn.Softmax(dim=0)(y / self.tau) | |||
| return y_soft | |||
| @@ -0,0 +1,131 @@ | |||
| import timm | |||
| import torch | |||
| import torch.nn as nn | |||
| from src.gumbel_softmax import GumbelSoftmax, GumbelSoftmaxTopK | |||
| class SaliencyNet(nn.Module): | |||
| def __init__(self, cfg): | |||
| super(SaliencyNet, self).__init__() | |||
| self.k = 8 | |||
| self.tau = cfg.init_tau | |||
| self.min_tau = cfg.min_tau | |||
| self.decay_factor = cfg.decay_factor | |||
| self.raw_backbone_path = cfg.raw_backbone_path | |||
| self.configure_pretrained(cfg.backbone, cfg.n_trainable_backbone_layers) | |||
| embed_dim = self.pretrained.num_features | |||
| self.fc1 = nn.Linear(embed_dim, cfg.fc1_size) | |||
| self.fc2 = nn.Linear(cfg.fc1_size, cfg.fc2_size) | |||
| self.fc3 = nn.Linear(cfg.fc2_size, cfg.fc3_size) | |||
| self.fc4 = nn.Linear(cfg.fc3_size, 1) | |||
| # self.bn1 = nn.BatchNorm1d(fc1_size) # Batch normalization layer | |||
| self.dropout = nn.Dropout(cfg.dropout_ratio) | |||
| self.sigmoid = nn.Sigmoid() | |||
| if cfg.activation == 'relu': | |||
| self.activation = nn.ReLU() | |||
| elif cfg.activation == 'leakyrelu': | |||
| self.activation = nn.LeakyReLU() | |||
| elif cfg.activation == 'gelu': | |||
| self.activation = nn.GELU() | |||
| else: | |||
| raise ValueError("Invalid activation type. Choose 'relu', 'leakyrelu', or 'gelu'.") | |||
| self.normalization_layer = self.init_normalization_layer(cfg.normalization_layer, self.tau, self.k) | |||
| self.initialize_weights() | |||
| def init_normalization_layer(self, normalization_type, tau, k): | |||
| if normalization_type == 'raw': | |||
| normalization_layer = lambda x: x | |||
| elif normalization_type == 'sigmoid': | |||
| normalization_layer = nn.Sigmoid() | |||
| elif normalization_type == 'softmax': | |||
| normalization_layer = nn.Softmax(dim=1) | |||
| elif normalization_type == 'gumbel': | |||
| normalization_layer = GumbelSoftmax(tau) | |||
| elif normalization_type == 'k_top': | |||
| normalization_layer = GumbelSoftmaxTopK(k, tau) | |||
| return normalization_layer | |||
| def initialize_weights(self): | |||
| for module in self.modules(): | |||
| if isinstance(module, nn.Linear): | |||
| nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |||
| if module.bias is not None: | |||
| nn.init.constant_(module.bias, 0) | |||
| def forward(self, preprocessed_images): | |||
| features = self.pretrained(preprocessed_images) | |||
| x = self.activation(self.fc1(features)) | |||
| x = self.dropout(x) | |||
| x = self.activation(self.fc2(x)) | |||
| x = self.dropout(x) | |||
| x = self.activation(self.fc3(x)) | |||
| x = self.fc4(x) | |||
| frame_scores = torch.flatten(x) | |||
| frame_weights = torch.flatten(self.normalization_layer(x)) | |||
| return frame_scores, frame_weights, features | |||
| def anneal_tau(self): | |||
| """ Anneals the temperature parameter τ using exponential decay """ | |||
| self.tau = self.tau * torch.exp(-self.decay_factor * torch.tensor(1.0)) | |||
| self.tau = max(self.tau, self.min_tau) # To prevent τ from becoming too small | |||
| def configure_pretrained(self, backbone, n_trainable_backbone_layers=1): | |||
| # backbones = ['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'] | |||
| self.pretrained = timm.create_model(backbone, pretrained=False, checkpoint_path=self.raw_backbone_path) | |||
| # Freeze the early layers of the model | |||
| for param in self.pretrained.parameters(): | |||
| param.requires_grad = False | |||
| # disable last layer of resnet18 | |||
| if backbone =='resnet18': | |||
| self.pretrained.fc = nn.Identity() | |||
| sequential_modules = [child for child in self.pretrained.children() if isinstance(child, nn.Sequential)] | |||
| elif backbone in ['efficientnet_b0', 'mobilenetv2_100','mobilenetv3_large_100'] : | |||
| self.pretrained.classifier = nn.Identity() | |||
| sequential_modules = [child for child in self.pretrained.blocks.children()] | |||
| else: | |||
| assert('Not supported backbone model!') | |||
| num_seq_blocks = len(sequential_modules) | |||
| if n_trainable_backbone_layers ==0: | |||
| pass | |||
| elif n_trainable_backbone_layers >= num_seq_blocks: | |||
| for param in self.pretrained.parameters(): | |||
| param.requires_grad = True | |||
| else: | |||
| # Select the last `n_trainable_backbone_layers` of sequential modules to be trainable | |||
| layers_to_train = sequential_modules[-n_trainable_backbone_layers:] | |||
| for layer in layers_to_train: | |||
| for param in layer.parameters(): | |||
| param.requires_grad = True | |||
| if n_trainable_backbone_layers > 0: | |||
| last_layers= [] | |||
| if hasattr(self.pretrained, 'conv_head'): | |||
| last_layers.append(self.pretrained.conv_head.weight) | |||
| if hasattr(self.pretrained, 'bn2'): | |||
| last_layers.append(self.pretrained.bn2.weight) | |||
| last_layers.append(self.pretrained.bn2.bias) | |||
| for param in last_layers: | |||
| param.requires_grad = True | |||
| @@ -0,0 +1,378 @@ | |||
| import os | |||
| import torch | |||
| from tqdm import tqdm, trange | |||
| import torch.nn as nn | |||
| import logging | |||
| import torch.nn.functional as F | |||
| from tabulate import tabulate | |||
| from torch.utils.data import DataLoader | |||
| from utils.cluster_frames import VideoClusterer | |||
| from utils.metrics import compute_metrics | |||
| class SalientFrameSamplerTrainer: | |||
| def __init__(self, model, train_dataloader: DataLoader, val_dataloader: DataLoader, cfg): | |||
| self.cfg = cfg | |||
| self.model = model | |||
| self.train_dataloader = train_dataloader | |||
| self.val_dataloader = val_dataloader | |||
| self.init_logging() | |||
| self.criterion1 = kl_divergence_loss() | |||
| self.criterion2 = nn.BCEWithLogitsLoss() | |||
| self.criterion3 = CrossEn() | |||
| self.optimizer = self.init_optimizer() | |||
| if self.cfg.use_scheduler: | |||
| self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.cfg.scheduler_step_size, gamma=self.cfg.scheduler_gamma) | |||
| def train(self): | |||
| for epoch in trange(self.cfg.epochs, desc="Epoch", ncols=100): | |||
| log_data = {} | |||
| if self.cfg.do_train: | |||
| train_loss = self.run_training_epoch() | |||
| log_data.update({ | |||
| "train_loss(f2f)": train_loss['f2f'], | |||
| "train_loss(f2t)": train_loss['f2t'], | |||
| "train_loss(v2t)": train_loss['v2t'], | |||
| }) | |||
| if self.cfg.do_eval: | |||
| valid_loss, t2v_metrics, v2t_metrics = self.run_validation_epoch() | |||
| log_data.update({ | |||
| "valid_loss(f2f)": valid_loss['f2f'], | |||
| "valid_loss(f2t)": valid_loss['f2t'], | |||
| "valid_loss(v2t)": valid_loss['v2t'], | |||
| }) | |||
| for key in t2v_metrics.keys(): | |||
| for r in [1, 5, 10]: | |||
| log_data[f"[t2v] R@{r} (k={key})"] = t2v_metrics[key][f'R@{r}'] | |||
| log_data[f"[v2t] R@{r} (k={key})"] = v2t_metrics[key][f'R@{r}'] | |||
| self.log_metrics(epoch, log_data) | |||
| self.model.anneal_tau() | |||
| if hasattr(self,'scheduler'): | |||
| self.scheduler.step() | |||
| if self.cfg.store_backbone: | |||
| torch.save(self.model.pretrained.state_dict(), os.path.join(self.cfg.trained_backbone_path, f"epoch_{epoch+1}.pt")) | |||
| if self.cfg.store_complete_model: | |||
| torch.save(self.model.state_dict(), os.path.join(self.args.trained_model_path, f"epoch_{epoch+1}.pt")) | |||
| def run_training_epoch(self): | |||
| self.model.train() | |||
| epoch_train_loss = {'f2f': 0.0, 'f2t': 0.0, 'v2t': 0.0} | |||
| video_embeddings = [] | |||
| text_embeddings = [] | |||
| f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||
| for i, (frames, teacher_text_features, teacher_frame_features, gt_score) in enumerate(tqdm(self.train_dataloader, desc='Train', leave=False, ncols=100)): | |||
| torch.cuda.empty_cache() | |||
| frames = frames[0].to(self.cfg.device) | |||
| teacher_frame_features = teacher_frame_features[0].to(self.cfg.device) | |||
| teacher_text_features = teacher_text_features[0].to(self.cfg.device) | |||
| gt_score = gt_score[0].to(self.cfg.device) | |||
| frame_scores, frame_weights, student_frame_features = self.model(frames) | |||
| # Frame-to-Frame loss | |||
| f2f_loss = self.criterion1(student_frame_features, teacher_frame_features) | |||
| # Frame-to-Text loss | |||
| f2t_loss = self.criterion2(frame_scores, gt_score) | |||
| # Update accumulators | |||
| f2f_loss_accum += f2f_loss.item() / self.cfg.batch_size | |||
| f2t_loss_accum += f2t_loss.item() / self.cfg.batch_size | |||
| # Calculating the in-pipeline loss for v2t | |||
| video_features = torch.matmul(teacher_frame_features.T, frame_weights) / torch.sum(frame_weights) | |||
| video_embeddings.append(video_features) | |||
| text_embeddings.append(teacher_text_features) | |||
| if (i + 1) % self.cfg.batch_size == 0 or i == len(self.train_dataloader) - 1: | |||
| sim_matrix = self.compute_sim_matrix(video_embeddings, text_embeddings) | |||
| sim_loss1 = self.criterion3(sim_matrix) | |||
| sim_loss2 = self.criterion3(sim_matrix.T) | |||
| v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||
| # Aggregate the total loss with regularization terms | |||
| total_loss = self.alpha * f2f_loss_accum + self.beta * f2t_loss_accum + self.gamma * v2t_loss | |||
| self.optimizer.zero_grad() | |||
| total_loss.backward() | |||
| # Gradient clipping | |||
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |||
| self.optimizer.step() | |||
| # Update epoch losses | |||
| epoch_train_loss['f2f'] += f2f_loss_accum | |||
| epoch_train_loss['f2t'] += f2t_loss_accum | |||
| epoch_train_loss['v2t'] += v2t_loss.item() | |||
| # Reset accumulators | |||
| f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||
| video_embeddings = [] | |||
| text_embeddings = [] | |||
| # Normalize epoch losses by the number of batches | |||
| for key in epoch_train_loss.keys(): | |||
| epoch_train_loss[key] /= len(self.train_dataloader) | |||
| return epoch_train_loss | |||
| def run_validation_epoch(self): | |||
| self.model.eval() | |||
| epoch_valid_loss = {'f2f': 0.0, 'f2t': 0.0, 'v2t': 0.0} | |||
| video_embeddings = [] | |||
| text_embeddings = [] | |||
| all_video_embeddings = {str(k): {'without_clustering': [], | |||
| 'uniform': [], 'spectral': [], 'agglomerative': [], 'kmeans': []} for k in [1,2,4,8,16]} | |||
| all_text_embeddings = [] | |||
| all_st_scores = [] | |||
| f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||
| for i, (frames, teacher_text_features, teacher_frame_features, gt_score) in enumerate(tqdm(self.val_dataloader, desc='Valid', leave=False, ncols=100)): | |||
| torch.cuda.empty_cache() | |||
| frames = frames[0].to(self.cfg.device) | |||
| teacher_frame_features = teacher_frame_features[0].to(self.cfg.device) | |||
| teacher_text_features = teacher_text_features[0].to(self.cfg.device) | |||
| gt_score = gt_score[0].to(self.cfg.device) | |||
| with torch.no_grad(): | |||
| frame_scores, frame_weights, student_frame_features = self.model(frames) | |||
| f2f_loss = self.criterion1(student_frame_features, teacher_frame_features) | |||
| f2t_loss = self.criterion2(frame_scores, gt_score) | |||
| f2f_loss_accum += f2f_loss.item() | |||
| f2t_loss_accum += f2t_loss.item() | |||
| # Calculating the all embeddings, used for calculating t2v_retrieval metrics | |||
| all_text_embeddings.append(teacher_text_features) | |||
| n_frames = teacher_frame_features.size(0) | |||
| sorted_indices = sorted(range(n_frames), key=lambda i: frame_scores[i], reverse=True).copy() | |||
| for k in [1,2,4,8,16]: | |||
| n_selected_frames = min(n_frames, k) | |||
| indices = sorted_indices[:n_selected_frames] | |||
| all_video_embeddings[str(k)]['without_clustering'].append(teacher_frame_features[indices].mean(dim=0)) | |||
| if self.cfg.do_cluster and k>2: | |||
| for clustering_method in ['uniform', 'spectral', 'agglomerative', 'kmeans']: | |||
| if len(frame_scores) <= k: #mean of all frames | |||
| all_video_embeddings[str(k)][clustering_method].append(teacher_frame_features.mean(dim=0)) | |||
| else: | |||
| clusterer = VideoClusterer(clustering_method, k) | |||
| clusters = clusterer.get_clusters(student_frame_features) | |||
| selected_indices = self.select_representative_frame_per_cluster(clusters, frame_scores) | |||
| all_video_embeddings[str(k)][clustering_method].append(teacher_frame_features[selected_indices].mean(dim=0)) | |||
| # Calculating the in-pipeline loss | |||
| video_features = torch.matmul(teacher_frame_features.T, frame_weights) / torch.sum(frame_weights) | |||
| video_embeddings.append(video_features) | |||
| text_embeddings.append(teacher_text_features) | |||
| if (i + 1) % self.cfg.batch_size == 0 or i == len(self.val_dataloader) - 1: | |||
| sim_matrix = self.compute_sim_matrix(video_embeddings, text_embeddings) | |||
| sim_loss1 = self.criterion3(sim_matrix) | |||
| sim_loss2 = self.criterion3(sim_matrix.T) | |||
| v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||
| v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||
| # Update epoch losses | |||
| epoch_valid_loss['f2f'] += f2f_loss_accum | |||
| epoch_valid_loss['f2t'] += f2t_loss_accum | |||
| epoch_valid_loss['v2t'] += v2t_loss.item() | |||
| # Reset accumulators | |||
| f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||
| video_embeddings = [] | |||
| text_embeddings = [] | |||
| # Normalize epoch losses by the number of batches | |||
| for key in epoch_valid_loss.keys(): | |||
| epoch_valid_loss[key] /= len(self.val_dataloader) | |||
| # Compute t2v metrics | |||
| t2v_metrics = {} | |||
| v2t_metrics = {} | |||
| for k, embeddings_dict in all_video_embeddings.items(): | |||
| t2v_metrics[str(k)]={} | |||
| for scenario, embeddings_list in embeddings_dict.items(): | |||
| if embeddings_list: # Check if there are any embeddings for this scenario | |||
| sim_matrix = self.compute_sim_matrix(embeddings_list, all_text_embeddings) | |||
| _key = f"{k}" if scenario=='without_clustering' else f"{k}-{scenario}" | |||
| t2v_metrics[_key] = compute_metrics(sim_matrix.detach().cpu().numpy()) | |||
| v2t_metrics[_key] = compute_metrics(sim_matrix.T.detach().cpu().numpy()) | |||
| return epoch_valid_loss, t2v_metrics, v2t_metrics | |||
| def compute_sim_matrix(self, video_embeddings, text_embeddings): | |||
| video_embeddings_tensor = torch.stack(video_embeddings, dim=0) | |||
| text_embeddings_tensor = torch.stack(text_embeddings, dim=0) | |||
| video_embeddings_norm = video_embeddings_tensor / video_embeddings_tensor.norm(dim=-1, keepdim=True) | |||
| text_embeddings_norm = text_embeddings_tensor / text_embeddings_tensor.norm(dim=-1, keepdim=True) | |||
| sim_matrix = torch.matmul(video_embeddings_norm, text_embeddings_norm.T) | |||
| return sim_matrix | |||
| def init_optimizer(self): | |||
| if not hasattr(self.cfg, 'backbone_lr_scale'): | |||
| return torch.optim.Adam(params=self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay) | |||
| backbone_params = [] | |||
| last_params = [] | |||
| for name, param in self.model.named_parameters(): | |||
| if 'pretrain' in name: | |||
| backbone_params.append(param) | |||
| else: | |||
| last_params.append(param) | |||
| return torch.optim.Adam([ | |||
| {'params': backbone_params, 'lr': self.cfg.lr}, | |||
| {'params': last_params, 'lr': self.cfg.lr * self.cfg.backbone_lr_scale} | |||
| ], weight_decay=self.cfg.weight_decay) | |||
| def init_logging(self): | |||
| available_logs = os.listdir(self.cfg.log_dir) | |||
| last_run = 0 | |||
| if available_logs: | |||
| last_run = max([int(name.split('-')[0]) for name in available_logs]) | |||
| log_file_name = f"{last_run + 1}-local-run.log" | |||
| log_file = os.path.join(self.cfg.log_dir, log_file_name) | |||
| logging.basicConfig(filename=log_file, level=logging.INFO, | |||
| format='%(asctime)s [%(levelname)s] - %(message)s', | |||
| datefmt='%Y-%m-%d %H:%M:%S') | |||
| hyperparameters = {key: value for key, value in vars(self.cfg).items()} | |||
| hyperparameters_str = "\n".join([f"{key}: {value}" for key, value in hyperparameters.items() if ('path' not in key and 'log_dir' not in key)]) | |||
| logging.info("Hyperparameters:\n" + hyperparameters_str) | |||
| def log_metrics(self, epoch, log_data): | |||
| # Loss Metrics Table | |||
| loss_header = ["", "f2f", "f2t", "v2t"] | |||
| # Check if train_loss exists, otherwise assign "N/A" | |||
| train_loss_f2f = log_data.get("train_loss(f2f)", "N/A") | |||
| train_loss_f2t = log_data.get("train_loss(f2t)", "N/A") | |||
| train_loss_v2t = log_data.get("train_loss(v2t)", "N/A") | |||
| # Check if valid_loss exists, otherwise assign "N/A" | |||
| valid_loss_f2f = log_data.get("valid_loss(f2f)", "N/A") | |||
| valid_loss_f2t = log_data.get("valid_loss(f2t)", "N/A") | |||
| valid_loss_v2t = log_data.get("valid_loss(v2t)", "N/A") | |||
| loss_rows = [ | |||
| ["train", train_loss_f2f, train_loss_f2t, train_loss_v2t], | |||
| ["val", valid_loss_f2f, valid_loss_f2t, valid_loss_v2t], | |||
| ] | |||
| loss_table = tabulate(loss_rows, headers=loss_header, tablefmt="grid") | |||
| # Recall Metrics Table | |||
| recall_header = ["", "R@1", "R@5", "R@10"] | |||
| clustering_names = ['uniform', 'spectral', 'agglomerative', 'kmeans'] | |||
| recall_rows_t2v = [] | |||
| recall_rows_v2t = [] | |||
| for k in [1, 2, 4, 8, 16]: | |||
| # Without clustering | |||
| base_key = f"k={k}" | |||
| if any(f"[t2v] R@{r} ({base_key})" in log_data for r in [1, 5, 10]): | |||
| row = [f"{k}"] | |||
| for r in [1, 5, 10]: | |||
| metric_key = f"[t2v] R@{r} ({base_key})" | |||
| row.append(log_data.get(metric_key, "N/A")) | |||
| recall_rows_t2v.append(row) | |||
| if any(f"[v2t] R@{r} ({base_key})" in log_data for r in [1, 5, 10]): | |||
| row = [f"{k}"] | |||
| for r in [1, 5, 10]: | |||
| metric_key = f"[v2t] R@{r} ({base_key})" | |||
| row.append(log_data.get(metric_key, "N/A")) | |||
| recall_rows_v2t.append(row) | |||
| # With clustering | |||
| for clustering_name in clustering_names: | |||
| cluster_key = f"{base_key}-{clustering_name}" | |||
| if any(f"[t2v] R@{r} ({cluster_key})" in log_data for r in [1, 5, 10]): | |||
| row = [f"{k}-{clustering_name}"] | |||
| for r in [1, 5, 10]: | |||
| metric_key = f"[t2v] R@{r} ({cluster_key})" | |||
| row.append(log_data.get(metric_key, "N/A")) | |||
| recall_rows_t2v.append(row) | |||
| if any(f"[v2t] R@{r} ({cluster_key})" in log_data for r in [1, 5, 10]): | |||
| row = [f"{k}-{clustering_name}"] | |||
| for r in [1, 5, 10]: | |||
| metric_key = f"[v2t] R@{r} ({cluster_key})" | |||
| row.append(log_data.get(metric_key, "N/A")) | |||
| recall_rows_v2t.append(row) | |||
| recall_table_t2v = tabulate(recall_rows_t2v, headers=recall_header, tablefmt="grid") | |||
| recall_table_v2t = tabulate(recall_rows_v2t, headers=recall_header, tablefmt="grid") | |||
| logging.info(f"Epoch {epoch+1} - Loss and Recall Metrics:") | |||
| logging.info("\n" + loss_table) | |||
| logging.info("\nText-to-Video Recall Metrics:") | |||
| logging.info("\n" + recall_table_t2v) | |||
| logging.info("\nVideo-to-Text Recall Metrics:") | |||
| logging.info("\n" + recall_table_v2t) | |||
| def select_representative_frame_per_cluster(self, clusters, frame_scores): | |||
| representative_frames = [] | |||
| for cluster in clusters: | |||
| best_frame_index = max(cluster, key=lambda index: frame_scores[index]) | |||
| representative_frames.append(int(best_frame_index)) | |||
| return representative_frames | |||
| class kl_divergence_loss(nn.Module): | |||
| def __init__(self,): | |||
| super(kl_divergence_loss, self).__init__() | |||
| def forward(self, features1, features2): | |||
| features1 = F.normalize(features1, p=2, dim=-1) | |||
| features2 = F.normalize(features2, p=2, dim=-1) | |||
| cos_sim_features1 = torch.mm(features1, features1.t()) | |||
| cos_sim_features2 = torch.mm(features2, features2.t()) | |||
| probs_features1 = F.softmax(cos_sim_features1, dim=-1) | |||
| probs_features2 = F.softmax(cos_sim_features2, dim=-1) | |||
| loss = F.kl_div(probs_features1.log(), probs_features2, reduction='batchmean') | |||
| return loss | |||
| class CrossEn(nn.Module): | |||
| def __init__(self,): | |||
| super(CrossEn, self).__init__() | |||
| def forward(self, sim_matrix): | |||
| logpt = F.log_softmax(sim_matrix, dim=-1) | |||
| logpt = torch.diag(logpt) | |||
| nce_loss = -logpt | |||
| sim_loss = nce_loss.mean() | |||
| return sim_loss | |||
| @@ -0,0 +1,73 @@ | |||
| from sklearn.cluster import SpectralClustering, AgglomerativeClustering, KMeans | |||
| import numpy as np | |||
| from sklearn.metrics.pairwise import cosine_similarity | |||
| class VideoClusterer: | |||
| def __init__(self, clustering_method='uniform', n_clusters=2, similarity_threshold=0.8): | |||
| self.n_clusters = n_clusters | |||
| self.similarity_threshold = similarity_threshold | |||
| self.clustering_method = clustering_method | |||
| # Decide on the clustering method to use | |||
| if clustering_method == 'uniform': | |||
| self.clusterer = self.uniform_clustering | |||
| elif clustering_method == 'spectral': | |||
| self.clusterer = SpectralClustering(n_clusters=n_clusters, affinity='precomputed') | |||
| elif clustering_method == 'agglomerative': | |||
| self.clusterer = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward') | |||
| elif clustering_method == 'kmeans': | |||
| self.clusterer = KMeans(n_clusters=n_clusters, n_init=1) | |||
| else: | |||
| raise ValueError(f"Invalid clustering method: {clustering_method}") | |||
| def uniform_clustering(self, features): | |||
| n = len(features) | |||
| clusters = [] | |||
| cluster_size = n // self.n_clusters | |||
| remainder = n % self.n_clusters | |||
| start = 0 | |||
| for i in range(self.n_clusters): | |||
| if i < remainder: | |||
| end = start + cluster_size + 1 | |||
| else: | |||
| end = start + cluster_size | |||
| clusters.append(list(range(start, end))) | |||
| start = end | |||
| return clusters | |||
| def detect_outliers(self, features): | |||
| dot_product_matrix = features.dot(features.T) | |||
| average_similarities = np.mean(dot_product_matrix, axis=0) | |||
| # Adding a small constant epsilon to the standard deviation to prevent division by zero | |||
| epsilon = 1e-8 | |||
| normal = (average_similarities - np.mean(average_similarities)) / (np.std(average_similarities) + epsilon) | |||
| outlier_mask = np.logical_or(normal > 1.5, normal < -1.5) | |||
| return outlier_mask | |||
| def get_clusters(self, features): | |||
| features = features.cpu().numpy() | |||
| if self.clustering_method == 'uniform': | |||
| return self.uniform_clustering(features) | |||
| else: | |||
| # For non-uniform methods, follow the original procedure | |||
| outlier_mask = self.detect_outliers(features) | |||
| if np.sum(~outlier_mask) > self.n_clusters: | |||
| features = features[~outlier_mask] | |||
| # Compute cosine similarity matrix for spectral clustering | |||
| if self.clustering_method == 'spectral': | |||
| similarity_matrix = cosine_similarity(features) | |||
| labels = self.clusterer.fit_predict(similarity_matrix) | |||
| else: | |||
| # For agglomerative, k-means, and other clustering methods that don't require a precomputed matrix | |||
| labels = self.clusterer.fit_predict(features) | |||
| # Organize frames into clusters based on labels | |||
| clusters = [[] for _ in range(self.n_clusters)] | |||
| for idx, label in enumerate(labels): | |||
| clusters[label].append(idx) | |||
| return clusters | |||
| @@ -0,0 +1,69 @@ | |||
| from __future__ import absolute_import | |||
| from __future__ import division | |||
| from __future__ import unicode_literals | |||
| from __future__ import print_function | |||
| import numpy as np | |||
| import torch | |||
| def compute_metrics(x): | |||
| sx = np.sort(-x, axis=1) | |||
| d = np.diag(-x) | |||
| d = d[:, np.newaxis] | |||
| ind = sx - d | |||
| ind = np.where(ind == 0) | |||
| ind = ind[1] | |||
| metrics = {} | |||
| metrics['R@1'] = float(np.sum(ind == 0)) * 100 / len(ind) | |||
| metrics['R@5'] = float(np.sum(ind < 5)) * 100 / len(ind) | |||
| metrics['R@10'] = float(np.sum(ind < 10)) * 100 / len(ind) | |||
| metrics["MedianR"] = np.median(ind) + 1 | |||
| metrics["MeanR"] = np.mean(ind) + 1 | |||
| # metrics["cols"] = [int(i) for i in list(ind)] | |||
| return metrics | |||
| def print_computed_metrics(metrics): | |||
| r1 = metrics['R@1'] | |||
| r5 = metrics['R@5'] | |||
| r10 = metrics['R@10'] | |||
| mr = metrics['MR'] | |||
| print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr)) | |||
| # below two functions directly come from: https://github.com/Deferf/Experiments | |||
| def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]): | |||
| if not torch.is_tensor(sim_tensor): | |||
| sim_tensor = torch.tensor(sim_tensor) | |||
| # Permute sim_tensor so it represents a sequence of text-video similarity matrices. | |||
| # Then obtain the double argsort to position the rank on the diagonal | |||
| stacked_sim_matrices = sim_tensor.permute(1, 0, 2) | |||
| first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True) | |||
| second_argsort = torch.argsort(first_argsort, dim = -1, descending= False) | |||
| # Extracts ranks i.e diagonals | |||
| ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2)) | |||
| # Now we need to extract valid ranks, as some belong to inf padding values | |||
| permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2)) | |||
| mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) | |||
| valid_ranks = ranks[mask] | |||
| # A quick dimension check validates our results, there may be other correctness tests pending | |||
| # Such as dot product localization, but that is for other time. | |||
| #assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) | |||
| if not torch.is_tensor(valid_ranks): | |||
| valid_ranks = torch.tensor(valid_ranks) | |||
| results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} | |||
| results["MedianR"] = float(torch.median(valid_ranks + 1)) | |||
| results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) | |||
| results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) | |||
| results['MR'] = results["MedianR"] | |||
| return results | |||
| def tensor_video_to_text_sim(sim_tensor): | |||
| if not torch.is_tensor(sim_tensor): | |||
| sim_tensor = torch.tensor(sim_tensor) | |||
| # Code to avoid nans | |||
| sim_tensor[sim_tensor != sim_tensor] = float('-inf') | |||
| # Forms a similarity matrix for use with rank at k | |||
| values, _ = torch.max(sim_tensor, dim=1, keepdim=True) | |||
| return torch.squeeze(values).T | |||