venv/ | |||||
**/videos/ | |||||
**/__pycache__/ | |||||
**/__pycache__/ |
{ | |||||
"[python]": { | |||||
"editor.defaultFormatter": "ms-python.autopep8" | |||||
}, | |||||
"python.formatting.provider": "none", | |||||
"cSpell.words": ["embedder", "pytube"] | |||||
} |
# Video Action Recognition Using Transfer Learning and Attention Mechanisms | |||||
This project focuses on video action recognition using deep learning techniques, leveraging transfer learning from language models and attention mechanisms. | |||||
## Getting Started | |||||
### 1. Dataset Preparation | |||||
1.1. Download the Kinetics dataset: | |||||
- Use `save_kinetics_dataset.ipynb` to download the dataset. | |||||
- Alternatively, you can use `download_k400.ipynb`. | |||||
1.2. Save the dataset: | |||||
- Store the downloaded dataset in your Google Drive for easy access. | |||||
### 2. Label Preprocessing | |||||
2.1. Update Kinetics labels: | |||||
- Run `preprocess_kinetics_labels.ipynb`. | |||||
- This script uses GPT-4 to generate detailed descriptions for each video action. | |||||
### 3. Model Training | |||||
3.1. Post-pretraining of VideoMAE: | |||||
- Execute `postpretrain_VideoMAE_to_CLIP_Space.ipynb`. | |||||
- This notebook trains a transformer layer to map VideoMAE embeddings to CLIP space. | |||||
### 4. Testing | |||||
4.1. Prepare the test dataset: | |||||
- Download the UCF101 dataset. | |||||
- Update the UCF101 labels using GPT-4, similar to the Kinetics label preprocessing step. | |||||
4.2. Run the test: | |||||
- Use `test.ipynb` to evaluate the model's performance. | |||||
## Prerequisites | |||||
- Python 3.x | |||||
- Jupyter Notebook | |||||
- PyTorch | |||||
- Transformers library | |||||
- CLIP model | |||||
- VideoMAE model | |||||
- Access to GPT-4 API for label preprocessing | |||||
- Google Drive (for storing datasets) | |||||
## Usage | |||||
1. Follow the steps in the "Getting Started" section to prepare your data and train the model. | |||||
2. Ensure all datasets are properly saved in your Google Drive. | |||||
3. Run the notebooks in the order specified above. | |||||
4. For testing, make sure you have the UCF101 dataset prepared and labels updated before running `test.ipynb`. | |||||
The model processes multiple frames from a video scene and creates rich representations in the CLIP space. | |||||
## Future Work | |||||
- Implement an adaptive frame selection unit | |||||
- Extend to more diverse datasets | |||||
- Integrate multimodal inputs (e.g., audio) | |||||
- Fine-tune hyperparameters |
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"colab": { | |||||
"base_uri": "https://localhost:8080/" | |||||
}, | |||||
"id": "ZGGB7jSjF1RD", | |||||
"outputId": "95aaefdd-a00e-46d9-86f9-556ab6ff53c2" | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Downloading...\n", | |||||
"From: https://drive.google.com/uc?id=18AzsP1DOBECBL3vjzUhokro40mtu4Ggn\n", | |||||
"To: /content/vit_s_k710_dl_from_giant.pth\n", | |||||
"100% 44.3M/44.3M [00:00<00:00, 76.7MB/s]\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"!gdown 18AzsP1DOBECBL3vjzUhokro40mtu4Ggn # vit_s_k710_dl_from_giant.pth" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"colab": { | |||||
"base_uri": "https://localhost:8080/" | |||||
}, | |||||
"id": "G9r7PONyCbMO", | |||||
"outputId": "6758cef5-cac8-45a9-f424-3525dd38a25b" | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Requirement already satisfied: decord in /usr/local/lib/python3.10/dist-packages (0.6.0)\n", | |||||
"Requirement already satisfied: deepspeed in /usr/local/lib/python3.10/dist-packages (0.12.2)\n", | |||||
"Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (0.7.0)\n", | |||||
"Requirement already satisfied: timm==0.4.12 in /usr/local/lib/python3.10/dist-packages (0.4.12)\n", | |||||
"Requirement already satisfied: tensorboardX in /usr/local/lib/python3.10/dist-packages (2.6.2.2)\n", | |||||
"Requirement already satisfied: mpi4py in /usr/local/lib/python3.10/dist-packages (3.1.5)\n", | |||||
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.0)\n", | |||||
"Requirement already satisfied: torch>=1.4 in /usr/local/lib/python3.10/dist-packages (from timm==0.4.12) (2.1.0+cu118)\n", | |||||
"Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from timm==0.4.12) (0.16.0+cu118)\n", | |||||
"Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from decord) (1.23.5)\n", | |||||
"Requirement already satisfied: hjson in /usr/local/lib/python3.10/dist-packages (from deepspeed) (3.1.0)\n", | |||||
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from deepspeed) (1.11.1.1)\n", | |||||
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from deepspeed) (23.2)\n", | |||||
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from deepspeed) (5.9.5)\n", | |||||
"Requirement already satisfied: py-cpuinfo in /usr/local/lib/python3.10/dist-packages (from deepspeed) (9.0.0)\n", | |||||
"Requirement already satisfied: pydantic in /usr/local/lib/python3.10/dist-packages (from deepspeed) (1.10.13)\n", | |||||
"Requirement already satisfied: pynvml in /usr/local/lib/python3.10/dist-packages (from deepspeed) (11.5.0)\n", | |||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from deepspeed) (4.66.1)\n", | |||||
"Requirement already satisfied: protobuf>=3.20 in /usr/local/lib/python3.10/dist-packages (from tensorboardX) (3.20.3)\n", | |||||
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.4)\n", | |||||
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.17.3)\n", | |||||
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", | |||||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", | |||||
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", | |||||
"Requirement already satisfied: tokenizers<0.15,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.14.1)\n", | |||||
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.0)\n", | |||||
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n", | |||||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n", | |||||
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.4->timm==0.4.12) (1.12)\n", | |||||
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.4->timm==0.4.12) (3.2)\n", | |||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4->timm==0.4.12) (3.1.2)\n", | |||||
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4->timm==0.4.12) (2.1.0)\n", | |||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.1)\n", | |||||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n", | |||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", | |||||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n", | |||||
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision->timm==0.4.12) (9.4.0)\n", | |||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.4->timm==0.4.12) (2.1.3)\n", | |||||
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.4->timm==0.4.12) (1.3.0)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"!pip install decord deepspeed einops timm==0.4.12 tensorboardX mpi4py transformers" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"colab": { | |||||
"base_uri": "https://localhost:8080/" | |||||
}, | |||||
"id": "6qwB9RtaEVl5", | |||||
"outputId": "43f46c83-0f1d-404e-a1af-b08a5af05f4a" | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"/content\n", | |||||
"fatal: destination path 'VideoMAEv2' already exists and is not an empty directory.\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"%cd /content\n", | |||||
"!git clone https://github.com/OpenGVLab/VideoMAEv2" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "Bwj7eNpwGLDI" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"class_to_label = {}\n", | |||||
"\n", | |||||
"with open(\"/content/ucfTrainTestlist/classInd.txt\", 'r') as f:\n", | |||||
" for line in f:\n", | |||||
" label, class_name = line.split()\n", | |||||
" class_to_label[class_name] = label\n", | |||||
"\n", | |||||
"modes = ['train', 'test']\n", | |||||
"\n", | |||||
"for mode in modes:\n", | |||||
" files = [\n", | |||||
" f\"/content/ucfTrainTestlist/{mode}list01.txt\",\n", | |||||
" f\"/content/ucfTrainTestlist/{mode}list02.txt\",\n", | |||||
" f\"/content/ucfTrainTestlist/{mode}list03.txt\",\n", | |||||
" ]\n", | |||||
"\n", | |||||
" output_file_path = f\"/content/UCF101/{'val' if mode == 'test' else mode}.csv\"\n", | |||||
"\n", | |||||
" with open(output_file_path, 'w') as outfile:\n", | |||||
" for file_path in files:\n", | |||||
" with open(file_path, 'r') as infile:\n", | |||||
" for line in infile:\n", | |||||
" line_text = line.strip()\n", | |||||
" if mode == 'train':\n", | |||||
" video_path, label = line_text.split(' ')\n", | |||||
" else:\n", | |||||
" video_path = line_text\n", | |||||
" class_name = os.path.dirname(line_text)\n", | |||||
" label = class_to_label[class_name]\n", | |||||
" outfile.write(f\"/content/UCF101/{video_path} {label}\\n\")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"colab": { | |||||
"base_uri": "https://localhost:8080/" | |||||
}, | |||||
"id": "7Tfx84ZrEfFB", | |||||
"outputId": "b25afa77-92f1-4b69-c312-e21012794ba3" | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"/content/VideoMAEv2\n", | |||||
"/content\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"%cd /content/VideoMAEv2\n", | |||||
"from dataset.datasets import VideoClsDataset\n", | |||||
"\n", | |||||
"from models.modeling_finetune import VisionTransformer, _cfg\n", | |||||
"from utils import (\n", | |||||
" load_state_dict,\n", | |||||
")\n", | |||||
"from optim_factory import (\n", | |||||
" LayerDecayValueAssigner,\n", | |||||
" get_parameter_groups,\n", | |||||
")\n", | |||||
"%cd /content/" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"colab": { | |||||
"base_uri": "https://localhost:8080/" | |||||
}, | |||||
"id": "WG9MzWkQI5tP", | |||||
"outputId": "40e59861-2ca7-400b-9d02-43285a68a85b" | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"[2023-11-05 21:20:09,304] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"import os\n", | |||||
"import sys\n", | |||||
"import json\n", | |||||
"import warnings\n", | |||||
"import math\n", | |||||
"import argparse\n", | |||||
"import logging\n", | |||||
"import random\n", | |||||
"import gc\n", | |||||
"import tqdm\n", | |||||
"\n", | |||||
"from collections import OrderedDict\n", | |||||
"\n", | |||||
"import numpy as np\n", | |||||
"import pandas as pd\n", | |||||
"import deepspeed\n", | |||||
"import torch\n", | |||||
"import torch.nn as nn\n", | |||||
"import torch.nn.functional as F\n", | |||||
"import torch.backends.cudnn as cudnn\n", | |||||
"import torch.utils.checkpoint as cp\n", | |||||
"from torch.utils.data import Dataset\n", | |||||
"from torch.utils.data._utils.collate import default_collate\n", | |||||
"from torchvision import transforms\n", | |||||
"from timm.models import create_model\n", | |||||
"from timm.models.layers import trunc_normal_\n", | |||||
"from timm.models.registry import register_model\n", | |||||
"from timm.loss import SoftTargetCrossEntropy\n", | |||||
"from functools import partial\n", | |||||
"from datetime import datetime\n", | |||||
"\n", | |||||
"from transformers import AutoTokenizer, CLIPModel" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "yFhdWFvhbLwF" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"class EnhancedVisionTransformer(VisionTransformer):\n", | |||||
" def get_embeddings(self, x):\n", | |||||
" B, _, T, H, W = x.shape\n", | |||||
"\n", | |||||
" x = self.patch_embed(x)\n", | |||||
"\n", | |||||
" if self.pos_embed is not None:\n", | |||||
" x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()\n", | |||||
" x = self.pos_drop(x)\n", | |||||
"\n", | |||||
" for blk in self.blocks:\n", | |||||
" if self.with_cp:\n", | |||||
" x = cp.checkpoint(blk, x)\n", | |||||
" else:\n", | |||||
" x = blk(x)\n", | |||||
"\n", | |||||
" B, num_patches, embed_dim = x.shape\n", | |||||
"\n", | |||||
" T = T // self.tubelet_size\n", | |||||
" x = x.view(B, T, num_patches // T, embed_dim)\n", | |||||
" x = x.reshape(B, T, -1)\n", | |||||
"\n", | |||||
" return x\n", | |||||
"\n", | |||||
"@register_model\n", | |||||
"def vit_small_patch16_224(pretrained=False, **kwargs):\n", | |||||
" model = EnhancedVisionTransformer(\n", | |||||
" patch_size=16,\n", | |||||
" embed_dim=384,\n", | |||||
" depth=12,\n", | |||||
" num_heads=6,\n", | |||||
" mlp_ratio=4,\n", | |||||
" qkv_bias=True,\n", | |||||
" norm_layer=partial(nn.LayerNorm, eps=1e-6),\n", | |||||
" **kwargs)\n", | |||||
" model.default_cfg = _cfg()\n", | |||||
" return model" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "gRgLU0IzIcxQ" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"def add_space_before_uppercase(s):\n", | |||||
" result = [s[0]]\n", | |||||
"\n", | |||||
" for char in s[1:]:\n", | |||||
" if char.isupper():\n", | |||||
" result.append(' ')\n", | |||||
" result.append(char)\n", | |||||
"\n", | |||||
" return ''.join(result)\n", | |||||
"\n", | |||||
"class EnhancedVideoClsDataset(VideoClsDataset):\n", | |||||
" def __getitem__(self, index):\n", | |||||
" original_data = super().__getitem__(index)\n", | |||||
"\n", | |||||
" video_filename = self.dataset_samples[index].split('/')[-1].split('.')[0].split('_')[1]\n", | |||||
"\n", | |||||
" label = add_space_before_uppercase(video_filename)\n", | |||||
"\n", | |||||
" return (*original_data, ''.join(label))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "WhYEWCxiJBMa" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"args = argparse.Namespace()\n", | |||||
"args.model = 'vit_small_patch16_224'\n", | |||||
"args.data_set = 'UCF101'\n", | |||||
"args.nb_classes = 101\n", | |||||
"args.data_path = '/content/UCF101/'\n", | |||||
"args.finetune = '/content/vit_s_k710_dl_from_giant.pth'\n", | |||||
"args.batch_size = 6\n", | |||||
"args.input_size = 224\n", | |||||
"args.short_side_size = 224\n", | |||||
"args.num_frames = 16\n", | |||||
"args.sampling_rate = 10\n", | |||||
"args.num_sample = 2\n", | |||||
"args.num_workers = 2\n", | |||||
"args.opt = 'adamw'\n", | |||||
"args.opt_eps = 1e-8\n", | |||||
"args.opt_betas = [0.9, 0.999]\n", | |||||
"args.lr = 1e-3\n", | |||||
"args.min_lr = 1e-6\n", | |||||
"args.drop = 0.0\n", | |||||
"args.attn_drop_rate = 0.0\n", | |||||
"args.drop_path = 0.35\n", | |||||
"args.clip_grad = None # 5.0\n", | |||||
"args.aa = 'rand-m7-n4-mstd0.5-inc1'\n", | |||||
"args.layer_decay = 0.92 # 0.9\n", | |||||
"args.weight_decay = 0.06 # 0.05\n", | |||||
"args.epochs = 5\n", | |||||
"\n", | |||||
"args.tubelet_size = 2\n", | |||||
"args.with_checkpoint = True\n", | |||||
"args.train_interpolation = 'bicubic'\n", | |||||
"args.reprob = 0.25\n", | |||||
"args.remode = 'pixel'\n", | |||||
"args.recount = 1\n", | |||||
"args.data_root = ''\n", | |||||
"\n", | |||||
"args.num_segments = 1\n", | |||||
"\n", | |||||
"args.start_epoch = 0\n", | |||||
"\n", | |||||
"args.pin_mem = True" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "rhsNVcWuKMm0" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |||||
"\n", | |||||
"torch.manual_seed(0)\n", | |||||
"np.random.seed(0)\n", | |||||
"random.seed(0)\n", | |||||
"cudnn.benchmark = True" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "eRRzjmFsMM0T" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"def multiple_samples_collate(batch):\n", | |||||
" inputs, labels, video_idx, extra_data, video_file_names = zip(*batch)\n", | |||||
" inputs = [item for sublist in inputs for item in sublist]\n", | |||||
" labels = [item for sublist in labels for item in sublist]\n", | |||||
" video_idx = [item for sublist in video_idx for item in sublist]\n", | |||||
" inputs, labels, video_idx, extra_data, video_file_names = (\n", | |||||
" default_collate(inputs),\n", | |||||
" default_collate(labels),\n", | |||||
" default_collate(video_idx),\n", | |||||
" default_collate(extra_data),\n", | |||||
" default_collate(video_file_names),\n", | |||||
" )\n", | |||||
" return inputs, labels, video_idx, extra_data, video_file_names" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "Twu9hFl_KhbN" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"train_dataset = EnhancedVideoClsDataset(\n", | |||||
" anno_path=os.path.join(args.data_path, 'train.csv'),\n", | |||||
" data_root=args.data_root,\n", | |||||
" clip_len=args.num_frames,\n", | |||||
" frame_sample_rate=args.sampling_rate,\n", | |||||
" num_segment=1,\n", | |||||
" crop_size=args.input_size,\n", | |||||
" short_side_size=args.short_side_size,\n", | |||||
" mode='train',\n", | |||||
" args=args)\n", | |||||
"\n", | |||||
"\n", | |||||
"data_loader_train = torch.utils.data.DataLoader(\n", | |||||
" train_dataset,\n", | |||||
" batch_size=args.batch_size,\n", | |||||
" num_workers=args.num_workers,\n", | |||||
" pin_memory=args.pin_mem,\n", | |||||
" shuffle=True,\n", | |||||
" drop_last=True,\n", | |||||
" collate_fn=partial(multiple_samples_collate),\n", | |||||
" persistent_workers=True)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "eWxUWHrlLBYN" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"class FrameScorePredictor(nn.Module):\n", | |||||
" def __init__(self, create_model_fn, args):\n", | |||||
" super().__init__()\n", | |||||
"\n", | |||||
"\n", | |||||
" self.base_model = create_model_fn(\n", | |||||
" args.model,\n", | |||||
" img_size=args.input_size,\n", | |||||
" pretrained=False,\n", | |||||
" all_frames=args.num_frames * args.num_segments,\n", | |||||
" tubelet_size=args.tubelet_size,\n", | |||||
" drop_rate=args.drop,\n", | |||||
" drop_path_rate=args.drop_path,\n", | |||||
" attn_drop_rate=args.attn_drop_rate,\n", | |||||
" drop_block_rate=None,\n", | |||||
" with_cp=args.with_checkpoint,\n", | |||||
" )\n", | |||||
"\n", | |||||
" self.base_embedding_dim = self.base_model.embed_dim * 14 ** 2 # TODO: use parameters\n", | |||||
" self.embedding_frame_number = args.num_frames // args.tubelet_size\n", | |||||
"\n", | |||||
" self.score_predictor = nn.Sequential(\n", | |||||
" nn.Linear(self.base_embedding_dim * self.embedding_frame_number, self.embedding_frame_number)\n", | |||||
" )\n", | |||||
"\n", | |||||
" def forward(self, x):\n", | |||||
" batch_size, num_channels, num_frames, height, width = x.shape\n", | |||||
"\n", | |||||
" embeddings = self.base_model.get_embeddings(x)\n", | |||||
"\n", | |||||
" embeddings = embeddings.reshape(batch_size, self.base_embedding_dim * self.embedding_frame_number)\n", | |||||
"\n", | |||||
" frame_scores = self.score_predictor(embeddings)\n", | |||||
" return frame_scores\n", | |||||
"\n", | |||||
" def load_base_model_state_dict(self, checkpoint_model):\n", | |||||
" load_state_dict(self.base_model, checkpoint_model)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"colab": { | |||||
"base_uri": "https://localhost:8080/" | |||||
}, | |||||
"id": "VO4DA_QkKl0v", | |||||
"outputId": "a107c6dd-272a-4970-ced6-cd0e16abbaa5" | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"size mismatch for head.weight: copying a param with shape torch.Size([710, 384]) from checkpoint, the shape in current model is torch.Size([1000, 384]).\n", | |||||
"size mismatch for head.bias: copying a param with shape torch.Size([710]) from checkpoint, the shape in current model is torch.Size([1000]).\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"124" | |||||
] | |||||
}, | |||||
"execution_count": 15, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"model = FrameScorePredictor(create_model, args=args)\n", | |||||
"\n", | |||||
"checkpoint = torch.load(args.finetune, map_location='cpu')\n", | |||||
"\n", | |||||
"checkpoint_model = checkpoint['model'] if 'model' in checkpoint else checkpoint['module']\n", | |||||
"\n", | |||||
"model.load_base_model_state_dict(checkpoint_model)\n", | |||||
"\n", | |||||
"model = model.to(device)\n", | |||||
"\n", | |||||
"del checkpoint\n", | |||||
"del checkpoint_model\n", | |||||
"\n", | |||||
"gc.collect()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "_aFo0-9J93FN" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"class PairwiseLoss(torch.nn.Module):\n", | |||||
" def __init__(self, margin=0.2):\n", | |||||
" super(PairwiseLoss, self).__init__()\n", | |||||
" self.margin = margin\n", | |||||
"\n", | |||||
" def forward(self, scores, labels):\n", | |||||
" pairwise_diff = scores.unsqueeze(1) - scores.unsqueeze(0)\n", | |||||
"\n", | |||||
" positive_mask = (labels.unsqueeze(1) - labels.unsqueeze(0)) > self.margin\n", | |||||
"\n", | |||||
" positive_diffs = pairwise_diff[positive_mask]\n", | |||||
"\n", | |||||
" loss = torch.clamp(positive_diffs, min=0).sum()\n", | |||||
"\n", | |||||
" num_positive_pairs = positive_mask.sum()\n", | |||||
" return loss / num_positive_pairs if num_positive_pairs > 0 else loss\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "GVGSAGKlPJye" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n", | |||||
"clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to(device)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "E9nez-e8TOLp" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"all_classes = [add_space_before_uppercase(text) for text in class_to_label.keys()]\n", | |||||
"all_classes_inputs = tokenizer(all_classes, padding=True, truncation=True, return_tensors=\"pt\").to(device)\n", | |||||
"with torch.no_grad():\n", | |||||
" text_embeddings = clip_model.get_text_features(**all_classes_inputs)\n", | |||||
"\n", | |||||
"all_classes_clip_embedding = {text: emb for text, emb in zip(all_classes, text_embeddings)}\n", | |||||
"all_classes_clip_embedding_list = torch.stack(list(all_classes_clip_embedding.values())).to(device)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "9lRGZ1ZuYmEe" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"def compute_accuracies(frame_features, target_index):\n", | |||||
" similarities = torch.matmul(frame_features, all_classes_clip_embedding_list.T)\n", | |||||
"\n", | |||||
" top5_indices = torch.topk(similarities, 5, largest=True).indices\n", | |||||
"\n", | |||||
" top1 = int(top5_indices[0] == target_index)\n", | |||||
" top5 = int(target_index in top5_indices)\n", | |||||
"\n", | |||||
" return top1, top5" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"colab": { | |||||
"background_save": true | |||||
}, | |||||
"id": "nQ3f69MnGLcW" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"embedding_frame_number = args.num_frames // args.tubelet_size\n", | |||||
"ranking_loss_scale = embedding_frame_number**3-embedding_frame_number\n", | |||||
"\n", | |||||
"# mse_criterion = nn.MSELoss()\n", | |||||
"# margin_ranking_loss = nn.MarginRankingLoss(margin=1)\n", | |||||
"\n", | |||||
"optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n", | |||||
"\n", | |||||
"gc.collect()\n", | |||||
"torch.cuda.empty_cache()\n", | |||||
"\n", | |||||
"model.train()\n", | |||||
"\n", | |||||
"bce_loss_sum = 0.0\n", | |||||
"mse_loss_sum = 0.0\n", | |||||
"rank_loss_sum = 0.0\n", | |||||
"combined_loss_sum = 0.0\n", | |||||
"print_step = 10\n", | |||||
"\n", | |||||
"strategies_items = ['random', 'best_pseudo', 'best_predicted']\n", | |||||
"\n", | |||||
"acc1_sum = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||||
"acc5_sum = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||||
"acc_counter = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||||
"\n", | |||||
"for step, (samples, targets, _, _, classes) in enumerate(data_loader_train):\n", | |||||
" optimizer.zero_grad()\n", | |||||
"\n", | |||||
" pseudo_labels = []\n", | |||||
" samples = samples.to(device)\n", | |||||
" targets = targets.to(device)\n", | |||||
" inputs = tokenizer(classes, padding=True, return_tensors=\"pt\").to(device)\n", | |||||
" text_features = clip_model.get_text_features(**inputs)\n", | |||||
"\n", | |||||
" all_frames_features = []\n", | |||||
"\n", | |||||
" for i in range(samples.shape[0]):\n", | |||||
" video_frames = samples[i].permute(1, 0, 2, 3)\n", | |||||
" video_frames = video_frames[1::2]\n", | |||||
" frames_features = clip_model.get_image_features(pixel_values=video_frames)\n", | |||||
" all_frames_features.append(frames_features)\n", | |||||
" similarity = torch.matmul(frames_features, text_features[i // args.num_sample])\n", | |||||
" pseudo_labels.append(similarity)\n", | |||||
" pseudo_labels = torch.stack(pseudo_labels).to(device)\n", | |||||
"\n", | |||||
" predicted_scores = model(samples)\n", | |||||
"\n", | |||||
" # mse_loss = mse_criterion(predicted_scores, pseudo_labels)\n", | |||||
"\n", | |||||
" best_frame_targets = torch.zeros_like(pseudo_labels)\n", | |||||
" best_frame_indices = pseudo_labels.argmax(dim=1)\n", | |||||
" best_frame_targets[torch.arange(pseudo_labels.size(0)), best_frame_indices] = 1\n", | |||||
" bce_loss = F.binary_cross_entropy_with_logits(predicted_scores, best_frame_targets)\n", | |||||
"\n", | |||||
" # target = (pseudo_labels[1:] > pseudo_labels[:-1]).float() * 2 - 1\n", | |||||
" # rank_loss = margin_ranking_loss(predicted_scores[:-1], predicted_scores[1:], target)\n", | |||||
"\n", | |||||
" # combined_loss = bce_loss + 0.02 * mse_loss + rank_loss\n", | |||||
"\n", | |||||
" # combined_loss.backward()\n", | |||||
" bce_loss.backward()\n", | |||||
" optimizer.step()\n", | |||||
"\n", | |||||
" # mse_loss_sum += mse_loss.item()\n", | |||||
" bce_loss_sum += bce_loss.item()\n", | |||||
" # rank_loss_sum += rank_loss.item()\n", | |||||
" # combined_loss_sum += combined_loss.item()\n", | |||||
"\n", | |||||
" strategies = {\n", | |||||
" 'random': torch.randint(low=0, high=args.num_frames // 2, size=(samples.size(0),)).to(device),\n", | |||||
" 'best_pseudo': pseudo_labels.argmax(dim=1),\n", | |||||
" 'best_predicted': predicted_scores.argmax(dim=1)\n", | |||||
" }\n", | |||||
"\n", | |||||
" for strategy_name, frame_indices in strategies.items():\n", | |||||
" for frames_features, frame_indice, target in zip(all_frames_features, frame_indices, targets):\n", | |||||
" frame_features = frames_features[frame_indice]\n", | |||||
" acc1, acc5 = compute_accuracies(frame_features, target - 1)\n", | |||||
"\n", | |||||
" acc1_sum[strategy_name] += acc1\n", | |||||
" acc5_sum[strategy_name] += acc5\n", | |||||
" acc_counter[strategy_name] += 1\n", | |||||
"\n", | |||||
" if (step + 1) % print_step == 0:\n", | |||||
" print(\n", | |||||
" f'Step {step}/{len(data_loader_train)}\\n'\n", | |||||
" # f'Average MSE Loss: {mse_loss_sum / print_step}\\n'\n", | |||||
" f'Average BCE Loss: {bce_loss_sum / print_step}\\n'\n", | |||||
" # f'Average Ranking Loss: {rank_loss_sum / print_step}\\n'\n", | |||||
" # f'Average Combined Loss: {combined_loss_sum / print_step}'\n", | |||||
" )\n", | |||||
"\n", | |||||
" for strategy_name in strategies.keys():\n", | |||||
" acc1_avg = acc1_sum[strategy_name] / acc_counter[strategy_name]\n", | |||||
" acc5_avg = acc5_sum[strategy_name] / acc_counter[strategy_name]\n", | |||||
"\n", | |||||
" print(f\"{strategy_name} Average Frame Acc@1: {acc1_avg:.2f}%, Acc@5: {acc5_avg:.2f}%\")\n", | |||||
"\n", | |||||
" # mse_loss_sum = 0.0\n", | |||||
" bce_loss_sum = 0.0\n", | |||||
" # rank_loss_sum = 0.0\n", | |||||
" # combined_loss_sum = 0.0\n", | |||||
"\n", | |||||
"\n", | |||||
" acc1_sum = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||||
" acc5_sum = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||||
" acc_counter = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||||
"\n", | |||||
" print('------------------')\n", | |||||
" print()\n", | |||||
"\n", | |||||
"\n", | |||||
" gc.collect()\n", | |||||
" torch.cuda.empty_cache()" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"accelerator": "GPU", | |||||
"colab": { | |||||
"provenance": [] | |||||
}, | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"name": "python" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 0 | |||||
} |
{ | |||||
"nbformat": 4, | |||||
"nbformat_minor": 0, | |||||
"metadata": { | |||||
"colab": { | |||||
"provenance": [] | |||||
}, | |||||
"kernelspec": { | |||||
"name": "python3", | |||||
"display_name": "Python 3" | |||||
}, | |||||
"language_info": { | |||||
"name": "python" | |||||
} | |||||
}, | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"colab": { | |||||
"base_uri": "https://localhost:8080/" | |||||
}, | |||||
"id": "OoqkDAv64cdu", | |||||
"outputId": "e0d30294-01b0-48d7-8199-d3ade0698c52" | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"output_type": "stream", | |||||
"name": "stdout", | |||||
"text": [ | |||||
"Downloading...\n", | |||||
"From: https://drive.google.com/uc?id=1bLhNoh7VNY3nkVlhQ0TBo-KQzTa-WW8c\n", | |||||
"To: /content/enhanced-ucf-101-label-CLIP-embedding.txt\n", | |||||
"\r 0% 0.00/667k [00:00<?, ?B/s]\r100% 667k/667k [00:00<00:00, 24.2MB/s]\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"!gdown 1bLhNoh7VNY3nkVlhQ0TBo-KQzTa-WW8c" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"source": [ | |||||
"import numpy as np\n", | |||||
"from sklearn.metrics.pairwise import cosine_similarity\n", | |||||
"from operator import itemgetter\n", | |||||
"\n", | |||||
"def parse_embedding_line(line):\n", | |||||
" parts = line.strip().split('|')\n", | |||||
" label = parts[0].strip()\n", | |||||
" description = parts[1].strip()\n", | |||||
" embedding = np.array([float(x) for x in parts[2].strip(' []').split(',')])\n", | |||||
" return label, description, embedding\n", | |||||
"\n", | |||||
"# Read the file\n", | |||||
"with open('/content/enhanced-ucf-101-label-CLIP-embedding.txt', 'r') as f:\n", | |||||
" lines = f.readlines()\n", | |||||
"\n", | |||||
"# Parse the lines\n", | |||||
"embeddings = [parse_embedding_line(line) for line in lines]\n", | |||||
"\n", | |||||
"# Calculate similarity and sort\n", | |||||
"similarities = []\n", | |||||
"for i in range(len(embeddings)):\n", | |||||
" for j in range(i+1, len(embeddings)):\n", | |||||
" similarity = cosine_similarity([embeddings[i][2]], [embeddings[j][2]])[0][0]\n", | |||||
" similarities.append((embeddings[i][0], embeddings[j][0], similarity))\n", | |||||
"\n", | |||||
"# Sort by similarity\n", | |||||
"similarities.sort(key=itemgetter(2), reverse=True)\n", | |||||
"\n", | |||||
"# Print the sorted similarities\n", | |||||
"for similarity in similarities[:300]:\n", | |||||
" print(f\"{similarity[0]} and {similarity[1]}: {similarity[2]}\")\n" | |||||
], | |||||
"metadata": { | |||||
"colab": { | |||||
"base_uri": "https://localhost:8080/" | |||||
}, | |||||
"id": "9_AHfllm7Y1F", | |||||
"outputId": "1eb26326-78ae-4640-bb42-9baaacfe81cb" | |||||
}, | |||||
"execution_count": null, | |||||
"outputs": [ | |||||
{ | |||||
"output_type": "stream", | |||||
"name": "stdout", | |||||
"text": [ | |||||
"BreastStroke and FrontCrawl: 0.9135641566364405\n", | |||||
"HighJump and LongJump: 0.9058382170839308\n", | |||||
"PushUps and WallPushups: 0.893513324165334\n", | |||||
"BoxingPunchingBag and BoxingSpeedBag: 0.8610437315536404\n", | |||||
"PlayingCello and PlayingViolin: 0.8607566078247353\n", | |||||
"Kayaking and Rafting: 0.8516765812785851\n", | |||||
"HandstandPushups and PushUps: 0.8291473360831765\n", | |||||
"HammerThrow and ThrowDiscus: 0.8285788251299882\n", | |||||
"HighJump and PoleVault: 0.827215227814937\n", | |||||
"LongJump and Shotput: 0.8249689845096113\n", | |||||
"Shotput and ThrowDiscus: 0.8225986807562835\n", | |||||
"BreastStroke and Diving: 0.8198360899246\n", | |||||
"HammerThrow and Shotput: 0.8194067921502493\n", | |||||
"ParallelBars and UnevenBars: 0.8172351546564418\n", | |||||
"HandstandPushups and WallPushups: 0.8063430148554194\n", | |||||
"PlayingDhol and PlayingTabla: 0.8044929381515984\n", | |||||
"Drumming and PlayingTabla: 0.8036557686356149\n", | |||||
"BaseballPitch and TennisSwing: 0.7926770716298019\n", | |||||
"HighJump and Shotput: 0.7916822672394149\n", | |||||
"FieldHockeyPenalty and SoccerPenalty: 0.7912218906905013\n", | |||||
"PlayingFlute and PlayingViolin: 0.7907740407852453\n", | |||||
"BodyWeightSquats and Lunges: 0.790197964204252\n", | |||||
"Diving and SkyDiving: 0.7900033596110797\n", | |||||
"PlayingPiano and PlayingViolin: 0.7898563658083394\n", | |||||
"Diving and FrontCrawl: 0.7884088912671918\n", | |||||
"CricketBowling and CricketShot: 0.786969527453953\n", | |||||
"Kayaking and Rowing: 0.786382273584344\n", | |||||
"Haircut and HeadMassage: 0.7813618969695116\n", | |||||
"PlayingSitar and PlayingTabla: 0.7788554840779363\n", | |||||
"SkateBoarding and Skiing: 0.7762970903271766\n", | |||||
"Rafting and Rowing: 0.7752746043074626\n", | |||||
"GolfSwing and TennisSwing: 0.7734050524041192\n", | |||||
"PullUps and WallPushups: 0.7729951565091107\n", | |||||
"HandstandPushups and HandstandWalking: 0.7717042057111059\n", | |||||
"HighJump and TrampolineJumping: 0.7710276014756723\n", | |||||
"HandstandWalking and SkateBoarding: 0.7684993331159056\n", | |||||
"Biking and HorseRiding: 0.768091642690889\n", | |||||
"PlayingGuitar and PlayingViolin: 0.766646619086345\n", | |||||
"LongJump and TennisSwing: 0.7665787465707985\n", | |||||
"RopeClimbing and TrampolineJumping: 0.7659114066820266\n", | |||||
"SkyDiving and TrampolineJumping: 0.76511061957433\n", | |||||
"PlayingCello and PlayingPiano: 0.7647740312136126\n", | |||||
"Diving and HighJump: 0.7629321426272753\n", | |||||
"JugglingBalls and YoYo: 0.7610768612305048\n", | |||||
"BrushingTeeth and ShavingBeard: 0.7591425280885891\n", | |||||
"Diving and LongJump: 0.7574804843984438\n", | |||||
"LongJump and TrampolineJumping: 0.7571481857197025\n", | |||||
"HighJump and TennisSwing: 0.7566156464932003\n", | |||||
"HammerThrow and LongJump: 0.7558300803658048\n", | |||||
"Diving and SkateBoarding: 0.7536973390587902\n", | |||||
"JumpRope and RopeClimbing: 0.753327828396938\n", | |||||
"LongJump and ThrowDiscus: 0.7529842080950556\n", | |||||
"Diving and TrampolineJumping: 0.7518814257833246\n", | |||||
"Biking and SkateBoarding: 0.7517431292455309\n", | |||||
"LongJump and PoleVault: 0.7507306877543533\n", | |||||
"BaseballPitch and CricketShot: 0.7499026266438265\n", | |||||
"Shotput and TennisSwing: 0.7489834916216616\n", | |||||
"PoleVault and Shotput: 0.7488596433999501\n", | |||||
"Swing and TrampolineJumping: 0.7487908446460982\n", | |||||
"CliffDiving and Diving: 0.7483162287151933\n", | |||||
"CricketShot and TableTennisShot: 0.7481777580061397\n", | |||||
"Diving and TennisSwing: 0.7471226907881188\n", | |||||
"FrisbeeCatch and ThrowDiscus: 0.7443262070746838\n", | |||||
"Skiing and Surfing: 0.7431443081264507\n", | |||||
"SkateBoarding and SkyDiving: 0.7426388562026282\n", | |||||
"PlayingGuitar and PlayingSitar: 0.7421187542614355\n", | |||||
"Punch and SkyDiving: 0.741055650388424\n", | |||||
"Rafting and SkyDiving: 0.7403432891302099\n", | |||||
"CricketShot and SoccerPenalty: 0.7399070298540545\n", | |||||
"RopeClimbing and Swing: 0.739812652356636\n", | |||||
"SkateBoarding and TrampolineJumping: 0.7398099398134427\n", | |||||
"Hammering and Typing: 0.7391169263001942\n", | |||||
"HandstandWalking and LongJump: 0.7376174197128719\n", | |||||
"HorseRiding and SkateBoarding: 0.7373558640663447\n", | |||||
"HandstandWalking and HighJump: 0.7373394124461938\n", | |||||
"CricketShot and TennisSwing: 0.7367685620369944\n", | |||||
"LongJump and SkateBoarding: 0.7363108945108562\n", | |||||
"Diving and HandstandWalking: 0.7361879858427884\n", | |||||
"PlayingGuitar and PlayingPiano: 0.7360426048849933\n", | |||||
"Hammering and Punch: 0.7355299374133767\n", | |||||
"Punch and TennisSwing: 0.7349357411747358\n", | |||||
"BalanceBeam and ParallelBars: 0.734626447206699\n", | |||||
"PlayingPiano and Typing: 0.7342818000329813\n", | |||||
"Drumming and PlayingDhol: 0.7342338545960496\n", | |||||
"Diving and Shotput: 0.733686875936528\n", | |||||
"PlayingFlute and PlayingPiano: 0.7334535071686105\n", | |||||
"FloorGymnastics and PoleVault: 0.7332110651309142\n", | |||||
"HorseRace and HorseRiding: 0.7326100067023478\n", | |||||
"HandstandWalking and TrampolineJumping: 0.7324930780008992\n", | |||||
"Basketball and BasketballDunk: 0.7320448669471811\n", | |||||
"HighJump and ThrowDiscus: 0.7312358151246964\n", | |||||
"ApplyEyeMakeup and ApplyLipstick: 0.7308442688180419\n", | |||||
"RopeClimbing and SkyDiving: 0.7306117431890142\n", | |||||
"TennisSwing and VolleyballSpiking: 0.7303710434696256\n", | |||||
"Billiards and Bowling: 0.7291336136266361\n", | |||||
"Skiing and SkyDiving: 0.7289941358107048\n", | |||||
"Diving and Skiing: 0.7279314246716805\n", | |||||
"Swing and TennisSwing: 0.7275821346958722\n", | |||||
"BaseballPitch and LongJump: 0.7264541939600715\n", | |||||
"PlayingCello and PlayingGuitar: 0.7263333055199999\n", | |||||
"HammerThrow and JavelinThrow: 0.7257228448556832\n", | |||||
"Haircut and Typing: 0.7257034588639546\n", | |||||
"RopeClimbing and SkateBoarding: 0.7255599711121894\n", | |||||
"HandstandPushups and PullUps: 0.7253599334545812\n", | |||||
"Rowing and Skiing: 0.7249357111858055\n", | |||||
"TennisSwing and TrampolineJumping: 0.7240195069941607\n", | |||||
"HammerThrow and PoleVault: 0.7239577542426181\n", | |||||
"LongJump and Skiing: 0.7237669589998139\n", | |||||
"BreastStroke and Shotput: 0.7236121929503939\n", | |||||
"PlayingCello and PlayingSitar: 0.7232122696549205\n", | |||||
"Punch and Swing: 0.7230845724023044\n", | |||||
"SkateBoarding and TennisSwing: 0.7228170165773073\n", | |||||
"BodyWeightSquats and HandstandPushups: 0.7220405937519339\n", | |||||
"FloorGymnastics and StillRings: 0.7219493456486239\n", | |||||
"PullUps and PushUps: 0.7213810950131618\n", | |||||
"HorseRiding and Skiing: 0.7206407102268063\n", | |||||
"SkyDiving and Typing: 0.7202455770918296\n", | |||||
"HorseRiding and RopeClimbing: 0.7191262064512949\n", | |||||
"SkyDiving and Swing: 0.7180675832908672\n", | |||||
"PlayingFlute and PlayingGuitar: 0.7180613359728372\n", | |||||
"HammerThrow and HighJump: 0.7179640718224373\n", | |||||
"Diving and Punch: 0.7176298910268708\n", | |||||
"PlayingViolin and TennisSwing: 0.7174823425623374\n", | |||||
"Diving and RopeClimbing: 0.7167796270997888\n", | |||||
"Rafting and RopeClimbing: 0.7167763884184823\n", | |||||
"TableTennisShot and TennisSwing: 0.7165475089712395\n", | |||||
"Biking and Skiing: 0.7164028469479011\n", | |||||
"JavelinThrow and PoleVault: 0.7159371318136043\n", | |||||
"RopeClimbing and Typing: 0.7155954413924253\n", | |||||
"HeadMassage and Typing: 0.7153357826215936\n", | |||||
"Rafting and Typing: 0.7138269972502138\n", | |||||
"Biking and Rowing: 0.7137891147805048\n", | |||||
"Kayaking and Typing: 0.7137610634362122\n", | |||||
"BreastStroke and HighJump: 0.7126134869505956\n", | |||||
"Diving and Rowing: 0.7125020136521408\n", | |||||
"HighJump and SkateBoarding: 0.7117778149802355\n", | |||||
"SoccerPenalty and VolleyballSpiking: 0.7111150090759842\n", | |||||
"LongJump and SkyDiving: 0.7110640850737046\n", | |||||
"PlayingFlute and PlayingSitar: 0.7110108149484404\n", | |||||
"BaseballPitch and Shotput: 0.7109615116869581\n", | |||||
"PlayingSitar and PlayingViolin: 0.7104350026580752\n", | |||||
"HandstandPushups and JumpingJack: 0.7103055239175267\n", | |||||
"SkateBoarding and Surfing: 0.7098772586463374\n", | |||||
"Punch and RopeClimbing: 0.7086769659664385\n", | |||||
"Punch and TrampolineJumping: 0.7078881926354027\n", | |||||
"Diving and Surfing: 0.707117497226668\n", | |||||
"PlayingPiano and TennisSwing: 0.706919504587053\n", | |||||
"Shotput and SkateBoarding: 0.7061596986078766\n", | |||||
"PlayingCello and PlayingFlute: 0.7058423238727487\n", | |||||
"FloorGymnastics and UnevenBars: 0.7057777115167209\n", | |||||
"ParallelBars and StillRings: 0.7055355076859835\n", | |||||
"CleanAndJerk and Shotput: 0.7055125658073746\n", | |||||
"Kayaking and SkyDiving: 0.704658976975876\n", | |||||
"TableTennisShot and VolleyballSpiking: 0.7045804240564966\n", | |||||
"Drumming and PlayingPiano: 0.7045712194927514\n", | |||||
"Haircut and Hammering: 0.7044598702031014\n", | |||||
"Rowing and TennisSwing: 0.7042651727056333\n", | |||||
"HorseRiding and Typing: 0.7040858174512816\n", | |||||
"Punch and Shotput: 0.704052525603613\n", | |||||
"Archery and TennisSwing: 0.7036815296561322\n", | |||||
"SkateBoarding and YoYo: 0.7030016415760869\n", | |||||
"RopeClimbing and YoYo: 0.7029277827820148\n", | |||||
"Punch and YoYo: 0.7023995431749964\n", | |||||
"LongJump and Lunges: 0.7018532908775557\n", | |||||
"JumpRope and TrampolineJumping: 0.7016024452998298\n", | |||||
"Diving and Rafting: 0.7013908045763904\n", | |||||
"Bowling and CricketBowling: 0.7013511724501809\n", | |||||
"Punch and TaiChi: 0.7008104891179683\n", | |||||
"BalanceBeam and UnevenBars: 0.7005837043328924\n", | |||||
"SkateBoarding and Typing: 0.7005140816781765\n", | |||||
"PlayingViolin and YoYo: 0.7003404861306218\n", | |||||
"BasketballDunk and TennisSwing: 0.6999227936958208\n", | |||||
"HighJump and VolleyballSpiking: 0.6998266362653691\n", | |||||
"FloorGymnastics and ParallelBars: 0.6988077818792078\n", | |||||
"PlayingViolin and Punch: 0.6973703128445751\n", | |||||
"Punch and Typing: 0.6967775549981869\n", | |||||
"Archery and Shotput: 0.6967583917284877\n", | |||||
"Billiards and TableTennisShot: 0.6967073608377916\n", | |||||
"BreastStroke and TennisSwing: 0.6966987394749857\n", | |||||
"Punch and SkateBoarding: 0.6964323711013919\n", | |||||
"TennisSwing and ThrowDiscus: 0.6964309660329386\n", | |||||
"Biking and TennisSwing: 0.6962187137964084\n", | |||||
"BaseballPitch and GolfSwing: 0.6961587111393436\n", | |||||
"SkateBoarding and Swing: 0.6960365072022532\n", | |||||
"BalanceBeam and FloorGymnastics: 0.6956748992112082\n", | |||||
"RopeClimbing and Skiing: 0.6954019941797185\n", | |||||
"BasketballDunk and HighJump: 0.6952324557182699\n", | |||||
"RopeClimbing and TennisSwing: 0.6952311829515747\n", | |||||
"BreastStroke and Rowing: 0.6949829501769483\n", | |||||
"Skiing and TennisSwing: 0.6947422694408114\n", | |||||
"JumpingJack and TrampolineJumping: 0.6946418413900224\n", | |||||
"ParallelBars and PommelHorse: 0.6943619896075937\n", | |||||
"BreastStroke and LongJump: 0.6926574264131027\n", | |||||
"FloorGymnastics and HandstandPushups: 0.6924127157339901\n", | |||||
"Shotput and Skiing: 0.692409162454707\n", | |||||
"BodyWeightSquats and JumpingJack: 0.6915140858910102\n", | |||||
"Archery and PlayingViolin: 0.6914364939922757\n", | |||||
"BasketballDunk and TrampolineJumping: 0.6910720435738926\n", | |||||
"Haircut and SkyDiving: 0.6906279918436358\n", | |||||
"HorseRiding and Rowing: 0.6906078683829708\n", | |||||
"Diving and Kayaking: 0.6904115966869461\n", | |||||
"HighJump and SkyDiving: 0.6903223255840882\n", | |||||
"HandstandWalking and RopeClimbing: 0.6900177976361846\n", | |||||
"Biking and Typing: 0.688663446247988\n", | |||||
"Archery and Punch: 0.6885543671849701\n", | |||||
"BreastStroke and Surfing: 0.6882605898627008\n", | |||||
"Diving and PlayingPiano: 0.6877803133466125\n", | |||||
"PlayingPiano and SkateBoarding: 0.6872490075085536\n", | |||||
"BaseballPitch and Punch: 0.6871945116845465\n", | |||||
"HandstandWalking and UnevenBars: 0.6869197476332304\n", | |||||
"CricketShot and GolfSwing: 0.6868492889521802\n", | |||||
"Kayaking and Surfing: 0.6865820428631303\n", | |||||
"Basketball and TennisSwing: 0.6864463021798647\n", | |||||
"Shotput and YoYo: 0.6862387737044783\n", | |||||
"PlayingViolin and SkateBoarding: 0.6861837432385506\n", | |||||
"SoccerPenalty and TennisSwing: 0.6861144923920035\n", | |||||
"BaseballPitch and SkateBoarding: 0.6855318206995243\n", | |||||
"Shotput and TrampolineJumping: 0.6854082143018928\n", | |||||
"CricketShot and Shotput: 0.6852120255606943\n", | |||||
"Kayaking and Skiing: 0.6850696489537216\n", | |||||
"Archery and Fencing: 0.6850545206297898\n", | |||||
"HighJump and Skiing: 0.6847553551867225\n", | |||||
"BaseballPitch and PlayingViolin: 0.6840371878413228\n", | |||||
"HorseRiding and Swing: 0.6835796664282903\n", | |||||
"Archery and SkyDiving: 0.683544816768237\n", | |||||
"PlayingCello and YoYo: 0.6833837981485343\n", | |||||
"Rowing and Typing: 0.6833516303139973\n", | |||||
"HorseRiding and SkyDiving: 0.6832857317563026\n", | |||||
"Lunges and TennisSwing: 0.6827234718924219\n", | |||||
"BasketballDunk and LongJump: 0.6825552137357147\n", | |||||
"Rowing and Surfing: 0.6820142752266611\n", | |||||
"JumpingJack and LongJump: 0.6819915349194994\n", | |||||
"Haircut and Punch: 0.6819490061511206\n", | |||||
"Haircut and SkateBoarding: 0.6818481724348707\n", | |||||
"FloorGymnastics and HighJump: 0.6808541160844553\n", | |||||
"Punch and Rafting: 0.6807997071512121\n", | |||||
"PlayingPiano and Rowing: 0.680677304303248\n", | |||||
"JavelinThrow and ThrowDiscus: 0.6806693704418652\n", | |||||
"Rowing and SkateBoarding: 0.6803773211285751\n", | |||||
"Hammering and HeadMassage: 0.6802512293515253\n", | |||||
"Punch and Rowing: 0.6802002555149966\n", | |||||
"BodyWeightSquats and PushUps: 0.6799776131017082\n", | |||||
"Diving and Typing: 0.6797726174709029\n", | |||||
"Shotput and SkyDiving: 0.6797673354951406\n", | |||||
"Hammering and PlayingViolin: 0.6796496669473412\n", | |||||
"HandstandWalking and SkyDiving: 0.6794054995491691\n", | |||||
"FloorGymnastics and HandstandWalking: 0.6794028936136237\n", | |||||
"BasketballDunk and CricketShot: 0.6792758965914687\n", | |||||
"Biking and Diving: 0.6790006973364628\n", | |||||
"BaseballPitch and CricketBowling: 0.6786719503376799\n", | |||||
"LongJump and VolleyballSpiking: 0.6784090342928275\n", | |||||
"PullUps and RopeClimbing: 0.678258564371885\n", | |||||
"Rafting and Skiing: 0.6781782308964603\n", | |||||
"SkyDiving and TennisSwing: 0.6780321924063827\n", | |||||
"Biking and Kayaking: 0.6778740166200692\n", | |||||
"HorseRiding and Punch: 0.6775237101716206\n", | |||||
"Hammering and SkateBoarding: 0.6770651672183081\n", | |||||
"Mixing and Typing: 0.6770458400648384\n", | |||||
"Shotput and VolleyballSpiking: 0.6769503514713261\n", | |||||
"LongJump and Punch: 0.6768260936231084\n", | |||||
"Billiards and TennisSwing: 0.676687811162065\n", | |||||
"SoccerJuggling and SoccerPenalty: 0.6765064274372669\n", | |||||
"StillRings and UnevenBars: 0.6764089115086491\n", | |||||
"JugglingBalls and SoccerJuggling: 0.6760817346996303\n", | |||||
"BaseballPitch and SoccerPenalty: 0.6759697039170522\n", | |||||
"Hammering and HorseRiding: 0.6757606742710943\n", | |||||
"Billiards and PlayingPiano: 0.6756887070040185\n", | |||||
"Basketball and Shotput: 0.6756210059393115\n", | |||||
"TrampolineJumping and Typing: 0.6756177121880147\n", | |||||
"BalanceBeam and PommelHorse: 0.6755478243357194\n", | |||||
"CliffDiving and SkyDiving: 0.6754688197958925\n", | |||||
"BenchPress and HandstandPushups: 0.6753462315331547\n", | |||||
"Rowing and Skijet: 0.6749020427705266\n", | |||||
"PlayingCello and TennisSwing: 0.6748126664110354\n", | |||||
"CricketBowling and TennisSwing: 0.6747936810472541\n", | |||||
"TrampolineJumping and UnevenBars: 0.6747665868776003\n", | |||||
"BodyWeightSquats and HandstandWalking: 0.6744448381666313\n", | |||||
"Hammering and RopeClimbing: 0.6740811403604476\n", | |||||
"Lunges and Shotput: 0.6737521021449214\n", | |||||
"Drumming and PlayingSitar: 0.6736439263992084\n", | |||||
"Skiing and Skijet: 0.6734635666570382\n", | |||||
"HandstandWalking and Shotput: 0.6732521658396622\n", | |||||
"BandMarching and SkateBoarding: 0.6730512018894181\n", | |||||
"HandstandWalking and TennisSwing: 0.672699463584204\n", | |||||
"HorseRiding and TennisSwing: 0.6722271368365996\n", | |||||
"Biking and RopeClimbing: 0.6721048261242591\n", | |||||
"PlayingViolin and Typing: 0.6720289417432124\n", | |||||
"HandstandWalking and JumpingJack: 0.671939227917232\n", | |||||
"PlayingPiano and PlayingSitar: 0.671923104534426\n", | |||||
"HorseRiding and Rafting: 0.6716409889288361\n", | |||||
"Biking and LongJump: 0.6713809775815504\n", | |||||
"Lunges and TaiChi: 0.6711312367339834\n", | |||||
"BandMarching and HorseRiding: 0.6711082186179306\n", | |||||
"CricketBowling and Shotput: 0.6705659431969646\n", | |||||
"BodyWeightSquats and WallPushups: 0.6705453623257535\n", | |||||
"SoccerPenalty and TableTennisShot: 0.6703895759334059\n", | |||||
"HandstandWalking and PushUps: 0.6702445259363492\n", | |||||
"RopeClimbing and UnevenBars: 0.6699618121733547\n", | |||||
"Rowing and Swing: 0.6697346266800764\n", | |||||
"PlayingCello and SkateBoarding: 0.6696594880529304\n" | |||||
] | |||||
} | |||||
] | |||||
} | |||||
] | |||||
} |
import torch | |||||
from torch.utils.data import DataLoader | |||||
from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_DataLoader | |||||
from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader | |||||
from dataloaders.dataloader_msvd_retrieval import MSVD_DataLoader | |||||
from dataloaders.dataloader_lsmdc_retrieval import LSMDC_DataLoader | |||||
from dataloaders.dataloader_activitynet_retrieval import ActivityNet_DataLoader | |||||
from dataloaders.dataloader_didemo_retrieval import DiDeMo_DataLoader | |||||
def dataloader_msrvtt_train(args, tokenizer): | |||||
msrvtt_dataset = MSRVTT_TrainDataLoader( | |||||
csv_path=args.train_csv, | |||||
json_path=args.data_path, | |||||
features_path=args.features_path, | |||||
max_words=args.max_words, | |||||
feature_framerate=args.feature_framerate, | |||||
tokenizer=tokenizer, | |||||
max_frames=args.max_frames, | |||||
unfold_sentences=args.expand_msrvtt_sentences, | |||||
frame_order=args.train_frame_order, | |||||
slice_framepos=args.slice_framepos, | |||||
) | |||||
train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset) | |||||
dataloader = DataLoader( | |||||
msrvtt_dataset, | |||||
batch_size=args.batch_size // args.n_gpu, | |||||
num_workers=args.num_thread_reader, | |||||
pin_memory=False, | |||||
shuffle=(train_sampler is None), | |||||
sampler=train_sampler, | |||||
drop_last=True, | |||||
) | |||||
return dataloader, len(msrvtt_dataset), train_sampler | |||||
def dataloader_msrvtt_test(args, tokenizer, subset="test"): | |||||
msrvtt_testset = MSRVTT_DataLoader( | |||||
csv_path=args.val_csv, | |||||
features_path=args.features_path, | |||||
max_words=args.max_words, | |||||
feature_framerate=args.feature_framerate, | |||||
tokenizer=tokenizer, | |||||
max_frames=args.max_frames, | |||||
frame_order=args.eval_frame_order, | |||||
slice_framepos=args.slice_framepos, | |||||
) | |||||
dataloader_msrvtt = DataLoader( | |||||
msrvtt_testset, | |||||
batch_size=args.batch_size_val, | |||||
num_workers=args.num_thread_reader, | |||||
shuffle=False, | |||||
drop_last=False, | |||||
) | |||||
return dataloader_msrvtt, len(msrvtt_testset) | |||||
def dataloader_msvd_train(args, tokenizer): | |||||
msvd_dataset = MSVD_DataLoader( | |||||
subset="train", | |||||
data_path=args.data_path, | |||||
features_path=args.features_path, | |||||
max_words=args.max_words, | |||||
feature_framerate=args.feature_framerate, | |||||
tokenizer=tokenizer, | |||||
max_frames=args.max_frames, | |||||
frame_order=args.train_frame_order, | |||||
slice_framepos=args.slice_framepos, | |||||
) | |||||
train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset) | |||||
dataloader = DataLoader( | |||||
msvd_dataset, | |||||
batch_size=args.batch_size // args.n_gpu, | |||||
num_workers=args.num_thread_reader, | |||||
pin_memory=False, | |||||
shuffle=(train_sampler is None), | |||||
sampler=train_sampler, | |||||
drop_last=True, | |||||
) | |||||
return dataloader, len(msvd_dataset), train_sampler | |||||
def dataloader_msvd_test(args, tokenizer, subset="test"): | |||||
msvd_testset = MSVD_DataLoader( | |||||
subset=subset, | |||||
data_path=args.data_path, | |||||
features_path=args.features_path, | |||||
max_words=args.max_words, | |||||
feature_framerate=args.feature_framerate, | |||||
tokenizer=tokenizer, | |||||
max_frames=args.max_frames, | |||||
frame_order=args.eval_frame_order, | |||||
slice_framepos=args.slice_framepos, | |||||
) | |||||
dataloader_msrvtt = DataLoader( | |||||
msvd_testset, | |||||
batch_size=args.batch_size_val, | |||||
num_workers=args.num_thread_reader, | |||||
shuffle=False, | |||||
drop_last=False, | |||||
) | |||||
return dataloader_msrvtt, len(msvd_testset) | |||||
def dataloader_lsmdc_train(args, tokenizer): | |||||
lsmdc_dataset = LSMDC_DataLoader( | |||||
subset="train", | |||||
data_path=args.data_path, | |||||
features_path=args.features_path, | |||||
max_words=args.max_words, | |||||
feature_framerate=args.feature_framerate, | |||||
tokenizer=tokenizer, | |||||
max_frames=args.max_frames, | |||||
frame_order=args.train_frame_order, | |||||
slice_framepos=args.slice_framepos, | |||||
) | |||||
train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset) | |||||
dataloader = DataLoader( | |||||
lsmdc_dataset, | |||||
batch_size=args.batch_size // args.n_gpu, | |||||
num_workers=args.num_thread_reader, | |||||
pin_memory=False, | |||||
shuffle=(train_sampler is None), | |||||
sampler=train_sampler, | |||||
drop_last=True, | |||||
) | |||||
return dataloader, len(lsmdc_dataset), train_sampler | |||||
def dataloader_lsmdc_test(args, tokenizer, subset="test"): | |||||
lsmdc_testset = LSMDC_DataLoader( | |||||
subset=subset, | |||||
data_path=args.data_path, | |||||
features_path=args.features_path, | |||||
max_words=args.max_words, | |||||
feature_framerate=args.feature_framerate, | |||||
tokenizer=tokenizer, | |||||
max_frames=args.max_frames, | |||||
frame_order=args.eval_frame_order, | |||||
slice_framepos=args.slice_framepos, | |||||
) | |||||
dataloader_msrvtt = DataLoader( | |||||
lsmdc_testset, | |||||
batch_size=args.batch_size_val, | |||||
num_workers=args.num_thread_reader, | |||||
shuffle=False, | |||||
drop_last=False, | |||||
) | |||||
return dataloader_msrvtt, len(lsmdc_testset) | |||||
def dataloader_activity_train(args, tokenizer): | |||||
activity_dataset = ActivityNet_DataLoader( | |||||
subset="train", | |||||
data_path=args.data_path, | |||||
features_path=args.features_path, | |||||
max_words=args.max_words, | |||||
feature_framerate=args.feature_framerate, | |||||
tokenizer=tokenizer, | |||||
max_frames=args.max_frames, | |||||
frame_order=args.train_frame_order, | |||||
slice_framepos=args.slice_framepos, | |||||
) | |||||
train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset) | |||||
dataloader = DataLoader( | |||||
activity_dataset, | |||||
batch_size=args.batch_size // args.n_gpu, | |||||
num_workers=args.num_thread_reader, | |||||
pin_memory=False, | |||||
shuffle=(train_sampler is None), | |||||
sampler=train_sampler, | |||||
drop_last=True, | |||||
) | |||||
return dataloader, len(activity_dataset), train_sampler | |||||
def dataloader_activity_test(args, tokenizer, subset="test"): | |||||
activity_testset = ActivityNet_DataLoader( | |||||
subset=subset, | |||||
data_path=args.data_path, | |||||
features_path=args.features_path, | |||||
max_words=args.max_words, | |||||
feature_framerate=args.feature_framerate, | |||||
tokenizer=tokenizer, | |||||
max_frames=args.max_frames, | |||||
frame_order=args.eval_frame_order, | |||||
slice_framepos=args.slice_framepos, | |||||
) | |||||
dataloader_msrvtt = DataLoader( | |||||
activity_testset, | |||||
batch_size=args.batch_size_val, | |||||
num_workers=args.num_thread_reader, | |||||
shuffle=False, | |||||
drop_last=False, | |||||
) | |||||
return dataloader_msrvtt, len(activity_testset) | |||||
def dataloader_didemo_train(args, tokenizer): | |||||
didemo_dataset = DiDeMo_DataLoader( | |||||
subset="train", | |||||
data_path=args.data_path, | |||||
features_path=args.features_path, | |||||
max_words=args.max_words, | |||||
feature_framerate=args.feature_framerate, | |||||
tokenizer=tokenizer, | |||||
max_frames=args.max_frames, | |||||
frame_order=args.train_frame_order, | |||||
slice_framepos=args.slice_framepos, | |||||
) | |||||
train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset) | |||||
dataloader = DataLoader( | |||||
didemo_dataset, | |||||
batch_size=args.batch_size // args.n_gpu, | |||||
num_workers=args.num_thread_reader, | |||||
pin_memory=False, | |||||
shuffle=(train_sampler is None), | |||||
sampler=train_sampler, | |||||
drop_last=True, | |||||
) | |||||
return dataloader, len(didemo_dataset), train_sampler | |||||
def dataloader_didemo_test(args, tokenizer, subset="test"): | |||||
didemo_testset = DiDeMo_DataLoader( | |||||
subset=subset, | |||||
data_path=args.data_path, | |||||
features_path=args.features_path, | |||||
max_words=args.max_words, | |||||
feature_framerate=args.feature_framerate, | |||||
tokenizer=tokenizer, | |||||
max_frames=args.max_frames, | |||||
frame_order=args.eval_frame_order, | |||||
slice_framepos=args.slice_framepos, | |||||
) | |||||
dataloader_didemo = DataLoader( | |||||
didemo_testset, | |||||
batch_size=args.batch_size_val, | |||||
num_workers=args.num_thread_reader, | |||||
shuffle=False, | |||||
drop_last=False, | |||||
) | |||||
return dataloader_didemo, len(didemo_testset) | |||||
DATALOADER_DICT = {} | |||||
DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_test, "test":None} | |||||
DATALOADER_DICT["msvd"] = {"train":dataloader_msvd_train, "val":dataloader_msvd_test, "test":dataloader_msvd_test} | |||||
DATALOADER_DICT["lsmdc"] = {"train":dataloader_lsmdc_train, "val":dataloader_lsmdc_test, "test":dataloader_lsmdc_test} | |||||
DATALOADER_DICT["activity"] = {"train":dataloader_activity_train, "val":dataloader_activity_test, "test":None} | |||||
DATALOADER_DICT["didemo"] = {"train":dataloader_didemo_train, "val":dataloader_didemo_test, "test":dataloader_didemo_test} |
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import unicode_literals | |||||
from __future__ import print_function | |||||
import os | |||||
from torch.utils.data import Dataset | |||||
import numpy as np | |||||
import json | |||||
import math | |||||
from dataloaders.rawvideo_util import RawVideoExtractor | |||||
class ActivityNet_DataLoader(Dataset): | |||||
def __init__( | |||||
self, | |||||
subset, | |||||
data_path, | |||||
features_path, | |||||
tokenizer, | |||||
max_words=30, | |||||
feature_framerate=1.0, | |||||
max_frames=100, | |||||
image_resolution=224, | |||||
frame_order=0, | |||||
slice_framepos=0, | |||||
): | |||||
self.data_path = data_path | |||||
self.features_path = features_path | |||||
self.feature_framerate = feature_framerate | |||||
self.max_words = max_words | |||||
self.max_frames = max_frames | |||||
self.tokenizer = tokenizer | |||||
# 0: ordinary order; 1: reverse order; 2: random order. | |||||
self.frame_order = frame_order | |||||
assert self.frame_order in [0, 1, 2] | |||||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||||
self.slice_framepos = slice_framepos | |||||
assert self.slice_framepos in [0, 1, 2] | |||||
self.subset = subset | |||||
assert self.subset in ["train", "val"] | |||||
video_id_path_dict = {} | |||||
video_id_path_dict["train"] = os.path.join(self.data_path, "train_ids.json") | |||||
video_id_path_dict["val"] = os.path.join(self.data_path, "val_ids.json") | |||||
video_json_path_dict = {} | |||||
video_json_path_dict["train"] = os.path.join(self.data_path, "train.json") | |||||
video_json_path_dict["val"] = os.path.join(self.data_path, "val_1.json") | |||||
pseudo_video_id_list, video_id_list = self._get_video_id_single(video_id_path_dict[self.subset]) | |||||
pseudo_caption_dict = self._get_captions_single(video_json_path_dict[self.subset]) | |||||
print("video id list: {}".format(len(video_id_list))) | |||||
print("pseudo caption dict: {}".format(len(pseudo_caption_dict.keys()))) | |||||
video_dict = {} | |||||
for root, dub_dir, video_files in os.walk(self.features_path): | |||||
for video_file in video_files: | |||||
video_id_ = ".".join(video_file.split(".")[:-1]) | |||||
if video_id_ not in video_id_list: | |||||
continue | |||||
file_path_ = os.path.join(root, video_file) | |||||
video_dict[video_id_] = file_path_ | |||||
self.video_dict = video_dict | |||||
print("video dict: {}".format(len(video_dict))) | |||||
self.pseudo_video_id_list = pseudo_video_id_list | |||||
self.video_id_list = video_id_list | |||||
self.pseudo_caption_dict = pseudo_caption_dict | |||||
# Get iterator video ids | |||||
self.video_id2idx_dict = {pseudo_video_id: id for id, pseudo_video_id in enumerate(self.pseudo_video_id_list)} | |||||
# Get all captions | |||||
self.iter2video_pairs_dict = {} | |||||
for pseudo_video_id, video_id in zip(self.pseudo_video_id_list, self.video_id_list): | |||||
if pseudo_video_id not in self.pseudo_caption_dict or video_id not in self.video_dict: | |||||
continue | |||||
caption = self.pseudo_caption_dict[pseudo_video_id] | |||||
n_caption = len(caption['start']) | |||||
for sub_id in range(n_caption): | |||||
self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (pseudo_video_id, sub_id) | |||||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||||
def __len__(self): | |||||
return len(self.iter2video_pairs_dict) | |||||
def _get_video_id_from_pseduo(self, pseudo_video_id): | |||||
video_id = pseudo_video_id[2:] | |||||
return video_id | |||||
def _get_video_id_single(self, path): | |||||
pseudo_video_id_list = [] | |||||
video_id_list = [] | |||||
print('Loading json: {}'.format(path)) | |||||
with open(path, 'r') as f: | |||||
json_data = json.load(f) | |||||
for pseudo_video_id in json_data: | |||||
if pseudo_video_id in pseudo_video_id_list: | |||||
print("reduplicate.") | |||||
else: | |||||
video_id = self._get_video_id_from_pseduo(pseudo_video_id) | |||||
pseudo_video_id_list.append(pseudo_video_id) | |||||
video_id_list.append(video_id) | |||||
return pseudo_video_id_list, video_id_list | |||||
def _get_captions_single(self, path): | |||||
pseudo_caption_dict = {} | |||||
with open(path, 'r') as f: | |||||
json_data = json.load(f) | |||||
for pseudo_video_id, v_ in json_data.items(): | |||||
pseudo_caption_dict[pseudo_video_id] = {} | |||||
duration = v_["duration"] | |||||
pseudo_caption_dict[pseudo_video_id]["start"] = np.array([0], dtype=object) | |||||
pseudo_caption_dict[pseudo_video_id]["end"] = np.array([int(math.ceil(float(duration)))], dtype=object) | |||||
pseudo_caption_dict[pseudo_video_id]["text"] = np.array([" ".join(v_["sentences"])], dtype=object) | |||||
return pseudo_caption_dict | |||||
def _get_text(self, pseudo_video_id, sub_id): | |||||
caption = self.pseudo_caption_dict[pseudo_video_id] | |||||
k = 1 | |||||
r_ind = [sub_id] | |||||
starts = np.zeros(k, dtype=np.long) | |||||
ends = np.zeros(k, dtype=np.long) | |||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||||
for i in range(k): | |||||
ind = r_ind[i] | |||||
start_, end_ = caption['start'][ind], caption['end'][ind] | |||||
words = self.tokenizer.tokenize(caption['text'][ind]) | |||||
starts[i], ends[i] = start_, end_ | |||||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||||
total_length_with_CLS = self.max_words - 1 | |||||
if len(words) > total_length_with_CLS: | |||||
words = words[:total_length_with_CLS] | |||||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||||
input_mask = [1] * len(input_ids) | |||||
segment_ids = [0] * len(input_ids) | |||||
while len(input_ids) < self.max_words: | |||||
input_ids.append(0) | |||||
input_mask.append(0) | |||||
segment_ids.append(0) | |||||
assert len(input_ids) == self.max_words | |||||
assert len(input_mask) == self.max_words | |||||
assert len(segment_ids) == self.max_words | |||||
pairs_text[i] = np.array(input_ids) | |||||
pairs_mask[i] = np.array(input_mask) | |||||
pairs_segment[i] = np.array(segment_ids) | |||||
return pairs_text, pairs_mask, pairs_segment, starts, ends | |||||
def _get_rawvideo(self, idx, s, e): | |||||
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long) | |||||
max_video_length = [0] * len(s) | |||||
# Pair x L x T x 3 x H x W | |||||
video = np.zeros((len(s), self.max_frames, 1, 3, | |||||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||||
video_path = self.video_dict[idx] | |||||
try: | |||||
for i in range(len(s)): | |||||
start_time = int(s[i]) | |||||
end_time = int(e[i]) | |||||
start_time = start_time if start_time >= 0. else 0. | |||||
end_time = end_time if end_time >= 0. else 0. | |||||
if start_time > end_time: | |||||
start_time, end_time = end_time, start_time | |||||
elif start_time == end_time: | |||||
end_time = end_time + 1 | |||||
# Should be optimized by gathering all asking of this video | |||||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) | |||||
raw_video_data = raw_video_data['video'] | |||||
if len(raw_video_data.shape) > 3: | |||||
raw_video_data_clip = raw_video_data | |||||
# L x T x 3 x H x W | |||||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||||
if self.max_frames < raw_video_slice.shape[0]: | |||||
if self.slice_framepos == 0: | |||||
video_slice = raw_video_slice[:self.max_frames, ...] | |||||
elif self.slice_framepos == 1: | |||||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||||
else: | |||||
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) | |||||
video_slice = raw_video_slice[sample_indx, ...] | |||||
else: | |||||
video_slice = raw_video_slice | |||||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||||
slice_len = video_slice.shape[0] | |||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||||
if slice_len < 1: | |||||
pass | |||||
else: | |||||
video[i][:slice_len, ...] = video_slice | |||||
else: | |||||
print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) | |||||
except Exception as excep: | |||||
print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) | |||||
raise excep | |||||
for i, v_length in enumerate(max_video_length): | |||||
video_mask[i][:v_length] = [1] * v_length | |||||
return video, video_mask | |||||
def __getitem__(self, feature_idx): | |||||
pseudo_video_id, sub_id = self.iter2video_pairs_dict[feature_idx] | |||||
idx = self.video_id2idx_dict[pseudo_video_id] | |||||
pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(pseudo_video_id, sub_id) | |||||
video, video_mask = self._get_rawvideo(self.video_id_list[idx], starts, ends) | |||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask |
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import unicode_literals | |||||
from __future__ import print_function | |||||
import os | |||||
from torch.utils.data import Dataset | |||||
import numpy as np | |||||
import json | |||||
from dataloaders.rawvideo_util import RawVideoExtractor | |||||
class DiDeMo_DataLoader(Dataset): | |||||
def __init__( | |||||
self, | |||||
subset, | |||||
data_path, | |||||
features_path, | |||||
tokenizer, | |||||
max_words=30, | |||||
feature_framerate=1.0, | |||||
max_frames=100, | |||||
image_resolution=224, | |||||
frame_order=0, | |||||
slice_framepos=0, | |||||
): | |||||
self.data_path = data_path | |||||
self.features_path = features_path | |||||
self.feature_framerate = feature_framerate | |||||
self.max_words = max_words | |||||
self.max_frames = max_frames | |||||
self.tokenizer = tokenizer | |||||
# 0: ordinary order; 1: reverse order; 2: random order. | |||||
self.frame_order = frame_order | |||||
assert self.frame_order in [0, 1, 2] | |||||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||||
self.slice_framepos = slice_framepos | |||||
assert self.slice_framepos in [0, 1, 2] | |||||
self.subset = subset | |||||
assert self.subset in ["train", "val", "test"] | |||||
video_id_path_dict = {} | |||||
video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") | |||||
video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") | |||||
video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") | |||||
video_json_path_dict = {} | |||||
video_json_path_dict["train"] = os.path.join(self.data_path, "train_data.json") | |||||
video_json_path_dict["val"] = os.path.join(self.data_path, "val_data.json") | |||||
video_json_path_dict["test"] = os.path.join(self.data_path, "test_data.json") | |||||
with open(video_id_path_dict[self.subset], 'r') as fp: | |||||
video_ids = [itm.strip() for itm in fp.readlines()] | |||||
caption_dict = {} | |||||
with open(video_json_path_dict[self.subset], 'r') as f: | |||||
json_data = json.load(f) | |||||
for itm in json_data: | |||||
description = itm["description"] | |||||
times = itm["times"] | |||||
video = itm["video"] | |||||
if video not in video_ids: | |||||
continue | |||||
# each video is split into 5-second temporal chunks | |||||
# average the points from each annotator | |||||
start_ = np.mean([t_[0] for t_ in times]) * 5 | |||||
end_ = (np.mean([t_[1] for t_ in times]) + 1) * 5 | |||||
if video in caption_dict: | |||||
caption_dict[video]["start"].append(start_) | |||||
caption_dict[video]["end"].append(end_) | |||||
caption_dict[video]["text"].append(description) | |||||
else: | |||||
caption_dict[video] = {} | |||||
caption_dict[video]["start"] = [start_] | |||||
caption_dict[video]["end"] = [end_] | |||||
caption_dict[video]["text"] = [description] | |||||
for k_ in caption_dict.keys(): | |||||
caption_dict[k_]["start"] = [0] | |||||
# trick to save time on obtaining each video length | |||||
# [https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md]: | |||||
# Some videos are longer than 30 seconds. These videos were truncated to 30 seconds during annotation. | |||||
caption_dict[k_]["end"] = [31] | |||||
caption_dict[k_]["text"] = [" ".join(caption_dict[k_]["text"])] | |||||
video_dict = {} | |||||
for root, dub_dir, video_files in os.walk(self.features_path): | |||||
for video_file in video_files: | |||||
video_id_ = video_file | |||||
if video_id_ not in video_ids: | |||||
continue | |||||
file_path_ = os.path.join(root, video_file) | |||||
video_dict[video_id_] = file_path_ | |||||
self.caption_dict = caption_dict | |||||
self.video_dict = video_dict | |||||
video_ids = list(set(video_ids) & set(self.caption_dict.keys()) & set(self.video_dict.keys())) | |||||
# Get all captions | |||||
self.iter2video_pairs_dict = {} | |||||
for video_id in self.caption_dict.keys(): | |||||
if video_id not in video_ids: | |||||
continue | |||||
caption = self.caption_dict[video_id] | |||||
n_caption = len(caption['start']) | |||||
for sub_id in range(n_caption): | |||||
self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (video_id, sub_id) | |||||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||||
def __len__(self): | |||||
return len(self.iter2video_pairs_dict) | |||||
def _get_text(self, video_id, sub_id): | |||||
caption = self.caption_dict[video_id] | |||||
k = 1 | |||||
r_ind = [sub_id] | |||||
starts = np.zeros(k, dtype=np.long) | |||||
ends = np.zeros(k, dtype=np.long) | |||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||||
for i in range(k): | |||||
ind = r_ind[i] | |||||
start_, end_ = caption['start'][ind], caption['end'][ind] | |||||
words = self.tokenizer.tokenize(caption['text'][ind]) | |||||
starts[i], ends[i] = start_, end_ | |||||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||||
total_length_with_CLS = self.max_words - 1 | |||||
if len(words) > total_length_with_CLS: | |||||
words = words[:total_length_with_CLS] | |||||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||||
input_mask = [1] * len(input_ids) | |||||
segment_ids = [0] * len(input_ids) | |||||
while len(input_ids) < self.max_words: | |||||
input_ids.append(0) | |||||
input_mask.append(0) | |||||
segment_ids.append(0) | |||||
assert len(input_ids) == self.max_words | |||||
assert len(input_mask) == self.max_words | |||||
assert len(segment_ids) == self.max_words | |||||
pairs_text[i] = np.array(input_ids) | |||||
pairs_mask[i] = np.array(input_mask) | |||||
pairs_segment[i] = np.array(segment_ids) | |||||
return pairs_text, pairs_mask, pairs_segment, starts, ends | |||||
def _get_rawvideo(self, idx, s, e): | |||||
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long) | |||||
max_video_length = [0] * len(s) | |||||
# Pair x L x T x 3 x H x W | |||||
video = np.zeros((len(s), self.max_frames, 1, 3, | |||||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||||
video_path = self.video_dict[idx] | |||||
try: | |||||
for i in range(len(s)): | |||||
start_time = int(s[i]) | |||||
end_time = int(e[i]) | |||||
start_time = start_time if start_time >= 0. else 0. | |||||
end_time = end_time if end_time >= 0. else 0. | |||||
if start_time > end_time: | |||||
start_time, end_time = end_time, start_time | |||||
elif start_time == end_time: | |||||
end_time = end_time + 1 | |||||
cache_id = "{}_{}_{}".format(video_path, start_time, end_time) | |||||
# Should be optimized by gathering all asking of this video | |||||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) | |||||
raw_video_data = raw_video_data['video'] | |||||
if len(raw_video_data.shape) > 3: | |||||
raw_video_data_clip = raw_video_data | |||||
# L x T x 3 x H x W | |||||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||||
if self.max_frames < raw_video_slice.shape[0]: | |||||
if self.slice_framepos == 0: | |||||
video_slice = raw_video_slice[:self.max_frames, ...] | |||||
elif self.slice_framepos == 1: | |||||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||||
else: | |||||
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) | |||||
video_slice = raw_video_slice[sample_indx, ...] | |||||
else: | |||||
video_slice = raw_video_slice | |||||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||||
slice_len = video_slice.shape[0] | |||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||||
if slice_len < 1: | |||||
pass | |||||
else: | |||||
video[i][:slice_len, ...] = video_slice | |||||
else: | |||||
print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) | |||||
except Exception as excep: | |||||
print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) | |||||
pass | |||||
# raise e | |||||
for i, v_length in enumerate(max_video_length): | |||||
video_mask[i][:v_length] = [1] * v_length | |||||
return video, video_mask | |||||
def __getitem__(self, feature_idx): | |||||
video_id, sub_id = self.iter2video_pairs_dict[feature_idx] | |||||
pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(video_id, sub_id) | |||||
video, video_mask = self._get_rawvideo(video_id, starts, ends) | |||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask |
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import unicode_literals | |||||
from __future__ import print_function | |||||
import os | |||||
from torch.utils.data import Dataset | |||||
import numpy as np | |||||
import json | |||||
import math | |||||
from dataloaders.rawvideo_util import RawVideoExtractor | |||||
class LSMDC_DataLoader(Dataset): | |||||
"""LSMDC dataset loader.""" | |||||
def __init__( | |||||
self, | |||||
subset, | |||||
data_path, | |||||
features_path, | |||||
tokenizer, | |||||
max_words=30, | |||||
feature_framerate=1.0, | |||||
max_frames=100, | |||||
image_resolution=224, | |||||
frame_order=0, | |||||
slice_framepos=0, | |||||
): | |||||
self.data_path = data_path | |||||
self.features_path = features_path | |||||
self.feature_framerate = feature_framerate | |||||
self.max_words = max_words | |||||
self.max_frames = max_frames | |||||
self.tokenizer = tokenizer | |||||
# 0: ordinary order; 1: reverse order; 2: random order. | |||||
self.frame_order = frame_order | |||||
assert self.frame_order in [0, 1, 2] | |||||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||||
self.slice_framepos = slice_framepos | |||||
assert self.slice_framepos in [0, 1, 2] | |||||
self.subset = subset | |||||
assert self.subset in ["train", "val", "test"] | |||||
video_json_path_dict = {} | |||||
video_json_path_dict["train"] = os.path.join(self.data_path, "LSMDC16_annos_training.csv") | |||||
video_json_path_dict["val"] = os.path.join(self.data_path, "LSMDC16_annos_val.csv") | |||||
video_json_path_dict["test"] = os.path.join(self.data_path, "LSMDC16_challenge_1000_publictect.csv") | |||||
# <CLIP_ID>\t<START_ALIGNED>\t<END_ALIGNED>\t<START_EXTRACTED>\t<END_EXTRACTED>\t<SENTENCE> | |||||
# <CLIP_ID> is not a unique identifier, i.e. the same <CLIP_ID> can be associated with multiple sentences. | |||||
# However, LSMDC16_challenge_1000_publictect.csv has no repeat instances | |||||
video_id_list = [] | |||||
caption_dict = {} | |||||
with open(video_json_path_dict[self.subset], 'r') as fp: | |||||
for line in fp: | |||||
line = line.strip() | |||||
line_split = line.split("\t") | |||||
assert len(line_split) == 6 | |||||
clip_id, start_aligned, end_aligned, start_extracted, end_extracted, sentence = line_split | |||||
caption_dict[len(caption_dict)] = (clip_id, sentence) | |||||
if clip_id not in video_id_list: video_id_list.append(clip_id) | |||||
video_dict = {} | |||||
for root, dub_dir, video_files in os.walk(self.features_path): | |||||
for video_file in video_files: | |||||
video_id_ = ".".join(video_file.split(".")[:-1]) | |||||
if video_id_ not in video_id_list: | |||||
continue | |||||
file_path_ = os.path.join(root, video_file) | |||||
video_dict[video_id_] = file_path_ | |||||
self.video_dict = video_dict | |||||
# Get all captions | |||||
self.iter2video_pairs_dict = {} | |||||
for clip_id, sentence in caption_dict.values(): | |||||
if clip_id not in self.video_dict: | |||||
continue | |||||
self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (clip_id, sentence) | |||||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||||
def __len__(self): | |||||
return len(self.iter2video_pairs_dict) | |||||
def _get_video_id_from_pseduo(self, pseudo_video_id): | |||||
video_id = pseudo_video_id[2:] | |||||
return video_id | |||||
def _get_video_id_single(self, path): | |||||
pseudo_video_id_list = [] | |||||
video_id_list = [] | |||||
print('Loading json: {}'.format(path)) | |||||
with open(path, 'r') as f: | |||||
json_data = json.load(f) | |||||
for pseudo_video_id in json_data: | |||||
if pseudo_video_id in pseudo_video_id_list: | |||||
print("reduplicate.") | |||||
else: | |||||
video_id = self._get_video_id_from_pseduo(pseudo_video_id) | |||||
pseudo_video_id_list.append(pseudo_video_id) | |||||
video_id_list.append(video_id) | |||||
return pseudo_video_id_list, video_id_list | |||||
def _get_captions_single(self, path): | |||||
pseudo_caption_dict = {} | |||||
with open(path, 'r') as f: | |||||
json_data = json.load(f) | |||||
for pseudo_video_id, v_ in json_data.items(): | |||||
pseudo_caption_dict[pseudo_video_id] = {} | |||||
timestamps = v_["timestamps"] | |||||
pseudo_caption_dict[pseudo_video_id]["start"] = \ | |||||
np.array([int(math.floor(float(itm[0]))) for itm in timestamps], dtype=object) | |||||
pseudo_caption_dict[pseudo_video_id]["end"] = \ | |||||
np.array([int(math.ceil(float(itm[1]))) for itm in timestamps], dtype=object) | |||||
pseudo_caption_dict[pseudo_video_id]["text"] = np.array(v_["sentences"], dtype=object) | |||||
return pseudo_caption_dict | |||||
def _get_text(self, video_id, caption): | |||||
k = 1 | |||||
choice_video_ids = [video_id] | |||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||||
for i, video_id in enumerate(choice_video_ids): | |||||
words = self.tokenizer.tokenize(caption) | |||||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||||
total_length_with_CLS = self.max_words - 1 | |||||
if len(words) > total_length_with_CLS: | |||||
words = words[:total_length_with_CLS] | |||||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||||
input_mask = [1] * len(input_ids) | |||||
segment_ids = [0] * len(input_ids) | |||||
while len(input_ids) < self.max_words: | |||||
input_ids.append(0) | |||||
input_mask.append(0) | |||||
segment_ids.append(0) | |||||
assert len(input_ids) == self.max_words | |||||
assert len(input_mask) == self.max_words | |||||
assert len(segment_ids) == self.max_words | |||||
pairs_text[i] = np.array(input_ids) | |||||
pairs_mask[i] = np.array(input_mask) | |||||
pairs_segment[i] = np.array(segment_ids) | |||||
return pairs_text, pairs_mask, pairs_segment, choice_video_ids | |||||
def _get_rawvideo(self, choice_video_ids): | |||||
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) | |||||
max_video_length = [0] * len(choice_video_ids) | |||||
# Pair x L x T x 3 x H x W | |||||
video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, | |||||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||||
try: | |||||
for i, video_id in enumerate(choice_video_ids): | |||||
video_path = self.video_dict[video_id] | |||||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path) | |||||
raw_video_data = raw_video_data['video'] | |||||
if len(raw_video_data.shape) > 3: | |||||
raw_video_data_clip = raw_video_data | |||||
# L x T x 3 x H x W | |||||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||||
if self.max_frames < raw_video_slice.shape[0]: | |||||
if self.slice_framepos == 0: | |||||
video_slice = raw_video_slice[:self.max_frames, ...] | |||||
elif self.slice_framepos == 1: | |||||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||||
else: | |||||
sample_indx = np.linspace(0, raw_video_slice.shape[0]-1, num=self.max_frames, dtype=int) | |||||
video_slice = raw_video_slice[sample_indx, ...] | |||||
else: | |||||
video_slice = raw_video_slice | |||||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||||
slice_len = video_slice.shape[0] | |||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||||
if slice_len < 1: | |||||
pass | |||||
else: | |||||
video[i][:slice_len, ...] = video_slice | |||||
else: | |||||
print("video path: {} error. video id: {}".format(video_path, video_id)) | |||||
except Exception as excep: | |||||
print("Video ids: {}".format(choice_video_ids)) | |||||
raise excep | |||||
for i, v_length in enumerate(max_video_length): | |||||
video_mask[i][:v_length] = [1] * v_length | |||||
return video, video_mask | |||||
def __getitem__(self, feature_idx): | |||||
clip_id, sentence = self.iter2video_pairs_dict[feature_idx] | |||||
pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(clip_id, sentence) | |||||
video, video_mask = self._get_rawvideo(choice_video_ids) | |||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask |
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import unicode_literals | |||||
from __future__ import print_function | |||||
import os | |||||
from torch.utils.data import Dataset | |||||
import numpy as np | |||||
import pandas as pd | |||||
from collections import defaultdict | |||||
import json | |||||
import random | |||||
from dataloaders.rawvideo_util import RawVideoExtractor | |||||
class MSRVTT_DataLoader(Dataset): | |||||
"""MSRVTT dataset loader.""" | |||||
def __init__( | |||||
self, | |||||
csv_path, | |||||
features_path, | |||||
tokenizer, | |||||
max_words=30, | |||||
feature_framerate=1.0, | |||||
max_frames=100, | |||||
image_resolution=224, | |||||
frame_order=0, | |||||
slice_framepos=0, | |||||
): | |||||
self.data = pd.read_csv(csv_path) | |||||
self.features_path = features_path | |||||
self.feature_framerate = feature_framerate | |||||
self.max_words = max_words | |||||
self.max_frames = max_frames | |||||
self.tokenizer = tokenizer | |||||
# 0: ordinary order; 1: reverse order; 2: random order. | |||||
self.frame_order = frame_order | |||||
assert self.frame_order in [0, 1, 2] | |||||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||||
self.slice_framepos = slice_framepos | |||||
assert self.slice_framepos in [0, 1, 2] | |||||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||||
def __len__(self): | |||||
return len(self.data) | |||||
def _get_text(self, video_id, sentence): | |||||
choice_video_ids = [video_id] | |||||
n_caption = len(choice_video_ids) | |||||
k = n_caption | |||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||||
for i, video_id in enumerate(choice_video_ids): | |||||
words = self.tokenizer.tokenize(sentence) | |||||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||||
total_length_with_CLS = self.max_words - 1 | |||||
if len(words) > total_length_with_CLS: | |||||
words = words[:total_length_with_CLS] | |||||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||||
input_mask = [1] * len(input_ids) | |||||
segment_ids = [0] * len(input_ids) | |||||
while len(input_ids) < self.max_words: | |||||
input_ids.append(0) | |||||
input_mask.append(0) | |||||
segment_ids.append(0) | |||||
assert len(input_ids) == self.max_words | |||||
assert len(input_mask) == self.max_words | |||||
assert len(segment_ids) == self.max_words | |||||
pairs_text[i] = np.array(input_ids) | |||||
pairs_mask[i] = np.array(input_mask) | |||||
pairs_segment[i] = np.array(segment_ids) | |||||
return pairs_text, pairs_mask, pairs_segment, choice_video_ids | |||||
def _get_rawvideo(self, choice_video_ids): | |||||
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) | |||||
max_video_length = [0] * len(choice_video_ids) | |||||
# Pair x L x T x 3 x H x W | |||||
video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, | |||||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||||
for i, video_id in enumerate(choice_video_ids): | |||||
# Individual for YoucokII dataset, due to it video format | |||||
video_path = os.path.join(self.features_path, "{}.mp4".format(video_id)) | |||||
if os.path.exists(video_path) is False: | |||||
video_path = video_path.replace(".mp4", ".webm") | |||||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path) | |||||
raw_video_data = raw_video_data['video'] | |||||
if len(raw_video_data.shape) > 3: | |||||
raw_video_data_clip = raw_video_data | |||||
# L x T x 3 x H x W | |||||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||||
if self.max_frames < raw_video_slice.shape[0]: | |||||
if self.slice_framepos == 0: | |||||
video_slice = raw_video_slice[:self.max_frames, ...] | |||||
elif self.slice_framepos == 1: | |||||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||||
else: | |||||
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) | |||||
video_slice = raw_video_slice[sample_indx, ...] | |||||
else: | |||||
video_slice = raw_video_slice | |||||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||||
slice_len = video_slice.shape[0] | |||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||||
if slice_len < 1: | |||||
pass | |||||
else: | |||||
video[i][:slice_len, ...] = video_slice | |||||
else: | |||||
print("video path: {} error. video id: {}".format(video_path, video_id)) | |||||
for i, v_length in enumerate(max_video_length): | |||||
video_mask[i][:v_length] = [1] * v_length | |||||
return video, video_mask | |||||
def __getitem__(self, idx): | |||||
video_id = self.data['video_id'].values[idx] | |||||
sentence = self.data['sentence'].values[idx] | |||||
pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, sentence) | |||||
video, video_mask = self._get_rawvideo(choice_video_ids) | |||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask | |||||
class MSRVTT_TrainDataLoader(Dataset): | |||||
"""MSRVTT train dataset loader.""" | |||||
def __init__( | |||||
self, | |||||
csv_path, | |||||
json_path, | |||||
features_path, | |||||
tokenizer, | |||||
max_words=30, | |||||
feature_framerate=1.0, | |||||
max_frames=100, | |||||
unfold_sentences=False, | |||||
image_resolution=224, | |||||
frame_order=0, | |||||
slice_framepos=0, | |||||
): | |||||
self.csv = pd.read_csv(csv_path) | |||||
self.data = json.load(open(json_path, 'r')) | |||||
self.features_path = features_path | |||||
self.feature_framerate = feature_framerate | |||||
self.max_words = max_words | |||||
self.max_frames = max_frames | |||||
self.tokenizer = tokenizer | |||||
# 0: ordinary order; 1: reverse order; 2: random order. | |||||
self.frame_order = frame_order | |||||
assert self.frame_order in [0, 1, 2] | |||||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||||
self.slice_framepos = slice_framepos | |||||
assert self.slice_framepos in [0, 1, 2] | |||||
self.unfold_sentences = unfold_sentences | |||||
self.sample_len = 0 | |||||
if self.unfold_sentences: | |||||
train_video_ids = list(self.csv['video_id'].values) | |||||
self.sentences_dict = {} | |||||
for itm in self.data['sentences']: | |||||
if itm['video_id'] in train_video_ids: | |||||
self.sentences_dict[len(self.sentences_dict)] = (itm['video_id'], itm['caption']) | |||||
self.sample_len = len(self.sentences_dict) | |||||
else: | |||||
num_sentences = 0 | |||||
self.sentences = defaultdict(list) | |||||
s_video_id_set = set() | |||||
for itm in self.data['sentences']: | |||||
self.sentences[itm['video_id']].append(itm['caption']) | |||||
num_sentences += 1 | |||||
s_video_id_set.add(itm['video_id']) | |||||
# Use to find the clips in the same video | |||||
self.parent_ids = {} | |||||
self.children_video_ids = defaultdict(list) | |||||
for itm in self.data['videos']: | |||||
vid = itm["video_id"] | |||||
url_posfix = itm["url"].split("?v=")[-1] | |||||
self.parent_ids[vid] = url_posfix | |||||
self.children_video_ids[url_posfix].append(vid) | |||||
self.sample_len = len(self.csv) | |||||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||||
def __len__(self): | |||||
return self.sample_len | |||||
def _get_text(self, video_id, caption=None): | |||||
k = 1 | |||||
choice_video_ids = [video_id] | |||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||||
for i, video_id in enumerate(choice_video_ids): | |||||
if caption is not None: | |||||
words = self.tokenizer.tokenize(caption) | |||||
else: | |||||
words = self._get_single_text(video_id) | |||||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||||
total_length_with_CLS = self.max_words - 1 | |||||
if len(words) > total_length_with_CLS: | |||||
words = words[:total_length_with_CLS] | |||||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||||
input_mask = [1] * len(input_ids) | |||||
segment_ids = [0] * len(input_ids) | |||||
while len(input_ids) < self.max_words: | |||||
input_ids.append(0) | |||||
input_mask.append(0) | |||||
segment_ids.append(0) | |||||
assert len(input_ids) == self.max_words | |||||
assert len(input_mask) == self.max_words | |||||
assert len(segment_ids) == self.max_words | |||||
pairs_text[i] = np.array(input_ids) | |||||
pairs_mask[i] = np.array(input_mask) | |||||
pairs_segment[i] = np.array(segment_ids) | |||||
return pairs_text, pairs_mask, pairs_segment, choice_video_ids | |||||
def _get_single_text(self, video_id): | |||||
rind = random.randint(0, len(self.sentences[video_id]) - 1) | |||||
caption = self.sentences[video_id][rind] | |||||
words = self.tokenizer.tokenize(caption) | |||||
return words | |||||
def _get_rawvideo(self, choice_video_ids): | |||||
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) | |||||
max_video_length = [0] * len(choice_video_ids) | |||||
# Pair x L x T x 3 x H x W | |||||
video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, | |||||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||||
for i, video_id in enumerate(choice_video_ids): | |||||
# Individual for YoucokII dataset, due to it video format | |||||
video_path = os.path.join(self.features_path, "{}.mp4".format(video_id)) | |||||
if os.path.exists(video_path) is False: | |||||
video_path = video_path.replace(".mp4", ".webm") | |||||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path) | |||||
raw_video_data = raw_video_data['video'] | |||||
if len(raw_video_data.shape) > 3: | |||||
raw_video_data_clip = raw_video_data | |||||
# L x T x 3 x H x W | |||||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||||
if self.max_frames < raw_video_slice.shape[0]: | |||||
if self.slice_framepos == 0: | |||||
video_slice = raw_video_slice[:self.max_frames, ...] | |||||
elif self.slice_framepos == 1: | |||||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||||
else: | |||||
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) | |||||
video_slice = raw_video_slice[sample_indx, ...] | |||||
else: | |||||
video_slice = raw_video_slice | |||||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||||
slice_len = video_slice.shape[0] | |||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||||
if slice_len < 1: | |||||
pass | |||||
else: | |||||
video[i][:slice_len, ...] = video_slice | |||||
else: | |||||
print("video path: {} error. video id: {}".format(video_path, video_id)) | |||||
for i, v_length in enumerate(max_video_length): | |||||
video_mask[i][:v_length] = [1] * v_length | |||||
return video, video_mask | |||||
def __getitem__(self, idx): | |||||
if self.unfold_sentences: | |||||
video_id, caption = self.sentences_dict[idx] | |||||
else: | |||||
video_id, caption = self.csv['video_id'].values[idx], None | |||||
pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption) | |||||
video, video_mask = self._get_rawvideo(choice_video_ids) | |||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask |
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import unicode_literals | |||||
from __future__ import print_function | |||||
import os | |||||
from torch.utils.data import Dataset | |||||
import numpy as np | |||||
import pickle | |||||
from dataloaders.rawvideo_util import RawVideoExtractor | |||||
class MSVD_DataLoader(Dataset): | |||||
"""MSVD dataset loader.""" | |||||
def __init__( | |||||
self, | |||||
subset, | |||||
data_path, | |||||
features_path, | |||||
tokenizer, | |||||
max_words=30, | |||||
feature_framerate=1.0, | |||||
max_frames=100, | |||||
image_resolution=224, | |||||
frame_order=0, | |||||
slice_framepos=0, | |||||
): | |||||
self.data_path = data_path | |||||
self.features_path = features_path | |||||
self.feature_framerate = feature_framerate | |||||
self.max_words = max_words | |||||
self.max_frames = max_frames | |||||
self.tokenizer = tokenizer | |||||
# 0: ordinary order; 1: reverse order; 2: random order. | |||||
self.frame_order = frame_order | |||||
assert self.frame_order in [0, 1, 2] | |||||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||||
self.slice_framepos = slice_framepos | |||||
assert self.slice_framepos in [0, 1, 2] | |||||
self.subset = subset | |||||
assert self.subset in ["train", "val", "test"] | |||||
video_id_path_dict = {} | |||||
video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") | |||||
video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") | |||||
video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") | |||||
caption_file = os.path.join(self.data_path, "raw-captions.pkl") | |||||
with open(video_id_path_dict[self.subset], 'r') as fp: | |||||
video_ids = [itm.strip() for itm in fp.readlines()] | |||||
with open(caption_file, 'rb') as f: | |||||
captions = pickle.load(f) | |||||
video_dict = {} | |||||
for root, dub_dir, video_files in os.walk(self.features_path): | |||||
for video_file in video_files: | |||||
video_id_ = ".".join(video_file.split(".")[:-1]) | |||||
if video_id_ not in video_ids: | |||||
continue | |||||
file_path_ = os.path.join(root, video_file) | |||||
video_dict[video_id_] = file_path_ | |||||
self.video_dict = video_dict | |||||
self.sample_len = 0 | |||||
self.sentences_dict = {} | |||||
self.cut_off_points = [] | |||||
for video_id in video_ids: | |||||
assert video_id in captions | |||||
for cap in captions[video_id]: | |||||
cap_txt = " ".join(cap) | |||||
self.sentences_dict[len(self.sentences_dict)] = (video_id, cap_txt) | |||||
self.cut_off_points.append(len(self.sentences_dict)) | |||||
## below variables are used to multi-sentences retrieval | |||||
# self.cut_off_points: used to tag the label when calculate the metric | |||||
# self.sentence_num: used to cut the sentence representation | |||||
# self.video_num: used to cut the video representation | |||||
self.multi_sentence_per_video = True # !!! important tag for eval | |||||
if self.subset == "val" or self.subset == "test": | |||||
self.sentence_num = len(self.sentences_dict) | |||||
self.video_num = len(video_ids) | |||||
assert len(self.cut_off_points) == self.video_num | |||||
print("For {}, sentence number: {}".format(self.subset, self.sentence_num)) | |||||
print("For {}, video number: {}".format(self.subset, self.video_num)) | |||||
print("Video number: {}".format(len(self.video_dict))) | |||||
print("Total Paire: {}".format(len(self.sentences_dict))) | |||||
self.sample_len = len(self.sentences_dict) | |||||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||||
def __len__(self): | |||||
return self.sample_len | |||||
def _get_text(self, video_id, caption): | |||||
k = 1 | |||||
choice_video_ids = [video_id] | |||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||||
for i, video_id in enumerate(choice_video_ids): | |||||
words = self.tokenizer.tokenize(caption) | |||||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||||
total_length_with_CLS = self.max_words - 1 | |||||
if len(words) > total_length_with_CLS: | |||||
words = words[:total_length_with_CLS] | |||||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||||
input_mask = [1] * len(input_ids) | |||||
segment_ids = [0] * len(input_ids) | |||||
while len(input_ids) < self.max_words: | |||||
input_ids.append(0) | |||||
input_mask.append(0) | |||||
segment_ids.append(0) | |||||
assert len(input_ids) == self.max_words | |||||
assert len(input_mask) == self.max_words | |||||
assert len(segment_ids) == self.max_words | |||||
pairs_text[i] = np.array(input_ids) | |||||
pairs_mask[i] = np.array(input_mask) | |||||
pairs_segment[i] = np.array(segment_ids) | |||||
return pairs_text, pairs_mask, pairs_segment, choice_video_ids | |||||
def _get_rawvideo(self, choice_video_ids): | |||||
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) | |||||
max_video_length = [0] * len(choice_video_ids) | |||||
# Pair x L x T x 3 x H x W | |||||
video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, | |||||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||||
for i, video_id in enumerate(choice_video_ids): | |||||
video_path = self.video_dict[video_id] | |||||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path) | |||||
raw_video_data = raw_video_data['video'] | |||||
if len(raw_video_data.shape) > 3: | |||||
raw_video_data_clip = raw_video_data | |||||
# L x T x 3 x H x W | |||||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||||
if self.max_frames < raw_video_slice.shape[0]: | |||||
if self.slice_framepos == 0: | |||||
video_slice = raw_video_slice[:self.max_frames, ...] | |||||
elif self.slice_framepos == 1: | |||||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||||
else: | |||||
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) | |||||
video_slice = raw_video_slice[sample_indx, ...] | |||||
else: | |||||
video_slice = raw_video_slice | |||||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||||
slice_len = video_slice.shape[0] | |||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||||
if slice_len < 1: | |||||
pass | |||||
else: | |||||
video[i][:slice_len, ...] = video_slice | |||||
else: | |||||
print("video path: {} error. video id: {}".format(video_path, video_id)) | |||||
for i, v_length in enumerate(max_video_length): | |||||
video_mask[i][:v_length] = [1] * v_length | |||||
return video, video_mask | |||||
def __getitem__(self, idx): | |||||
video_id, caption = self.sentences_dict[idx] | |||||
pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption) | |||||
video, video_mask = self._get_rawvideo(choice_video_ids) | |||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask |
import torch as th | |||||
import numpy as np | |||||
from PIL import Image | |||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |||||
import cv2 | |||||
class RawVideoExtractorCV2(): | |||||
def __init__(self, centercrop=False, size=224, framerate=-1, ): | |||||
self.centercrop = centercrop | |||||
self.size = size | |||||
self.framerate = framerate | |||||
self.transform = self._transform(self.size) | |||||
def _transform(self, n_px): | |||||
return Compose([ | |||||
Resize(n_px, interpolation=Image.BICUBIC), | |||||
CenterCrop(n_px), | |||||
lambda image: image.convert("RGB"), | |||||
ToTensor(), | |||||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||||
(0.26862954, 0.26130258, 0.27577711)), | |||||
]) | |||||
def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None): | |||||
if start_time is not None or end_time is not None: | |||||
assert isinstance(start_time, int) and isinstance(end_time, int) \ | |||||
and start_time > -1 and end_time > start_time | |||||
assert sample_fp > -1 | |||||
# Samples a frame sample_fp X frames. | |||||
cap = cv2.VideoCapture(video_file) | |||||
frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |||||
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |||||
total_duration = (frameCount + fps - 1) // fps | |||||
start_sec, end_sec = 0, total_duration | |||||
if start_time is not None: | |||||
start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration | |||||
cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) | |||||
interval = 1 | |||||
if sample_fp > 0: | |||||
interval = fps // sample_fp | |||||
else: | |||||
sample_fp = fps | |||||
if interval == 0: | |||||
interval = 1 | |||||
inds = [ind for ind in np.arange(0, fps, interval)] | |||||
assert len(inds) >= sample_fp | |||||
inds = inds[:sample_fp] | |||||
ret = True | |||||
images, included = [], [] | |||||
for sec in np.arange(start_sec, end_sec + 1): | |||||
if not ret: | |||||
break | |||||
sec_base = int(sec * fps) | |||||
for ind in inds: | |||||
cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) | |||||
ret, frame = cap.read() | |||||
if not ret: | |||||
break | |||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |||||
images.append(preprocess( | |||||
Image.fromarray(frame_rgb).convert("RGB"))) | |||||
cap.release() | |||||
if len(images) > 0: | |||||
video_data = th.tensor(np.stack(images)) | |||||
else: | |||||
video_data = th.zeros(1) | |||||
return {'video': video_data} | |||||
def get_video_data(self, video_path, start_time=None, end_time=None): | |||||
image_input = self.video_to_tensor( | |||||
video_path, self.transform, sample_fp=self.framerate, start_time=start_time, end_time=end_time) | |||||
return image_input | |||||
def process_raw_data(self, raw_video_data): | |||||
tensor_size = raw_video_data.size() | |||||
tensor = raw_video_data.view(-1, 1, | |||||
tensor_size[-3], tensor_size[-2], tensor_size[-1]) | |||||
return tensor | |||||
def process_frame_order(self, raw_video_data, frame_order=0): | |||||
# 0: ordinary order; 1: reverse order; 2: random order. | |||||
if frame_order == 0: | |||||
pass | |||||
elif frame_order == 1: | |||||
reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1) | |||||
raw_video_data = raw_video_data[reverse_order, ...] | |||||
elif frame_order == 2: | |||||
random_order = np.arange(raw_video_data.size(0)) | |||||
np.random.shuffle(random_order) | |||||
raw_video_data = raw_video_data[random_order, ...] | |||||
return raw_video_data | |||||
# An ordinary video frame extractor based CV2 | |||||
RawVideoExtractor = RawVideoExtractorCV2 |
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"!pip install transformers\n", | |||||
"!pip install pytube" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"!git clone https://github.com/abreza/clip-video-embedder\n", | |||||
"%cd clip-video-embedder/" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from utils.video_loader import download_video_from_youtube\n", | |||||
"\n", | |||||
"video_id = \"4uw4co69JUQ\"\n", | |||||
"video_path = download_video_from_youtube('4uw4co69JUQ', './videos')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from transformers import CLIPModel, CLIPProcessor\n", | |||||
"\n", | |||||
"model = CLIPModel.from_pretrained(\"clip-vit-large-patch14\")\n", | |||||
"processor = CLIPProcessor.from_pretrained(\"clip-vit-large-patch14\")" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import numpy as np\n", | |||||
"\n", | |||||
"from modules.frame_sampler.uniform_sampler import uniform_sample_frames\n", | |||||
"from modules.frame_sampler.aggregate_embedding import aggregate_embeddings\n", | |||||
"\n", | |||||
"num_frames = 10\n", | |||||
"video_frames = uniform_sample_frames(video_path, num_frames)\n", | |||||
"\n", | |||||
"frame_embeddings = []\n", | |||||
"for frame in video_frames:\n", | |||||
" inputs = processor(images=frame, return_tensors=\"pt\")\n", | |||||
" outputs = model.get_image_features(**inputs)\n", | |||||
" frame_embedding = outputs.detach().numpy()\n", | |||||
" frame_embeddings.append(frame_embedding)\n", | |||||
"\n", | |||||
"video_embedding = aggregate_embeddings(frame_embeddings, strategy=\"mean\")\n", | |||||
"\n", | |||||
"# Compare the video embedding to a given text\n", | |||||
"texts = [\"a dog playing in the park\", \"playing basketball\"]\n", | |||||
"for text in texts:\n", | |||||
" text_inputs = processor(text=text, return_tensors=\"pt\", padding=True)\n", | |||||
" text_outputs = model.get_text_features(**text_inputs)\n", | |||||
" text_embedding = text_outputs.detach().numpy()\n", | |||||
" \n", | |||||
" similarity = np.inner(video_embedding, text_embedding)\n", | |||||
" print(text, similarity)" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "venv", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.10.10" | |||||
}, | |||||
"orig_nbformat": 4 | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
from utils.dotdict import DotNotationDict | |||||
from dataloaders.data_dataloaders import DATALOADER_DICT | |||||
from transformers import CLIPTokenizer | |||||
def main(): | |||||
tokenizer = CLIPTokenizer.from_pretrained("clip-vit-large-patch14") | |||||
train_args = DotNotationDict({ | |||||
"train_csv": 'MSRVTT_train.9k.csv', | |||||
"data_path": 'MSRVTT_data.json', | |||||
"features_path": 'MSRVTT_Videos', | |||||
"max_words": 32, | |||||
"feature_framerate": 1, | |||||
"max_frames": 100, | |||||
"expand_msrvtt_sentences": True, | |||||
"train_frame_order": 0, | |||||
"slice_framepos": 2, | |||||
"batch_size": 128, | |||||
"n_gpu": 1, | |||||
"num_thread_reader": 0, | |||||
}) | |||||
train_dataloader, train_length = DATALOADER_DICT['msrvtt']["train"]( | |||||
train_args, tokenizer) | |||||
if __name__ == '__main__': | |||||
main() |
import numpy as np | |||||
from transformers import CLIPProcessor, CLIPModel | |||||
class VideoEmbedder: | |||||
def __init__(self, model_name="openai/clip-vit-large-patch14"): | |||||
self.model = CLIPModel.from_pretrained(model_name) | |||||
self.processor = CLIPProcessor.from_pretrained(model_name) | |||||
def uniform_sampling(self, frames, num_samples): | |||||
frame_count = len(frames) | |||||
indices = np.linspace(0, frame_count - 1, num_samples, dtype=int) | |||||
sampled_frames = [frames[i] for i in indices] | |||||
return sampled_frames | |||||
def get_frame_embeddings(self, frames): | |||||
inputs = self.processor(images=frames, return_tensors="pt") | |||||
outputs = self.model.get_image_features(**inputs) | |||||
frame_embeddings = outputs.detach().numpy() | |||||
return frame_embeddings | |||||
def aggregate_embeddings(self, embeddings, strategy="mean"): | |||||
if strategy == "mean": | |||||
return np.mean(embeddings, axis=0) | |||||
elif strategy == "average": | |||||
return np.average(embeddings, axis=0) | |||||
else: | |||||
raise ValueError("Invalid aggregation strategy") | |||||
def embed_video(self, frames, num_samples, aggregation_strategy="mean"): | |||||
sampled_frames = self.uniform_sampling(frames, num_samples) | |||||
frame_embeddings = self.get_frame_embeddings(sampled_frames) | |||||
aggregated_embedding = self.aggregate_embeddings( | |||||
frame_embeddings, aggregation_strategy) | |||||
return aggregated_embedding | |||||
def get_text_embedding(self, text): | |||||
text_inputs = self.processor( | |||||
text=text, return_tensors="pt", padding=True) | |||||
text_outputs = self.model.get_text_features(**text_inputs) | |||||
text_embedding = text_outputs.detach().numpy() | |||||
return text_embedding | |||||
def similarity(self, video_embedding, text_embedding): | |||||
similarity_score = np.dot(video_embedding, text_embedding.T) | |||||
return similarity_score |
import numpy as np | |||||
def aggregate_embeddings(embeddings, strategy="mean"): | |||||
if strategy == "mean": | |||||
return np.mean(embeddings, axis=0) | |||||
else: | |||||
raise ValueError(f"Unknown aggregation strategy: {strategy}") |
import timm | |||||
import torch.nn as nn | |||||
from transformers import CLIPProcessor, CLIPModel | |||||
class SaliencyNet(nn.Module): | |||||
def __init__(self): | |||||
super(SaliencyNet, self).__init__() | |||||
self.model = timm.create_model('resnet50', pretrained=True) | |||||
self.fc1 = nn.Linear(self.model.num_features, 512) | |||||
self.fc2 = nn.Linear(512, 1) | |||||
self.relu = nn.ReLU() | |||||
self.sigmoid = nn.Sigmoid() | |||||
def forward(self, images): | |||||
features = self.model.forward_features(images) | |||||
x = self.fc1(features) | |||||
x = self.relu(x) | |||||
x = self.fc2(x) | |||||
x = self.sigmoid(x) | |||||
return x | |||||
class CLIPTeacher(nn.Module): | |||||
def __init__(self): | |||||
super(CLIPTeacher, self).__init__() | |||||
self.model = CLIPModel.from_pretrained('openai/clip-vit-large-patch14') | |||||
self.processor = CLIPProcessor.from_pretrained( | |||||
'openai/clip-vit-large-patch14') | |||||
def forward(self, images, descriptions): | |||||
inputs = self.processor( | |||||
text=descriptions, images=images, return_tensors="pt", padding=True | |||||
) | |||||
outputs = self.model(**inputs) | |||||
logits_per_image = outputs.logits_per_image | |||||
average_scores = logits_per_image.mean(dim=1) | |||||
return average_scores |
import torch | |||||
import torch.nn as nn | |||||
from torch.utils.data import DataLoader | |||||
def train_salient_frame_sampler(teacher, student, train_dataloader: DataLoader, val_dataloader: DataLoader, epochs: int, optimizer, device): | |||||
teacher.eval() | |||||
teacher.to(device) | |||||
student.train() | |||||
student.to(device) | |||||
criterion = nn.MSELoss() | |||||
num_samples = 10 | |||||
for epoch in range(epochs): | |||||
running_loss = 0.0 | |||||
for frames, descriptions in train_dataloader: | |||||
frames = frames.to(device) | |||||
descriptions = descriptions.to(device) | |||||
student_scores = student(frames).squeeze() | |||||
with torch.no_grad(): | |||||
teacher_scores = teacher(frames, descriptions) | |||||
optimizer.zero_grad() | |||||
loss = criterion(student_scores, teacher_scores) | |||||
loss.backward() | |||||
optimizer.step() | |||||
running_loss += loss.item() | |||||
train_loss = running_loss / len(train_dataloader) | |||||
print(f"Epoch {epoch + 1}/{epochs}, Training Loss: {train_loss:.4f}") | |||||
# Calculate validation loss | |||||
with torch.no_grad(): | |||||
running_val_loss = 0.0 | |||||
for val_frames, val_descriptions in val_dataloader: | |||||
val_frames = val_frames.to(device) | |||||
val_descriptions = val_descriptions.to(device) | |||||
val_student_scores = student(val_frames).squeeze() | |||||
val_teacher_scores = teacher(val_frames, val_descriptions) | |||||
val_loss = criterion(val_student_scores, val_teacher_scores) | |||||
running_val_loss += val_loss.item() | |||||
val_loss = running_val_loss / len(val_dataloader) | |||||
print( | |||||
f"Epoch {epoch + 1}/{epochs}, Validation Loss: {val_loss:.4f}") |
import cv2 | |||||
import numpy as np | |||||
from PIL import Image | |||||
def uniform_sample_frames(video_path, num_frames): | |||||
cap = cv2.VideoCapture(video_path) | |||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |||||
frame_idxs = np.linspace(0, total_frames - 1, num=num_frames, dtype=int) | |||||
frames = [] | |||||
for idx in frame_idxs: | |||||
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |||||
ret, frame = cap.read() | |||||
if ret: | |||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |||||
frames.append(Image.fromarray(frame)) | |||||
cap.release() | |||||
return frames |
{ | |||||
"nbformat": 4, | |||||
"nbformat_minor": 0, | |||||
"metadata": { | |||||
"colab": { | |||||
"provenance": [] | |||||
}, | |||||
"kernelspec": { | |||||
"name": "python3", | |||||
"display_name": "Python 3" | |||||
}, | |||||
"language_info": { | |||||
"name": "python" | |||||
} | |||||
}, | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "QGkNRWwCX4Z4" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"!wget https://storage.googleapis.com/deepmind-media/Datasets/kinetics700_2020.tar.gz" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"source": [ | |||||
"!tar -xf kinetics700_2020.tar.gz" | |||||
], | |||||
"metadata": { | |||||
"id": "87GXSaprYGxe" | |||||
}, | |||||
"execution_count": null, | |||||
"outputs": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"source": [ | |||||
"import pandas as pd\n", | |||||
"\n", | |||||
"df = pd.read_csv('/content/kinetics700_2020/train.csv')\n", | |||||
"\n", | |||||
"unique_labels = df['label'].unique()\n", | |||||
"\n", | |||||
"for i, label in enumerate(unique_labels):\n", | |||||
" print(f\"{i}- {label}\")" | |||||
], | |||||
"metadata": { | |||||
"id": "wJ4USlBtZWi9" | |||||
}, | |||||
"execution_count": null, | |||||
"outputs": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"source": [ | |||||
"def merge_files(real_labels_file, gpt_labels_file):\n", | |||||
" with open(real_labels_file, 'r') as real_file, open(gpt_labels_file, 'r') as gpt_file:\n", | |||||
" real_labels = real_file.readlines()\n", | |||||
" gpt_labels = gpt_file.readlines()\n", | |||||
"\n", | |||||
" merged_labels = []\n", | |||||
" for real, gpt in zip(real_labels, gpt_labels):\n", | |||||
" real_label = real.split('-')[1].strip()\n", | |||||
" gpt_label = gpt.split('-')[1].strip()\n", | |||||
"\n", | |||||
" # Merge the labels and add to the list\n", | |||||
" merged_labels.append(f'{real_label}: {gpt_label}')\n", | |||||
"\n", | |||||
" return merged_labels\n", | |||||
"\n", | |||||
"real_labels_file = '/content/real-labels.txt'\n", | |||||
"gpt_labels_file = '/content/chatgpt-labels.txt'\n", | |||||
"\n", | |||||
"merged_labels = merge_files(real_labels_file, gpt_labels_file)" | |||||
], | |||||
"metadata": { | |||||
"id": "CUMXgnGkDXXB" | |||||
}, | |||||
"execution_count": null, | |||||
"outputs": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"source": [ | |||||
"!pip install transformers" | |||||
], | |||||
"metadata": { | |||||
"id": "gvWmOILaIUI1" | |||||
}, | |||||
"execution_count": null, | |||||
"outputs": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"source": [ | |||||
"from transformers import CLIPProcessor, CLIPModel\n", | |||||
"\n", | |||||
"model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n", | |||||
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")" | |||||
], | |||||
"metadata": { | |||||
"id": "gPfdGYiHIMss" | |||||
}, | |||||
"execution_count": null, | |||||
"outputs": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"source": [ | |||||
"with open('enhanced-kinetics-label-CLIP-embedding.txt', 'w') as write_file:\n", | |||||
" for label in merged_labels:\n", | |||||
" text_inputs = processor(text=label, return_tensors=\"pt\", padding=True)\n", | |||||
" text_features = model.get_text_features(**text_inputs).float()\n", | |||||
" # format output to only have 4 digits after decimal point\n", | |||||
" formatted_features = ['{:.4f}'.format(val) for val in text_features[0].tolist()]\n", | |||||
" write_file.write(label.replace(':', ' |') + ' | ' + str(formatted_features) + '\\n')" | |||||
], | |||||
"metadata": { | |||||
"id": "x84uh8PdIdXI" | |||||
}, | |||||
"execution_count": null, | |||||
"outputs": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"source": [ | |||||
"!wget https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip --no-check-certificate\n", | |||||
"!unzip UCF101TrainTestSplits-RecognitionTask.zip" | |||||
], | |||||
"metadata": { | |||||
"colab": { | |||||
"base_uri": "https://localhost:8080/" | |||||
}, | |||||
"id": "jbg7PJSXO9so", | |||||
"outputId": "f2e84ec0-37ae-49c3-a732-d35529d38eed" | |||||
}, | |||||
"execution_count": null, | |||||
"outputs": [ | |||||
{ | |||||
"output_type": "stream", | |||||
"name": "stdout", | |||||
"text": [ | |||||
"--2023-08-02 13:01:00-- https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip\n", | |||||
"Resolving www.crcv.ucf.edu (www.crcv.ucf.edu)... 132.170.214.127\n", | |||||
"Connecting to www.crcv.ucf.edu (www.crcv.ucf.edu)|132.170.214.127|:443... connected.\n", | |||||
"WARNING: cannot verify www.crcv.ucf.edu's certificate, issued by ‘CN=InCommon RSA Server CA,OU=InCommon,O=Internet2,L=Ann Arbor,ST=MI,C=US’:\n", | |||||
" Unable to locally verify the issuer's authority.\n", | |||||
"HTTP request sent, awaiting response... 200 OK\n", | |||||
"Length: 113943 (111K) [application/zip]\n", | |||||
"Saving to: ‘UCF101TrainTestSplits-RecognitionTask.zip’\n", | |||||
"\n", | |||||
"\r UCF101Tra 0%[ ] 0 --.-KB/s \rUCF101TrainTestSpli 100%[===================>] 111.27K --.-KB/s in 0.05s \n", | |||||
"\n", | |||||
"2023-08-02 13:01:00 (2.03 MB/s) - ‘UCF101TrainTestSplits-RecognitionTask.zip’ saved [113943/113943]\n", | |||||
"\n", | |||||
"Archive: UCF101TrainTestSplits-RecognitionTask.zip\n", | |||||
" creating: ucfTrainTestlist/\n", | |||||
" inflating: ucfTrainTestlist/classInd.txt \n", | |||||
" inflating: ucfTrainTestlist/testlist01.txt \n", | |||||
" inflating: ucfTrainTestlist/testlist02.txt \n", | |||||
" inflating: ucfTrainTestlist/testlist03.txt \n", | |||||
" inflating: ucfTrainTestlist/trainlist01.txt \n", | |||||
" inflating: ucfTrainTestlist/trainlist02.txt \n", | |||||
" inflating: ucfTrainTestlist/trainlist03.txt \n" | |||||
] | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"source": [ | |||||
"with open('enhanced-ucf-101-label-CLIP-embedding.txt', 'w') as write_file:\n", | |||||
" with open('ucf101-labels.txt', 'r') as ucf101_labels_file:\n", | |||||
" ucf101_labels = ucf101_labels_file.readlines()\n", | |||||
" for label in ucf101_labels:\n", | |||||
" text_inputs = processor(text=label, return_tensors=\"pt\", padding=True)\n", | |||||
" text_features = model.get_text_features(**text_inputs).float()\n", | |||||
" # format output to only have 4 digits after decimal point\n", | |||||
" formatted_features = ['{:.4f}'.format(val) for val in text_features[0].tolist()]\n", | |||||
" write_file.write(label.replace(':', ' |') + ' | ' + str(formatted_features) + '\\n')" | |||||
], | |||||
"metadata": { | |||||
"id": "jAQgg2nMgIE6" | |||||
}, | |||||
"execution_count": null, | |||||
"outputs": [] | |||||
} | |||||
] | |||||
} |
pytube | |||||
timm | |||||
transformers | |||||
torchvision | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"colab": { | |||||
"background_save": true | |||||
}, | |||||
"id": "ucqaqtLhPY6T", | |||||
"outputId": "ceea603c-8b61-4c19-cdc3-1de8062ecfc4" | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Mounted at /content/drive\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from google.colab import drive\n", | |||||
"drive.mount('/content/drive')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"id": "BBOM9m-EQ7so" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# !pip install opendatalab\n", | |||||
"# !odl login\n", | |||||
"# %cd /content/drive/MyDrive/Kinetics\n", | |||||
"# !odl get Kinetics_700-2020" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"colab": { | |||||
"background_save": true | |||||
}, | |||||
"id": "MTa17FVePkqM" | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"%cd /content/\n", | |||||
"\n", | |||||
"!git clone https://github.com/cvdfoundation/kinetics-dataset\n", | |||||
"\n", | |||||
"%cd /content/drive/MyDrive/Kinetics\n", | |||||
"!bash /content/kinetics-dataset/k700_2020_downloader.sh\n", | |||||
"!bash /content/kinetics-dataset/k700_2020_extractor.sh" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"colab": { | |||||
"provenance": [] | |||||
}, | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"name": "python" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 0 | |||||
} |
from datasets.msrvtt.dataset import MSRVTTDataset | |||||
from modules.clip_video_embedder import VideoEmbedder | |||||
def main(): | |||||
dataset = MSRVTTDataset() | |||||
video_embedder = VideoEmbedder() | |||||
for idx in range(len(dataset)): | |||||
frames, captions = dataset[idx] | |||||
video_embedding = video_embedder.embed_video(frames, num_samples=10) | |||||
for caption in captions: | |||||
text_embedding = video_embedder.get_text_embedding(caption) | |||||
similarity_score = video_embedder.similarity( | |||||
video_embedding, text_embedding) | |||||
print( | |||||
f"Similarity between video and caption '{caption}': {similarity_score}") | |||||
if __name__ == "__main__": | |||||
main() |
class DotNotationDict(dict): | |||||
"""Enables dot notation access to dictionary attributes.""" | |||||
def getattr(self, key): | |||||
return self.get(key) | |||||
def __setattr__(self, key, value): | |||||
self.__setitem__(key, value) | |||||
def __delattr__(self, key): | |||||
self.__delitem__(key) |
import torch | |||||
def uniform_sample_frames(frames, num_samples): | |||||
total_frames = frames.shape[0] | |||||
step_size = total_frames // num_samples | |||||
indices = torch.arange(0, total_frames, step_size)[:num_samples] | |||||
return frames[indices] |
import os | |||||
import pytube | |||||
def download_video_from_youtube(video_id_or_url, destination_path=None, video_name=None): | |||||
video_url = video_id_or_url if "://" in video_id_or_url else f'https://youtu.be/{video_id_or_url}' | |||||
video_id = pytube.extract.video_id(video_url) | |||||
destination_path = destination_path if destination_path else '/content/' | |||||
video_name = video_name if video_name else f'{video_id}.mp4' | |||||
video_path = os.path.join(destination_path, video_name) | |||||
if os.path.isfile(video_path): | |||||
return video_path | |||||
print(f'Downloading YouTube video {video_id}.') | |||||
pytube.YouTube(video_url).streams.get_highest_resolution().download( | |||||
destination_path, filename=video_name) | |||||
print(f'Download complete.') | |||||
return video_path |