| import argparse | |||||
| import os | |||||
| import torch | |||||
| media_path_root = "path/to/media" | |||||
| class Config: | |||||
| def __init__(self): | |||||
| self.parser = argparse.ArgumentParser(description='Train a student network.') | |||||
| self.add_arguments() | |||||
| self.args = self.parse() | |||||
| self.update_paths() | |||||
| def add_arguments(self): | |||||
| # General settings | |||||
| self.parser.add_argument('--seed', type=int, default=12345, help='Init Seed') | |||||
| self.parser.add_argument('--device', type=int, default=0, help='Device number to use for training') | |||||
| self.parser.add_argument("--do_train", type=int, default=1, help="Whether to run training.") | |||||
| self.parser.add_argument("--do_eval", type=int, default=1, help="Whether to run evaluation.") | |||||
| self.parser.add_argument('--dataset', type=str, default='activity-net', choices=['activity-net', 'msrvtt'], help='Dataset to use') | |||||
| self.parser.add_argument('--teacher_model', type=str, default='CLIP', choices=['CLIP', 'DiffusionRet', 'EMCL', 'HBI', 'DiCoSA'], help='Teacher model to use') | |||||
| self.parser.add_argument('--workers', type=int, default=12, help='Number of workers') | |||||
| # Model configurations | |||||
| self.parser.add_argument('--dropout_ratio', type=float, default=0.3, help='Dropout ratio for the model') | |||||
| self.parser.add_argument('--fc1_size', type=int, default=512, help='First FC layer Hidden size') | |||||
| self.parser.add_argument('--fc2_size', type=int, default=256, help='Second FC layer Hidden size') | |||||
| self.parser.add_argument('--fc3_size', type=int, default=128, help='Third FC layer Hidden size') | |||||
| self.parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'], help='Backbone model') | |||||
| self.parser.add_argument('--n_trainable_backbone_layers', type=int, default=1, help='Number of trainable backbone layers') | |||||
| self.parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'leakyrelu', 'gelu'], help='Activation function') | |||||
| self.parser.add_argument('--normalization_layer', type=str, default='sigmoid', choices=['sigmoid', 'softmax', 'gumbel', 'k_top'], help='Normalization layer type (for text-to-video retireval pipeline)') | |||||
| self.parser.add_argument('--init_tau', default=5.0, type=float, help="annealing init temperature") | |||||
| self.parser.add_argument('--min_tau', default=0.5, type=float, help="min temperature to anneal to") | |||||
| self.parser.add_argument('--decay_factor', default=0.045, type=float, help="exp decay factor per epoch") | |||||
| self.parser.add_argument('--store_complete_model', type=int, default=0, help='Store the compelte model') | |||||
| self.parser.add_argument('--load_complete_model', type=int, default=0, help='1: Load') | |||||
| self.parser.add_argument('--trained_model_epoch', type=int, default=100, help='Load the trained model from which epoch') | |||||
| self.parser.add_argument('--store_backbone', type=int, default=0, help='Store just backbone') | |||||
| self.parser.add_argument('--load_backbone', type=int, default=1, help='1: Load') | |||||
| self.parser.add_argument('--trained_backbone_epoch', type=int, default=100, help='Load a trained backbone model from which epoch') | |||||
| # Dataloader configurations | |||||
| self.parser.add_argument('--normalization', type=str, default='binary', choices=['raw', 'min_max', 'binary'], help='Normalization type for dataloader') | |||||
| self.parser.add_argument('--binary_percentage', type=float, default=0.1, help='Binary percentage for dataloader') | |||||
| # Training configurations | |||||
| self.parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training') | |||||
| self.parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for training') | |||||
| self.parser.add_argument('--similarity_matrix_loss', type=str, default='kl', choices=['mse','kl'], help='frame2frame loss') | |||||
| self.parser.add_argument('--use_weighted_loss', type=int, default=0, help='Whether use weighted BCE, to handle unbalanced classes') | |||||
| self.parser.add_argument('--alpha', type=float, default=1, help='Weight for f2f loss') | |||||
| self.parser.add_argument('--beta', type=float, default=1, help='Weight for f2t loss') | |||||
| self.parser.add_argument('--gamma', type=float, default=1, help='Weight for v2t loss') | |||||
| self.parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') | |||||
| self.parser.add_argument('--use_scheduler', type=int, default=1, help='Use a learning rate scheduler') | |||||
| self.parser.add_argument('--scheduler_step_size', type=float, default=5, help='After how many epochs the learning rate will be adjusted') | |||||
| self.parser.add_argument('--scheduler_gamma', type=float, default=0.4, help='The factor by which the learning rate will be reduced at each step.') | |||||
| self.parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay for the optimizer. Set 0 to disable it') | |||||
| self.parser.add_argument('--do_cluster', type=int, default=1, help='Calculate results with cluster as well or not.') | |||||
| def parse(self): | |||||
| return self.parser.parse_args() | |||||
| def update_paths(self): | |||||
| self.args.device = torch.device(f'cuda:{self.args.device}' if torch.cuda.is_available() else "cpu") | |||||
| self.args.raw_backbone_path = f"{media_path_root}/saved_models/student_pretrained_backbones/{self.args.backbone}_pretrained.pth" | |||||
| self.args.st_scores_path = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/st_scores/st_scores.json" | |||||
| self.args.trained_backbone_path = f"{media_path_root}/experiments/saved_models/{self.args.backbone}/{self.args.similarity_matrix_loss}" | |||||
| self.args.trained_model_path = f"{self.args.trained_backbone_path}/complete_model/" | |||||
| self.args.log_dir = f'logs/{self.args.project_name}/' | |||||
| os.makedirs(self.args.st_scores_path, exist_ok=True) | |||||
| os.makedirs(self.args.trained_backbone_path, exist_ok=True) | |||||
| os.makedirs(self.args.trained_model_path, exist_ok=True) | |||||
| os.makedirs(self.args.log_dir, exist_ok=True) | |||||
| self.args.paths = {'train':{}, 'valid':{}} | |||||
| for data_split in ['train', 'valid']: | |||||
| split = 'val_1' if data_split=='valid' else data_split | |||||
| self.args.paths[data_split]['frames'] = f"{media_path_root}/{self.args.dataset}/sampled_frames/{split}" | |||||
| self.args.paths[data_split]['frame_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/frames/{split}" | |||||
| self.args.paths[data_split]['text_features'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/embeddings/description/{split}" | |||||
| self.args.paths[data_split]['gt_scores'] = f"{media_path_root}/{self.args.dataset}/{self.args.teacher_model}/gt_scores/gt_{split}.json" |
| import cv2 | |||||
| from PIL import Image | |||||
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode | |||||
| class RawVideoExtractor(): | |||||
| def __init__(self, centercrop=False, framerate=-1, size=224, to_tensor=True): | |||||
| self.centercrop = centercrop | |||||
| self.framerate = framerate | |||||
| self.to_tensor = to_tensor | |||||
| self.transform = self._transform(size) | |||||
| def _transform(self, n_px): | |||||
| if self.to_tensor: | |||||
| return Compose([ | |||||
| Resize(n_px, interpolation=InterpolationMode.BICUBIC), | |||||
| CenterCrop(n_px), ToTensor(), | |||||
| Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
| (0.26862954, 0.26130258, 0.27577711))]) | |||||
| else: | |||||
| return Compose([Resize(n_px, interpolation=InterpolationMode.BICUBIC),CenterCrop(n_px)]) | |||||
| def get_video_data(self, video_path, start_time=None, end_time=None): | |||||
| if start_time is not None or end_time is not None: | |||||
| assert start_time > -1 and end_time > start_time | |||||
| assert self.framerate > -1 | |||||
| cap = cv2.VideoCapture(video_path) | |||||
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |||||
| video_fps = cap.get(cv2.CAP_PROP_FPS) | |||||
| start_frame = int(start_time * video_fps) if start_time else 0 | |||||
| end_frame = int(end_time * video_fps) if end_time else frame_count - 1 | |||||
| interval = 1 | |||||
| if self.framerate > 0: | |||||
| interval = video_fps / self.framerate | |||||
| else: | |||||
| self.framerate = video_fps | |||||
| if interval == 0: | |||||
| interval = 1 | |||||
| images = [] | |||||
| for i in range(frame_count): | |||||
| ret, frame = cap.read() | |||||
| if not ret: | |||||
| break | |||||
| if i >= start_frame and i <= end_frame: | |||||
| if len(images) * interval < i - start_frame: | |||||
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |||||
| image = Image.fromarray(frame_rgb) | |||||
| image = self.transform(image) | |||||
| images.append(image) | |||||
| cap.release() | |||||
| return images |
| import os | |||||
| import json | |||||
| import torch | |||||
| from torchvision.io import read_image | |||||
| from torch.utils.data import Dataset | |||||
| from torchvision import transforms | |||||
| class VideoDescriptionDataloader(Dataset): | |||||
| def __init__(self, cfg, data_type): | |||||
| self.frames_dir = cfg.paths[data_type]['frames'] | |||||
| self.frame_features_dir = cfg.paths[data_type]['frame_features'] | |||||
| self.text_features_dir = cfg.paths[data_type]['text_features'] | |||||
| self.gt_scores_file_path = cfg.paths[data_type]['gt_scores'] | |||||
| self.transform = self._get_transforms() | |||||
| self.binary_percentage = cfg.binary_percentage | |||||
| if cfg.normalization == 'min_max' or data_type=='valid': | |||||
| self.normalize_scores = self._min_max_normalize_scores | |||||
| elif cfg.normalization == 'binary': | |||||
| self.normalize_scores = self._to_binary_labels | |||||
| else: | |||||
| raise ValueError(f"Unsupported normalization method: {cfg.normalization}") | |||||
| with open(self.gt_scores_file_path, "r") as f: | |||||
| self.gt_scores = json.load(f) | |||||
| self.video_ids = list(self.gt_scores.keys()) | |||||
| def __len__(self): | |||||
| return len(self.video_ids) | |||||
| def __getitem__(self, idx): | |||||
| video_id = self.video_ids[idx] | |||||
| video_frames_dir = os.path.join(self.frames_dir, self.video_ids[idx]) | |||||
| frames_tensor = torch.stack([self.transform(read_image(os.path.join(video_frames_dir, frame_file)).float()) | |||||
| for frame_file in sorted(os.listdir(video_frames_dir))]) | |||||
| text_features = torch.load(os.path.join(self.text_features_dir, f'{self.video_ids[idx]}.pt')) | |||||
| frame_features = torch.load(os.path.join(self.frame_features_dir, f'{self.video_ids[idx]}.pt')) | |||||
| gt_score = torch.tensor(self.gt_scores[self.video_ids[idx]], dtype=torch.float32) | |||||
| gt_score = self.normalize_scores(gt_score) | |||||
| return frames_tensor, text_features, frame_features, gt_score | |||||
| @staticmethod | |||||
| def _get_transforms(): | |||||
| return transforms.Compose([transforms.Normalize([0.485, 0.456, 0.406], | |||||
| [0.229, 0.224, 0.225])]) | |||||
| @staticmethod | |||||
| def _min_max_normalize_scores(score): | |||||
| min_val, max_val = score.min(), score.max() | |||||
| if min_val != max_val: | |||||
| return (score - min_val) / (max_val - min_val) | |||||
| return torch.full_like(score, 0.5) | |||||
| def _to_binary_labels(self, score): | |||||
| num_top_elements = max(int(len(score) * self.binary_percentage) , 1) | |||||
| sorted_indices = score.argsort(descending=True) | |||||
| binary_labels = torch.zeros_like(score) | |||||
| binary_labels[sorted_indices[:num_top_elements]] = 1 | |||||
| return binary_labels |
| video6513 | |||||
| video6514 | |||||
| video6515 | |||||
| video6516 | |||||
| video6517 | |||||
| video6518 | |||||
| video6519 | |||||
| video6520 | |||||
| video6521 | |||||
| video6522 | |||||
| video6523 | |||||
| video6524 | |||||
| video6525 | |||||
| video6526 | |||||
| video6527 | |||||
| video6528 | |||||
| video6529 | |||||
| video6530 | |||||
| video6531 | |||||
| video6532 | |||||
| video6533 | |||||
| video6534 | |||||
| video6535 | |||||
| video6536 | |||||
| video6537 | |||||
| video6538 | |||||
| video6539 | |||||
| video6540 | |||||
| video6541 | |||||
| video6542 | |||||
| video6543 | |||||
| video6544 | |||||
| video6545 | |||||
| video6546 | |||||
| video6547 | |||||
| video6548 | |||||
| video6549 | |||||
| video6550 | |||||
| video6551 | |||||
| video6552 | |||||
| video6553 | |||||
| video6554 | |||||
| video6555 | |||||
| video6556 | |||||
| video6557 | |||||
| video6558 | |||||
| video6559 | |||||
| video6560 | |||||
| video6561 | |||||
| video6562 | |||||
| video6563 | |||||
| video6564 | |||||
| video6565 | |||||
| video6566 | |||||
| video6567 | |||||
| video6568 | |||||
| video6569 | |||||
| video6570 | |||||
| video6571 | |||||
| video6572 | |||||
| video6573 | |||||
| video6574 | |||||
| video6575 | |||||
| video6576 | |||||
| video6577 | |||||
| video6578 | |||||
| video6579 | |||||
| video6580 | |||||
| video6581 | |||||
| video6582 | |||||
| video6583 | |||||
| video6584 | |||||
| video6585 | |||||
| video6586 | |||||
| video6587 | |||||
| video6588 | |||||
| video6589 | |||||
| video6590 | |||||
| video6591 | |||||
| video6592 | |||||
| video6593 | |||||
| video6594 | |||||
| video6595 | |||||
| video6596 | |||||
| video6597 | |||||
| video6598 | |||||
| video6599 | |||||
| video6600 | |||||
| video6601 | |||||
| video6602 | |||||
| video6603 | |||||
| video6604 | |||||
| video6605 | |||||
| video6606 | |||||
| video6607 | |||||
| video6608 | |||||
| video6609 | |||||
| video6610 | |||||
| video6611 | |||||
| video6612 | |||||
| video6613 | |||||
| video6614 | |||||
| video6615 | |||||
| video6616 | |||||
| video6617 | |||||
| video6618 | |||||
| video6619 | |||||
| video6620 | |||||
| video6621 | |||||
| video6622 | |||||
| video6623 | |||||
| video6624 | |||||
| video6625 | |||||
| video6626 | |||||
| video6627 | |||||
| video6628 | |||||
| video6629 | |||||
| video6630 | |||||
| video6631 | |||||
| video6632 | |||||
| video6633 | |||||
| video6634 | |||||
| video6635 | |||||
| video6636 | |||||
| video6637 | |||||
| video6638 | |||||
| video6639 | |||||
| video6640 | |||||
| video6641 | |||||
| video6642 | |||||
| video6643 | |||||
| video6644 | |||||
| video6645 | |||||
| video6646 | |||||
| video6647 | |||||
| video6648 | |||||
| video6649 | |||||
| video6650 | |||||
| video6651 | |||||
| video6652 | |||||
| video6653 | |||||
| video6654 | |||||
| video6655 | |||||
| video6656 | |||||
| video6657 | |||||
| video6658 | |||||
| video6659 | |||||
| video6660 | |||||
| video6661 | |||||
| video6662 | |||||
| video6663 | |||||
| video6664 | |||||
| video6665 | |||||
| video6666 | |||||
| video6667 | |||||
| video6668 | |||||
| video6669 | |||||
| video6670 | |||||
| video6671 | |||||
| video6672 | |||||
| video6673 | |||||
| video6674 | |||||
| video6675 | |||||
| video6676 | |||||
| video6677 | |||||
| video6678 | |||||
| video6679 | |||||
| video6680 | |||||
| video6681 | |||||
| video6682 | |||||
| video6683 | |||||
| video6684 | |||||
| video6685 | |||||
| video6686 | |||||
| video6687 | |||||
| video6688 | |||||
| video6689 | |||||
| video6690 | |||||
| video6691 | |||||
| video6692 | |||||
| video6693 | |||||
| video6694 | |||||
| video6695 | |||||
| video6696 | |||||
| video6697 | |||||
| video6698 | |||||
| video6699 | |||||
| video6700 | |||||
| video6701 | |||||
| video6702 | |||||
| video6703 | |||||
| video6704 | |||||
| video6705 | |||||
| video6706 | |||||
| video6707 | |||||
| video6708 | |||||
| video6709 | |||||
| video6710 | |||||
| video6711 | |||||
| video6712 | |||||
| video6713 | |||||
| video6714 | |||||
| video6715 | |||||
| video6716 | |||||
| video6717 | |||||
| video6718 | |||||
| video6719 | |||||
| video6720 | |||||
| video6721 | |||||
| video6722 | |||||
| video6723 | |||||
| video6724 | |||||
| video6725 | |||||
| video6726 | |||||
| video6727 | |||||
| video6728 | |||||
| video6729 | |||||
| video6730 | |||||
| video6731 | |||||
| video6732 | |||||
| video6733 | |||||
| video6734 | |||||
| video6735 | |||||
| video6736 | |||||
| video6737 | |||||
| video6738 | |||||
| video6739 | |||||
| video6740 | |||||
| video6741 | |||||
| video6742 | |||||
| video6743 | |||||
| video6744 | |||||
| video6745 | |||||
| video6746 | |||||
| video6747 | |||||
| video6748 | |||||
| video6749 | |||||
| video6750 | |||||
| video6751 | |||||
| video6752 | |||||
| video6753 | |||||
| video6754 | |||||
| video6755 | |||||
| video6756 | |||||
| video6757 | |||||
| video6758 | |||||
| video6759 | |||||
| video6760 | |||||
| video6761 | |||||
| video6762 | |||||
| video6763 | |||||
| video6764 | |||||
| video6765 | |||||
| video6766 | |||||
| video6767 | |||||
| video6768 | |||||
| video6769 | |||||
| video6770 | |||||
| video6771 | |||||
| video6772 | |||||
| video6773 | |||||
| video6774 | |||||
| video6775 | |||||
| video6776 | |||||
| video6777 | |||||
| video6778 | |||||
| video6779 | |||||
| video6780 | |||||
| video6781 | |||||
| video6782 | |||||
| video6783 | |||||
| video6784 | |||||
| video6785 | |||||
| video6786 | |||||
| video6787 | |||||
| video6788 | |||||
| video6789 | |||||
| video6790 | |||||
| video6791 | |||||
| video6792 | |||||
| video6793 | |||||
| video6794 | |||||
| video6795 | |||||
| video6796 | |||||
| video6797 | |||||
| video6798 | |||||
| video6799 | |||||
| video6800 | |||||
| video6801 | |||||
| video6802 | |||||
| video6803 | |||||
| video6804 | |||||
| video6805 | |||||
| video6806 | |||||
| video6807 | |||||
| video6808 | |||||
| video6809 | |||||
| video6810 | |||||
| video6811 | |||||
| video6812 | |||||
| video6813 | |||||
| video6814 | |||||
| video6815 | |||||
| video6816 | |||||
| video6817 | |||||
| video6818 | |||||
| video6819 | |||||
| video6820 | |||||
| video6821 | |||||
| video6822 | |||||
| video6823 | |||||
| video6824 | |||||
| video6825 | |||||
| video6826 | |||||
| video6827 | |||||
| video6828 | |||||
| video6829 | |||||
| video6830 | |||||
| video6831 | |||||
| video6832 | |||||
| video6833 | |||||
| video6834 | |||||
| video6835 | |||||
| video6836 | |||||
| video6837 | |||||
| video6838 | |||||
| video6839 | |||||
| video6840 | |||||
| video6841 | |||||
| video6842 | |||||
| video6843 | |||||
| video6844 | |||||
| video6845 | |||||
| video6846 | |||||
| video6847 | |||||
| video6848 | |||||
| video6849 | |||||
| video6850 | |||||
| video6851 | |||||
| video6852 | |||||
| video6853 | |||||
| video6854 | |||||
| video6855 | |||||
| video6856 | |||||
| video6857 | |||||
| video6858 | |||||
| video6859 | |||||
| video6860 | |||||
| video6861 | |||||
| video6862 | |||||
| video6863 | |||||
| video6864 | |||||
| video6865 | |||||
| video6866 | |||||
| video6867 | |||||
| video6868 | |||||
| video6869 | |||||
| video6870 | |||||
| video6871 | |||||
| video6872 | |||||
| video6873 | |||||
| video6874 | |||||
| video6875 | |||||
| video6876 | |||||
| video6877 | |||||
| video6878 | |||||
| video6879 | |||||
| video6880 | |||||
| video6881 | |||||
| video6882 | |||||
| video6883 | |||||
| video6884 | |||||
| video6885 | |||||
| video6886 | |||||
| video6887 | |||||
| video6888 | |||||
| video6889 | |||||
| video6890 | |||||
| video6891 | |||||
| video6892 | |||||
| video6893 | |||||
| video6894 | |||||
| video6895 | |||||
| video6896 | |||||
| video6897 | |||||
| video6898 | |||||
| video6899 | |||||
| video6900 | |||||
| video6901 | |||||
| video6902 | |||||
| video6903 | |||||
| video6904 | |||||
| video6905 | |||||
| video6906 | |||||
| video6907 | |||||
| video6908 | |||||
| video6909 | |||||
| video6910 | |||||
| video6911 | |||||
| video6912 | |||||
| video6913 | |||||
| video6914 | |||||
| video6915 | |||||
| video6916 | |||||
| video6917 | |||||
| video6918 | |||||
| video6919 | |||||
| video6920 | |||||
| video6921 | |||||
| video6922 | |||||
| video6923 | |||||
| video6924 | |||||
| video6925 | |||||
| video6926 | |||||
| video6927 | |||||
| video6928 | |||||
| video6929 | |||||
| video6930 | |||||
| video6931 | |||||
| video6932 | |||||
| video6933 | |||||
| video6934 | |||||
| video6935 | |||||
| video6936 | |||||
| video6937 | |||||
| video6938 | |||||
| video6939 | |||||
| video6940 | |||||
| video6941 | |||||
| video6942 | |||||
| video6943 | |||||
| video6944 | |||||
| video6945 | |||||
| video6946 | |||||
| video6947 | |||||
| video6948 | |||||
| video6949 | |||||
| video6950 | |||||
| video6951 | |||||
| video6952 | |||||
| video6953 | |||||
| video6954 | |||||
| video6955 | |||||
| video6956 | |||||
| video6957 | |||||
| video6958 | |||||
| video6959 | |||||
| video6960 | |||||
| video6961 | |||||
| video6962 | |||||
| video6963 | |||||
| video6964 | |||||
| video6965 | |||||
| video6966 | |||||
| video6967 | |||||
| video6968 | |||||
| video6969 | |||||
| video6970 | |||||
| video6971 | |||||
| video6972 | |||||
| video6973 | |||||
| video6974 | |||||
| video6975 | |||||
| video6976 | |||||
| video6977 | |||||
| video6978 | |||||
| video6979 | |||||
| video6980 | |||||
| video6981 | |||||
| video6982 | |||||
| video6983 | |||||
| video6984 | |||||
| video6985 | |||||
| video6986 | |||||
| video6987 | |||||
| video6988 | |||||
| video6989 | |||||
| video6990 | |||||
| video6991 | |||||
| video6992 | |||||
| video6993 | |||||
| video6994 | |||||
| video6995 | |||||
| video6996 | |||||
| video6997 | |||||
| video6998 | |||||
| video6999 | |||||
| video7000 | |||||
| video7001 | |||||
| video7002 | |||||
| video7003 | |||||
| video7004 | |||||
| video7005 | |||||
| video7006 | |||||
| video7007 | |||||
| video7008 | |||||
| video7009 |
| import os | |||||
| import torch | |||||
| from config import Config | |||||
| from src.model import SaliencyNet | |||||
| from torch.utils.data import DataLoader | |||||
| from src.train import SalientFrameSamplerTrainer | |||||
| from dataloaders.t2v_dataloader import VideoDescriptionDataloader | |||||
| def load_weights(path, which_epoch): | |||||
| weights, last_epoch = None, None | |||||
| available_models = [name for name in os.listdir(path) if name.endswith(".pt")] | |||||
| if available_models: | |||||
| last_epoch = max([int(name[6:-3]) for name in available_models]) | |||||
| last_epoch = min(last_epoch, which_epoch) | |||||
| weights = torch.load(os.path.join(path, f'epoch_{last_epoch}.pt')) | |||||
| return weights, last_epoch | |||||
| def set_seeds(seed: int): | |||||
| os.environ['PYTHONHASHSEED'] = str(seed) | |||||
| torch.manual_seed(seed) | |||||
| torch.cuda.manual_seed(seed) | |||||
| torch.backends.cudnn.benchmark = False | |||||
| torch.backends.cudnn.deterministic = True | |||||
| def main(cfg): | |||||
| set_seeds(cfg.seed) | |||||
| model = SaliencyNet(cfg) | |||||
| model.to(cfg.device) | |||||
| if cfg.load_complete_model: | |||||
| weights, last_epoch = load_weights(cfg.trained_model_path, cfg.trained_model_epoch) | |||||
| if weights: | |||||
| model.load_state_dict(weights) | |||||
| print(f'Complete model -{cfg.backbone} trained on {cfg.similarity_matrix_loss} and {cfg.saliency_matching_loss} lossese from epoch #{last_epoch}') | |||||
| elif cfg.load_backbone: | |||||
| weights, last_epoch = load_weights(cfg.trained_backbone_path, cfg.trained_backbone_epoch) | |||||
| if weights: | |||||
| model.pretrained.load_state_dict(weights) | |||||
| print(f'{cfg.backbone} backbone trained on {cfg.similarity_matrix_loss} loss loaded from epoch #{last_epoch}') | |||||
| else: | |||||
| print(f'{cfg.backbone} backbone loaded from scratch.') | |||||
| train_dataset = VideoDescriptionDataloader(cfg, data_type='train') | |||||
| val_dataset = VideoDescriptionDataloader(cfg, data_type='valid') | |||||
| train_dataloader = DataLoader(train_dataset, shuffle=True, num_workers=cfg.workers) | |||||
| val_dataloader = DataLoader(val_dataset, shuffle=False, num_workers=cfg.workers) | |||||
| trainer = SalientFrameSamplerTrainer(model, train_dataloader, val_dataloader, cfg) | |||||
| trainer.train() | |||||
| if __name__ == "__main__": | |||||
| torch.multiprocessing.set_start_method("spawn") | |||||
| cfg = Config().args | |||||
| main(cfg) |
| torch | |||||
| opencv-python | |||||
| Pillow | |||||
| timm | |||||
| transformers | |||||
| torchvision | |||||
| scikit-learn |
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| class GumbelSoftmaxTopK(nn.Module): | |||||
| def __init__(self, k, tau): | |||||
| super(GumbelSoftmaxTopK, self).__init__() | |||||
| self.k = k | |||||
| self.tau = tau | |||||
| def forward(self, logits): | |||||
| logits = torch.log(nn.Softmax(dim=0)(logits).clamp(min=1e-8)) | |||||
| """ | |||||
| Adapted from: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/DL2/sampling/subsets.html#Subset-Sampler-Class | |||||
| and https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax | |||||
| """ | |||||
| m = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format), | |||||
| torch.ones_like(logits, memory_format= torch.legacy_contiguous_format)) | |||||
| gumbels = m.sample() | |||||
| y = logits + gumbels | |||||
| _EPS = torch.tensor(1e-40).to(logits.device) | |||||
| khot = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) | |||||
| onehot_approx = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format) | |||||
| for i in range(self.k): | |||||
| khot_mask = torch.maximum(1.0 - onehot_approx, _EPS) | |||||
| y += khot_mask.log() | |||||
| onehot_approx = nn.Softmax(dim=0)(y / self.tau) | |||||
| khot = torch.add(khot, onehot_approx) | |||||
| return khot | |||||
| class GumbelSoftmax(nn.Module): | |||||
| def __init__(self, tau, hard=False): | |||||
| super(GumbelSoftmax, self).__init__() | |||||
| self.tau = tau | |||||
| self.hard = hard | |||||
| def forward(self, logits): | |||||
| m = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format), | |||||
| torch.ones_like(logits, memory_format= torch.legacy_contiguous_format)) | |||||
| gumbels = m.sample() | |||||
| y = logits + gumbels | |||||
| y_soft = nn.Softmax(dim=0)(y / self.tau) | |||||
| return y_soft |
| import timm | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from src.gumbel_softmax import GumbelSoftmax, GumbelSoftmaxTopK | |||||
| class SaliencyNet(nn.Module): | |||||
| def __init__(self, cfg): | |||||
| super(SaliencyNet, self).__init__() | |||||
| self.k = 8 | |||||
| self.tau = cfg.init_tau | |||||
| self.min_tau = cfg.min_tau | |||||
| self.decay_factor = cfg.decay_factor | |||||
| self.raw_backbone_path = cfg.raw_backbone_path | |||||
| self.configure_pretrained(cfg.backbone, cfg.n_trainable_backbone_layers) | |||||
| embed_dim = self.pretrained.num_features | |||||
| self.fc1 = nn.Linear(embed_dim, cfg.fc1_size) | |||||
| self.fc2 = nn.Linear(cfg.fc1_size, cfg.fc2_size) | |||||
| self.fc3 = nn.Linear(cfg.fc2_size, cfg.fc3_size) | |||||
| self.fc4 = nn.Linear(cfg.fc3_size, 1) | |||||
| # self.bn1 = nn.BatchNorm1d(fc1_size) # Batch normalization layer | |||||
| self.dropout = nn.Dropout(cfg.dropout_ratio) | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| if cfg.activation == 'relu': | |||||
| self.activation = nn.ReLU() | |||||
| elif cfg.activation == 'leakyrelu': | |||||
| self.activation = nn.LeakyReLU() | |||||
| elif cfg.activation == 'gelu': | |||||
| self.activation = nn.GELU() | |||||
| else: | |||||
| raise ValueError("Invalid activation type. Choose 'relu', 'leakyrelu', or 'gelu'.") | |||||
| self.normalization_layer = self.init_normalization_layer(cfg.normalization_layer, self.tau, self.k) | |||||
| self.initialize_weights() | |||||
| def init_normalization_layer(self, normalization_type, tau, k): | |||||
| if normalization_type == 'raw': | |||||
| normalization_layer = lambda x: x | |||||
| elif normalization_type == 'sigmoid': | |||||
| normalization_layer = nn.Sigmoid() | |||||
| elif normalization_type == 'softmax': | |||||
| normalization_layer = nn.Softmax(dim=1) | |||||
| elif normalization_type == 'gumbel': | |||||
| normalization_layer = GumbelSoftmax(tau) | |||||
| elif normalization_type == 'k_top': | |||||
| normalization_layer = GumbelSoftmaxTopK(k, tau) | |||||
| return normalization_layer | |||||
| def initialize_weights(self): | |||||
| for module in self.modules(): | |||||
| if isinstance(module, nn.Linear): | |||||
| nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |||||
| if module.bias is not None: | |||||
| nn.init.constant_(module.bias, 0) | |||||
| def forward(self, preprocessed_images): | |||||
| features = self.pretrained(preprocessed_images) | |||||
| x = self.activation(self.fc1(features)) | |||||
| x = self.dropout(x) | |||||
| x = self.activation(self.fc2(x)) | |||||
| x = self.dropout(x) | |||||
| x = self.activation(self.fc3(x)) | |||||
| x = self.fc4(x) | |||||
| frame_scores = torch.flatten(x) | |||||
| frame_weights = torch.flatten(self.normalization_layer(x)) | |||||
| return frame_scores, frame_weights, features | |||||
| def anneal_tau(self): | |||||
| """ Anneals the temperature parameter τ using exponential decay """ | |||||
| self.tau = self.tau * torch.exp(-self.decay_factor * torch.tensor(1.0)) | |||||
| self.tau = max(self.tau, self.min_tau) # To prevent τ from becoming too small | |||||
| def configure_pretrained(self, backbone, n_trainable_backbone_layers=1): | |||||
| # backbones = ['resnet18', 'efficientnet_b0', 'mobilenetv2_100', 'mobilenetv3_large_100'] | |||||
| self.pretrained = timm.create_model(backbone, pretrained=False, checkpoint_path=self.raw_backbone_path) | |||||
| # Freeze the early layers of the model | |||||
| for param in self.pretrained.parameters(): | |||||
| param.requires_grad = False | |||||
| # disable last layer of resnet18 | |||||
| if backbone =='resnet18': | |||||
| self.pretrained.fc = nn.Identity() | |||||
| sequential_modules = [child for child in self.pretrained.children() if isinstance(child, nn.Sequential)] | |||||
| elif backbone in ['efficientnet_b0', 'mobilenetv2_100','mobilenetv3_large_100'] : | |||||
| self.pretrained.classifier = nn.Identity() | |||||
| sequential_modules = [child for child in self.pretrained.blocks.children()] | |||||
| else: | |||||
| assert('Not supported backbone model!') | |||||
| num_seq_blocks = len(sequential_modules) | |||||
| if n_trainable_backbone_layers ==0: | |||||
| pass | |||||
| elif n_trainable_backbone_layers >= num_seq_blocks: | |||||
| for param in self.pretrained.parameters(): | |||||
| param.requires_grad = True | |||||
| else: | |||||
| # Select the last `n_trainable_backbone_layers` of sequential modules to be trainable | |||||
| layers_to_train = sequential_modules[-n_trainable_backbone_layers:] | |||||
| for layer in layers_to_train: | |||||
| for param in layer.parameters(): | |||||
| param.requires_grad = True | |||||
| if n_trainable_backbone_layers > 0: | |||||
| last_layers= [] | |||||
| if hasattr(self.pretrained, 'conv_head'): | |||||
| last_layers.append(self.pretrained.conv_head.weight) | |||||
| if hasattr(self.pretrained, 'bn2'): | |||||
| last_layers.append(self.pretrained.bn2.weight) | |||||
| last_layers.append(self.pretrained.bn2.bias) | |||||
| for param in last_layers: | |||||
| param.requires_grad = True |
| import os | |||||
| import torch | |||||
| from tqdm import tqdm, trange | |||||
| import torch.nn as nn | |||||
| import logging | |||||
| import torch.nn.functional as F | |||||
| from tabulate import tabulate | |||||
| from torch.utils.data import DataLoader | |||||
| from utils.cluster_frames import VideoClusterer | |||||
| from utils.metrics import compute_metrics | |||||
| class SalientFrameSamplerTrainer: | |||||
| def __init__(self, model, train_dataloader: DataLoader, val_dataloader: DataLoader, cfg): | |||||
| self.cfg = cfg | |||||
| self.model = model | |||||
| self.train_dataloader = train_dataloader | |||||
| self.val_dataloader = val_dataloader | |||||
| self.init_logging() | |||||
| self.criterion1 = kl_divergence_loss() | |||||
| self.criterion2 = nn.BCEWithLogitsLoss() | |||||
| self.criterion3 = CrossEn() | |||||
| self.optimizer = self.init_optimizer() | |||||
| if self.cfg.use_scheduler: | |||||
| self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.cfg.scheduler_step_size, gamma=self.cfg.scheduler_gamma) | |||||
| def train(self): | |||||
| for epoch in trange(self.cfg.epochs, desc="Epoch", ncols=100): | |||||
| log_data = {} | |||||
| if self.cfg.do_train: | |||||
| train_loss = self.run_training_epoch() | |||||
| log_data.update({ | |||||
| "train_loss(f2f)": train_loss['f2f'], | |||||
| "train_loss(f2t)": train_loss['f2t'], | |||||
| "train_loss(v2t)": train_loss['v2t'], | |||||
| }) | |||||
| if self.cfg.do_eval: | |||||
| valid_loss, t2v_metrics, v2t_metrics = self.run_validation_epoch() | |||||
| log_data.update({ | |||||
| "valid_loss(f2f)": valid_loss['f2f'], | |||||
| "valid_loss(f2t)": valid_loss['f2t'], | |||||
| "valid_loss(v2t)": valid_loss['v2t'], | |||||
| }) | |||||
| for key in t2v_metrics.keys(): | |||||
| for r in [1, 5, 10]: | |||||
| log_data[f"[t2v] R@{r} (k={key})"] = t2v_metrics[key][f'R@{r}'] | |||||
| log_data[f"[v2t] R@{r} (k={key})"] = v2t_metrics[key][f'R@{r}'] | |||||
| self.log_metrics(epoch, log_data) | |||||
| self.model.anneal_tau() | |||||
| if hasattr(self,'scheduler'): | |||||
| self.scheduler.step() | |||||
| if self.cfg.store_backbone: | |||||
| torch.save(self.model.pretrained.state_dict(), os.path.join(self.cfg.trained_backbone_path, f"epoch_{epoch+1}.pt")) | |||||
| if self.cfg.store_complete_model: | |||||
| torch.save(self.model.state_dict(), os.path.join(self.args.trained_model_path, f"epoch_{epoch+1}.pt")) | |||||
| def run_training_epoch(self): | |||||
| self.model.train() | |||||
| epoch_train_loss = {'f2f': 0.0, 'f2t': 0.0, 'v2t': 0.0} | |||||
| video_embeddings = [] | |||||
| text_embeddings = [] | |||||
| f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||||
| for i, (frames, teacher_text_features, teacher_frame_features, gt_score) in enumerate(tqdm(self.train_dataloader, desc='Train', leave=False, ncols=100)): | |||||
| torch.cuda.empty_cache() | |||||
| frames = frames[0].to(self.cfg.device) | |||||
| teacher_frame_features = teacher_frame_features[0].to(self.cfg.device) | |||||
| teacher_text_features = teacher_text_features[0].to(self.cfg.device) | |||||
| gt_score = gt_score[0].to(self.cfg.device) | |||||
| frame_scores, frame_weights, student_frame_features = self.model(frames) | |||||
| # Frame-to-Frame loss | |||||
| f2f_loss = self.criterion1(student_frame_features, teacher_frame_features) | |||||
| # Frame-to-Text loss | |||||
| f2t_loss = self.criterion2(frame_scores, gt_score) | |||||
| # Update accumulators | |||||
| f2f_loss_accum += f2f_loss.item() / self.cfg.batch_size | |||||
| f2t_loss_accum += f2t_loss.item() / self.cfg.batch_size | |||||
| # Calculating the in-pipeline loss for v2t | |||||
| video_features = torch.matmul(teacher_frame_features.T, frame_weights) / torch.sum(frame_weights) | |||||
| video_embeddings.append(video_features) | |||||
| text_embeddings.append(teacher_text_features) | |||||
| if (i + 1) % self.cfg.batch_size == 0 or i == len(self.train_dataloader) - 1: | |||||
| sim_matrix = self.compute_sim_matrix(video_embeddings, text_embeddings) | |||||
| sim_loss1 = self.criterion3(sim_matrix) | |||||
| sim_loss2 = self.criterion3(sim_matrix.T) | |||||
| v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||||
| # Aggregate the total loss with regularization terms | |||||
| total_loss = self.alpha * f2f_loss_accum + self.beta * f2t_loss_accum + self.gamma * v2t_loss | |||||
| self.optimizer.zero_grad() | |||||
| total_loss.backward() | |||||
| # Gradient clipping | |||||
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |||||
| self.optimizer.step() | |||||
| # Update epoch losses | |||||
| epoch_train_loss['f2f'] += f2f_loss_accum | |||||
| epoch_train_loss['f2t'] += f2t_loss_accum | |||||
| epoch_train_loss['v2t'] += v2t_loss.item() | |||||
| # Reset accumulators | |||||
| f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||||
| video_embeddings = [] | |||||
| text_embeddings = [] | |||||
| # Normalize epoch losses by the number of batches | |||||
| for key in epoch_train_loss.keys(): | |||||
| epoch_train_loss[key] /= len(self.train_dataloader) | |||||
| return epoch_train_loss | |||||
| def run_validation_epoch(self): | |||||
| self.model.eval() | |||||
| epoch_valid_loss = {'f2f': 0.0, 'f2t': 0.0, 'v2t': 0.0} | |||||
| video_embeddings = [] | |||||
| text_embeddings = [] | |||||
| all_video_embeddings = {str(k): {'without_clustering': [], | |||||
| 'uniform': [], 'spectral': [], 'agglomerative': [], 'kmeans': []} for k in [1,2,4,8,16]} | |||||
| all_text_embeddings = [] | |||||
| all_st_scores = [] | |||||
| f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||||
| for i, (frames, teacher_text_features, teacher_frame_features, gt_score) in enumerate(tqdm(self.val_dataloader, desc='Valid', leave=False, ncols=100)): | |||||
| torch.cuda.empty_cache() | |||||
| frames = frames[0].to(self.cfg.device) | |||||
| teacher_frame_features = teacher_frame_features[0].to(self.cfg.device) | |||||
| teacher_text_features = teacher_text_features[0].to(self.cfg.device) | |||||
| gt_score = gt_score[0].to(self.cfg.device) | |||||
| with torch.no_grad(): | |||||
| frame_scores, frame_weights, student_frame_features = self.model(frames) | |||||
| f2f_loss = self.criterion1(student_frame_features, teacher_frame_features) | |||||
| f2t_loss = self.criterion2(frame_scores, gt_score) | |||||
| f2f_loss_accum += f2f_loss.item() | |||||
| f2t_loss_accum += f2t_loss.item() | |||||
| # Calculating the all embeddings, used for calculating t2v_retrieval metrics | |||||
| all_text_embeddings.append(teacher_text_features) | |||||
| n_frames = teacher_frame_features.size(0) | |||||
| sorted_indices = sorted(range(n_frames), key=lambda i: frame_scores[i], reverse=True).copy() | |||||
| for k in [1,2,4,8,16]: | |||||
| n_selected_frames = min(n_frames, k) | |||||
| indices = sorted_indices[:n_selected_frames] | |||||
| all_video_embeddings[str(k)]['without_clustering'].append(teacher_frame_features[indices].mean(dim=0)) | |||||
| if self.cfg.do_cluster and k>2: | |||||
| for clustering_method in ['uniform', 'spectral', 'agglomerative', 'kmeans']: | |||||
| if len(frame_scores) <= k: #mean of all frames | |||||
| all_video_embeddings[str(k)][clustering_method].append(teacher_frame_features.mean(dim=0)) | |||||
| else: | |||||
| clusterer = VideoClusterer(clustering_method, k) | |||||
| clusters = clusterer.get_clusters(student_frame_features) | |||||
| selected_indices = self.select_representative_frame_per_cluster(clusters, frame_scores) | |||||
| all_video_embeddings[str(k)][clustering_method].append(teacher_frame_features[selected_indices].mean(dim=0)) | |||||
| # Calculating the in-pipeline loss | |||||
| video_features = torch.matmul(teacher_frame_features.T, frame_weights) / torch.sum(frame_weights) | |||||
| video_embeddings.append(video_features) | |||||
| text_embeddings.append(teacher_text_features) | |||||
| if (i + 1) % self.cfg.batch_size == 0 or i == len(self.val_dataloader) - 1: | |||||
| sim_matrix = self.compute_sim_matrix(video_embeddings, text_embeddings) | |||||
| sim_loss1 = self.criterion3(sim_matrix) | |||||
| sim_loss2 = self.criterion3(sim_matrix.T) | |||||
| v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||||
| v2t_loss = (sim_loss1 + sim_loss2) / 2 | |||||
| # Update epoch losses | |||||
| epoch_valid_loss['f2f'] += f2f_loss_accum | |||||
| epoch_valid_loss['f2t'] += f2t_loss_accum | |||||
| epoch_valid_loss['v2t'] += v2t_loss.item() | |||||
| # Reset accumulators | |||||
| f2f_loss_accum, f2t_loss_accum = 0.0, 0.0 | |||||
| video_embeddings = [] | |||||
| text_embeddings = [] | |||||
| # Normalize epoch losses by the number of batches | |||||
| for key in epoch_valid_loss.keys(): | |||||
| epoch_valid_loss[key] /= len(self.val_dataloader) | |||||
| # Compute t2v metrics | |||||
| t2v_metrics = {} | |||||
| v2t_metrics = {} | |||||
| for k, embeddings_dict in all_video_embeddings.items(): | |||||
| t2v_metrics[str(k)]={} | |||||
| for scenario, embeddings_list in embeddings_dict.items(): | |||||
| if embeddings_list: # Check if there are any embeddings for this scenario | |||||
| sim_matrix = self.compute_sim_matrix(embeddings_list, all_text_embeddings) | |||||
| _key = f"{k}" if scenario=='without_clustering' else f"{k}-{scenario}" | |||||
| t2v_metrics[_key] = compute_metrics(sim_matrix.detach().cpu().numpy()) | |||||
| v2t_metrics[_key] = compute_metrics(sim_matrix.T.detach().cpu().numpy()) | |||||
| return epoch_valid_loss, t2v_metrics, v2t_metrics | |||||
| def compute_sim_matrix(self, video_embeddings, text_embeddings): | |||||
| video_embeddings_tensor = torch.stack(video_embeddings, dim=0) | |||||
| text_embeddings_tensor = torch.stack(text_embeddings, dim=0) | |||||
| video_embeddings_norm = video_embeddings_tensor / video_embeddings_tensor.norm(dim=-1, keepdim=True) | |||||
| text_embeddings_norm = text_embeddings_tensor / text_embeddings_tensor.norm(dim=-1, keepdim=True) | |||||
| sim_matrix = torch.matmul(video_embeddings_norm, text_embeddings_norm.T) | |||||
| return sim_matrix | |||||
| def init_optimizer(self): | |||||
| if not hasattr(self.cfg, 'backbone_lr_scale'): | |||||
| return torch.optim.Adam(params=self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay) | |||||
| backbone_params = [] | |||||
| last_params = [] | |||||
| for name, param in self.model.named_parameters(): | |||||
| if 'pretrain' in name: | |||||
| backbone_params.append(param) | |||||
| else: | |||||
| last_params.append(param) | |||||
| return torch.optim.Adam([ | |||||
| {'params': backbone_params, 'lr': self.cfg.lr}, | |||||
| {'params': last_params, 'lr': self.cfg.lr * self.cfg.backbone_lr_scale} | |||||
| ], weight_decay=self.cfg.weight_decay) | |||||
| def init_logging(self): | |||||
| available_logs = os.listdir(self.cfg.log_dir) | |||||
| last_run = 0 | |||||
| if available_logs: | |||||
| last_run = max([int(name.split('-')[0]) for name in available_logs]) | |||||
| log_file_name = f"{last_run + 1}-local-run.log" | |||||
| log_file = os.path.join(self.cfg.log_dir, log_file_name) | |||||
| logging.basicConfig(filename=log_file, level=logging.INFO, | |||||
| format='%(asctime)s [%(levelname)s] - %(message)s', | |||||
| datefmt='%Y-%m-%d %H:%M:%S') | |||||
| hyperparameters = {key: value for key, value in vars(self.cfg).items()} | |||||
| hyperparameters_str = "\n".join([f"{key}: {value}" for key, value in hyperparameters.items() if ('path' not in key and 'log_dir' not in key)]) | |||||
| logging.info("Hyperparameters:\n" + hyperparameters_str) | |||||
| def log_metrics(self, epoch, log_data): | |||||
| # Loss Metrics Table | |||||
| loss_header = ["", "f2f", "f2t", "v2t"] | |||||
| # Check if train_loss exists, otherwise assign "N/A" | |||||
| train_loss_f2f = log_data.get("train_loss(f2f)", "N/A") | |||||
| train_loss_f2t = log_data.get("train_loss(f2t)", "N/A") | |||||
| train_loss_v2t = log_data.get("train_loss(v2t)", "N/A") | |||||
| # Check if valid_loss exists, otherwise assign "N/A" | |||||
| valid_loss_f2f = log_data.get("valid_loss(f2f)", "N/A") | |||||
| valid_loss_f2t = log_data.get("valid_loss(f2t)", "N/A") | |||||
| valid_loss_v2t = log_data.get("valid_loss(v2t)", "N/A") | |||||
| loss_rows = [ | |||||
| ["train", train_loss_f2f, train_loss_f2t, train_loss_v2t], | |||||
| ["val", valid_loss_f2f, valid_loss_f2t, valid_loss_v2t], | |||||
| ] | |||||
| loss_table = tabulate(loss_rows, headers=loss_header, tablefmt="grid") | |||||
| # Recall Metrics Table | |||||
| recall_header = ["", "R@1", "R@5", "R@10"] | |||||
| clustering_names = ['uniform', 'spectral', 'agglomerative', 'kmeans'] | |||||
| recall_rows_t2v = [] | |||||
| recall_rows_v2t = [] | |||||
| for k in [1, 2, 4, 8, 16]: | |||||
| # Without clustering | |||||
| base_key = f"k={k}" | |||||
| if any(f"[t2v] R@{r} ({base_key})" in log_data for r in [1, 5, 10]): | |||||
| row = [f"{k}"] | |||||
| for r in [1, 5, 10]: | |||||
| metric_key = f"[t2v] R@{r} ({base_key})" | |||||
| row.append(log_data.get(metric_key, "N/A")) | |||||
| recall_rows_t2v.append(row) | |||||
| if any(f"[v2t] R@{r} ({base_key})" in log_data for r in [1, 5, 10]): | |||||
| row = [f"{k}"] | |||||
| for r in [1, 5, 10]: | |||||
| metric_key = f"[v2t] R@{r} ({base_key})" | |||||
| row.append(log_data.get(metric_key, "N/A")) | |||||
| recall_rows_v2t.append(row) | |||||
| # With clustering | |||||
| for clustering_name in clustering_names: | |||||
| cluster_key = f"{base_key}-{clustering_name}" | |||||
| if any(f"[t2v] R@{r} ({cluster_key})" in log_data for r in [1, 5, 10]): | |||||
| row = [f"{k}-{clustering_name}"] | |||||
| for r in [1, 5, 10]: | |||||
| metric_key = f"[t2v] R@{r} ({cluster_key})" | |||||
| row.append(log_data.get(metric_key, "N/A")) | |||||
| recall_rows_t2v.append(row) | |||||
| if any(f"[v2t] R@{r} ({cluster_key})" in log_data for r in [1, 5, 10]): | |||||
| row = [f"{k}-{clustering_name}"] | |||||
| for r in [1, 5, 10]: | |||||
| metric_key = f"[v2t] R@{r} ({cluster_key})" | |||||
| row.append(log_data.get(metric_key, "N/A")) | |||||
| recall_rows_v2t.append(row) | |||||
| recall_table_t2v = tabulate(recall_rows_t2v, headers=recall_header, tablefmt="grid") | |||||
| recall_table_v2t = tabulate(recall_rows_v2t, headers=recall_header, tablefmt="grid") | |||||
| logging.info(f"Epoch {epoch+1} - Loss and Recall Metrics:") | |||||
| logging.info("\n" + loss_table) | |||||
| logging.info("\nText-to-Video Recall Metrics:") | |||||
| logging.info("\n" + recall_table_t2v) | |||||
| logging.info("\nVideo-to-Text Recall Metrics:") | |||||
| logging.info("\n" + recall_table_v2t) | |||||
| def select_representative_frame_per_cluster(self, clusters, frame_scores): | |||||
| representative_frames = [] | |||||
| for cluster in clusters: | |||||
| best_frame_index = max(cluster, key=lambda index: frame_scores[index]) | |||||
| representative_frames.append(int(best_frame_index)) | |||||
| return representative_frames | |||||
| class kl_divergence_loss(nn.Module): | |||||
| def __init__(self,): | |||||
| super(kl_divergence_loss, self).__init__() | |||||
| def forward(self, features1, features2): | |||||
| features1 = F.normalize(features1, p=2, dim=-1) | |||||
| features2 = F.normalize(features2, p=2, dim=-1) | |||||
| cos_sim_features1 = torch.mm(features1, features1.t()) | |||||
| cos_sim_features2 = torch.mm(features2, features2.t()) | |||||
| probs_features1 = F.softmax(cos_sim_features1, dim=-1) | |||||
| probs_features2 = F.softmax(cos_sim_features2, dim=-1) | |||||
| loss = F.kl_div(probs_features1.log(), probs_features2, reduction='batchmean') | |||||
| return loss | |||||
| class CrossEn(nn.Module): | |||||
| def __init__(self,): | |||||
| super(CrossEn, self).__init__() | |||||
| def forward(self, sim_matrix): | |||||
| logpt = F.log_softmax(sim_matrix, dim=-1) | |||||
| logpt = torch.diag(logpt) | |||||
| nce_loss = -logpt | |||||
| sim_loss = nce_loss.mean() | |||||
| return sim_loss |
| from sklearn.cluster import SpectralClustering, AgglomerativeClustering, KMeans | |||||
| import numpy as np | |||||
| from sklearn.metrics.pairwise import cosine_similarity | |||||
| class VideoClusterer: | |||||
| def __init__(self, clustering_method='uniform', n_clusters=2, similarity_threshold=0.8): | |||||
| self.n_clusters = n_clusters | |||||
| self.similarity_threshold = similarity_threshold | |||||
| self.clustering_method = clustering_method | |||||
| # Decide on the clustering method to use | |||||
| if clustering_method == 'uniform': | |||||
| self.clusterer = self.uniform_clustering | |||||
| elif clustering_method == 'spectral': | |||||
| self.clusterer = SpectralClustering(n_clusters=n_clusters, affinity='precomputed') | |||||
| elif clustering_method == 'agglomerative': | |||||
| self.clusterer = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward') | |||||
| elif clustering_method == 'kmeans': | |||||
| self.clusterer = KMeans(n_clusters=n_clusters, n_init=1) | |||||
| else: | |||||
| raise ValueError(f"Invalid clustering method: {clustering_method}") | |||||
| def uniform_clustering(self, features): | |||||
| n = len(features) | |||||
| clusters = [] | |||||
| cluster_size = n // self.n_clusters | |||||
| remainder = n % self.n_clusters | |||||
| start = 0 | |||||
| for i in range(self.n_clusters): | |||||
| if i < remainder: | |||||
| end = start + cluster_size + 1 | |||||
| else: | |||||
| end = start + cluster_size | |||||
| clusters.append(list(range(start, end))) | |||||
| start = end | |||||
| return clusters | |||||
| def detect_outliers(self, features): | |||||
| dot_product_matrix = features.dot(features.T) | |||||
| average_similarities = np.mean(dot_product_matrix, axis=0) | |||||
| # Adding a small constant epsilon to the standard deviation to prevent division by zero | |||||
| epsilon = 1e-8 | |||||
| normal = (average_similarities - np.mean(average_similarities)) / (np.std(average_similarities) + epsilon) | |||||
| outlier_mask = np.logical_or(normal > 1.5, normal < -1.5) | |||||
| return outlier_mask | |||||
| def get_clusters(self, features): | |||||
| features = features.cpu().numpy() | |||||
| if self.clustering_method == 'uniform': | |||||
| return self.uniform_clustering(features) | |||||
| else: | |||||
| # For non-uniform methods, follow the original procedure | |||||
| outlier_mask = self.detect_outliers(features) | |||||
| if np.sum(~outlier_mask) > self.n_clusters: | |||||
| features = features[~outlier_mask] | |||||
| # Compute cosine similarity matrix for spectral clustering | |||||
| if self.clustering_method == 'spectral': | |||||
| similarity_matrix = cosine_similarity(features) | |||||
| labels = self.clusterer.fit_predict(similarity_matrix) | |||||
| else: | |||||
| # For agglomerative, k-means, and other clustering methods that don't require a precomputed matrix | |||||
| labels = self.clusterer.fit_predict(features) | |||||
| # Organize frames into clusters based on labels | |||||
| clusters = [[] for _ in range(self.n_clusters)] | |||||
| for idx, label in enumerate(labels): | |||||
| clusters[label].append(idx) | |||||
| return clusters |
| from __future__ import absolute_import | |||||
| from __future__ import division | |||||
| from __future__ import unicode_literals | |||||
| from __future__ import print_function | |||||
| import numpy as np | |||||
| import torch | |||||
| def compute_metrics(x): | |||||
| sx = np.sort(-x, axis=1) | |||||
| d = np.diag(-x) | |||||
| d = d[:, np.newaxis] | |||||
| ind = sx - d | |||||
| ind = np.where(ind == 0) | |||||
| ind = ind[1] | |||||
| metrics = {} | |||||
| metrics['R@1'] = float(np.sum(ind == 0)) * 100 / len(ind) | |||||
| metrics['R@5'] = float(np.sum(ind < 5)) * 100 / len(ind) | |||||
| metrics['R@10'] = float(np.sum(ind < 10)) * 100 / len(ind) | |||||
| metrics["MedianR"] = np.median(ind) + 1 | |||||
| metrics["MeanR"] = np.mean(ind) + 1 | |||||
| # metrics["cols"] = [int(i) for i in list(ind)] | |||||
| return metrics | |||||
| def print_computed_metrics(metrics): | |||||
| r1 = metrics['R@1'] | |||||
| r5 = metrics['R@5'] | |||||
| r10 = metrics['R@10'] | |||||
| mr = metrics['MR'] | |||||
| print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr)) | |||||
| # below two functions directly come from: https://github.com/Deferf/Experiments | |||||
| def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10]): | |||||
| if not torch.is_tensor(sim_tensor): | |||||
| sim_tensor = torch.tensor(sim_tensor) | |||||
| # Permute sim_tensor so it represents a sequence of text-video similarity matrices. | |||||
| # Then obtain the double argsort to position the rank on the diagonal | |||||
| stacked_sim_matrices = sim_tensor.permute(1, 0, 2) | |||||
| first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True) | |||||
| second_argsort = torch.argsort(first_argsort, dim = -1, descending= False) | |||||
| # Extracts ranks i.e diagonals | |||||
| ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2)) | |||||
| # Now we need to extract valid ranks, as some belong to inf padding values | |||||
| permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2)) | |||||
| mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) | |||||
| valid_ranks = ranks[mask] | |||||
| # A quick dimension check validates our results, there may be other correctness tests pending | |||||
| # Such as dot product localization, but that is for other time. | |||||
| #assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) | |||||
| if not torch.is_tensor(valid_ranks): | |||||
| valid_ranks = torch.tensor(valid_ranks) | |||||
| results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} | |||||
| results["MedianR"] = float(torch.median(valid_ranks + 1)) | |||||
| results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) | |||||
| results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) | |||||
| results['MR'] = results["MedianR"] | |||||
| return results | |||||
| def tensor_video_to_text_sim(sim_tensor): | |||||
| if not torch.is_tensor(sim_tensor): | |||||
| sim_tensor = torch.tensor(sim_tensor) | |||||
| # Code to avoid nans | |||||
| sim_tensor[sim_tensor != sim_tensor] = float('-inf') | |||||
| # Forms a similarity matrix for use with rank at k | |||||
| values, _ = torch.max(sim_tensor, dim=1, keepdim=True) | |||||
| return torch.squeeze(values).T |