@@ -0,0 +1,4 @@ | |||
venv/ | |||
**/videos/ | |||
**/__pycache__/ | |||
**/__pycache__/ |
@@ -0,0 +1,7 @@ | |||
{ | |||
"[python]": { | |||
"editor.defaultFormatter": "ms-python.autopep8" | |||
}, | |||
"python.formatting.provider": "none", | |||
"cSpell.words": ["embedder", "pytube"] | |||
} |
@@ -0,0 +1,62 @@ | |||
# Video Action Recognition Using Transfer Learning and Attention Mechanisms | |||
This project focuses on video action recognition using deep learning techniques, leveraging transfer learning from language models and attention mechanisms. | |||
## Getting Started | |||
### 1. Dataset Preparation | |||
1.1. Download the Kinetics dataset: | |||
- Use `save_kinetics_dataset.ipynb` to download the dataset. | |||
- Alternatively, you can use `download_k400.ipynb`. | |||
1.2. Save the dataset: | |||
- Store the downloaded dataset in your Google Drive for easy access. | |||
### 2. Label Preprocessing | |||
2.1. Update Kinetics labels: | |||
- Run `preprocess_kinetics_labels.ipynb`. | |||
- This script uses GPT-4 to generate detailed descriptions for each video action. | |||
### 3. Model Training | |||
3.1. Post-pretraining of VideoMAE: | |||
- Execute `postpretrain_VideoMAE_to_CLIP_Space.ipynb`. | |||
- This notebook trains a transformer layer to map VideoMAE embeddings to CLIP space. | |||
### 4. Testing | |||
4.1. Prepare the test dataset: | |||
- Download the UCF101 dataset. | |||
- Update the UCF101 labels using GPT-4, similar to the Kinetics label preprocessing step. | |||
4.2. Run the test: | |||
- Use `test.ipynb` to evaluate the model's performance. | |||
## Prerequisites | |||
- Python 3.x | |||
- Jupyter Notebook | |||
- PyTorch | |||
- Transformers library | |||
- CLIP model | |||
- VideoMAE model | |||
- Access to GPT-4 API for label preprocessing | |||
- Google Drive (for storing datasets) | |||
## Usage | |||
1. Follow the steps in the "Getting Started" section to prepare your data and train the model. | |||
2. Ensure all datasets are properly saved in your Google Drive. | |||
3. Run the notebooks in the order specified above. | |||
4. For testing, make sure you have the UCF101 dataset prepared and labels updated before running `test.ipynb`. | |||
The model processes multiple frames from a video scene and creates rich representations in the CLIP space. | |||
## Future Work | |||
- Implement an adaptive frame selection unit | |||
- Extend to more diverse datasets | |||
- Integrate multimodal inputs (e.g., audio) | |||
- Fine-tune hyperparameters |
@@ -0,0 +1,751 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"colab": { | |||
"base_uri": "https://localhost:8080/" | |||
}, | |||
"id": "ZGGB7jSjF1RD", | |||
"outputId": "95aaefdd-a00e-46d9-86f9-556ab6ff53c2" | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Downloading...\n", | |||
"From: https://drive.google.com/uc?id=18AzsP1DOBECBL3vjzUhokro40mtu4Ggn\n", | |||
"To: /content/vit_s_k710_dl_from_giant.pth\n", | |||
"100% 44.3M/44.3M [00:00<00:00, 76.7MB/s]\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"!gdown 18AzsP1DOBECBL3vjzUhokro40mtu4Ggn # vit_s_k710_dl_from_giant.pth" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"colab": { | |||
"base_uri": "https://localhost:8080/" | |||
}, | |||
"id": "G9r7PONyCbMO", | |||
"outputId": "6758cef5-cac8-45a9-f424-3525dd38a25b" | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Requirement already satisfied: decord in /usr/local/lib/python3.10/dist-packages (0.6.0)\n", | |||
"Requirement already satisfied: deepspeed in /usr/local/lib/python3.10/dist-packages (0.12.2)\n", | |||
"Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (0.7.0)\n", | |||
"Requirement already satisfied: timm==0.4.12 in /usr/local/lib/python3.10/dist-packages (0.4.12)\n", | |||
"Requirement already satisfied: tensorboardX in /usr/local/lib/python3.10/dist-packages (2.6.2.2)\n", | |||
"Requirement already satisfied: mpi4py in /usr/local/lib/python3.10/dist-packages (3.1.5)\n", | |||
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.0)\n", | |||
"Requirement already satisfied: torch>=1.4 in /usr/local/lib/python3.10/dist-packages (from timm==0.4.12) (2.1.0+cu118)\n", | |||
"Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from timm==0.4.12) (0.16.0+cu118)\n", | |||
"Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from decord) (1.23.5)\n", | |||
"Requirement already satisfied: hjson in /usr/local/lib/python3.10/dist-packages (from deepspeed) (3.1.0)\n", | |||
"Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from deepspeed) (1.11.1.1)\n", | |||
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from deepspeed) (23.2)\n", | |||
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from deepspeed) (5.9.5)\n", | |||
"Requirement already satisfied: py-cpuinfo in /usr/local/lib/python3.10/dist-packages (from deepspeed) (9.0.0)\n", | |||
"Requirement already satisfied: pydantic in /usr/local/lib/python3.10/dist-packages (from deepspeed) (1.10.13)\n", | |||
"Requirement already satisfied: pynvml in /usr/local/lib/python3.10/dist-packages (from deepspeed) (11.5.0)\n", | |||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from deepspeed) (4.66.1)\n", | |||
"Requirement already satisfied: protobuf>=3.20 in /usr/local/lib/python3.10/dist-packages (from tensorboardX) (3.20.3)\n", | |||
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.4)\n", | |||
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.17.3)\n", | |||
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", | |||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", | |||
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", | |||
"Requirement already satisfied: tokenizers<0.15,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.14.1)\n", | |||
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.0)\n", | |||
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n", | |||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n", | |||
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.4->timm==0.4.12) (1.12)\n", | |||
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.4->timm==0.4.12) (3.2)\n", | |||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4->timm==0.4.12) (3.1.2)\n", | |||
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4->timm==0.4.12) (2.1.0)\n", | |||
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.1)\n", | |||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n", | |||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", | |||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n", | |||
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision->timm==0.4.12) (9.4.0)\n", | |||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.4->timm==0.4.12) (2.1.3)\n", | |||
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.4->timm==0.4.12) (1.3.0)\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"!pip install decord deepspeed einops timm==0.4.12 tensorboardX mpi4py transformers" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"colab": { | |||
"base_uri": "https://localhost:8080/" | |||
}, | |||
"id": "6qwB9RtaEVl5", | |||
"outputId": "43f46c83-0f1d-404e-a1af-b08a5af05f4a" | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"/content\n", | |||
"fatal: destination path 'VideoMAEv2' already exists and is not an empty directory.\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"%cd /content\n", | |||
"!git clone https://github.com/OpenGVLab/VideoMAEv2" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "Bwj7eNpwGLDI" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"class_to_label = {}\n", | |||
"\n", | |||
"with open(\"/content/ucfTrainTestlist/classInd.txt\", 'r') as f:\n", | |||
" for line in f:\n", | |||
" label, class_name = line.split()\n", | |||
" class_to_label[class_name] = label\n", | |||
"\n", | |||
"modes = ['train', 'test']\n", | |||
"\n", | |||
"for mode in modes:\n", | |||
" files = [\n", | |||
" f\"/content/ucfTrainTestlist/{mode}list01.txt\",\n", | |||
" f\"/content/ucfTrainTestlist/{mode}list02.txt\",\n", | |||
" f\"/content/ucfTrainTestlist/{mode}list03.txt\",\n", | |||
" ]\n", | |||
"\n", | |||
" output_file_path = f\"/content/UCF101/{'val' if mode == 'test' else mode}.csv\"\n", | |||
"\n", | |||
" with open(output_file_path, 'w') as outfile:\n", | |||
" for file_path in files:\n", | |||
" with open(file_path, 'r') as infile:\n", | |||
" for line in infile:\n", | |||
" line_text = line.strip()\n", | |||
" if mode == 'train':\n", | |||
" video_path, label = line_text.split(' ')\n", | |||
" else:\n", | |||
" video_path = line_text\n", | |||
" class_name = os.path.dirname(line_text)\n", | |||
" label = class_to_label[class_name]\n", | |||
" outfile.write(f\"/content/UCF101/{video_path} {label}\\n\")" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"colab": { | |||
"base_uri": "https://localhost:8080/" | |||
}, | |||
"id": "7Tfx84ZrEfFB", | |||
"outputId": "b25afa77-92f1-4b69-c312-e21012794ba3" | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"/content/VideoMAEv2\n", | |||
"/content\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"%cd /content/VideoMAEv2\n", | |||
"from dataset.datasets import VideoClsDataset\n", | |||
"\n", | |||
"from models.modeling_finetune import VisionTransformer, _cfg\n", | |||
"from utils import (\n", | |||
" load_state_dict,\n", | |||
")\n", | |||
"from optim_factory import (\n", | |||
" LayerDecayValueAssigner,\n", | |||
" get_parameter_groups,\n", | |||
")\n", | |||
"%cd /content/" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"colab": { | |||
"base_uri": "https://localhost:8080/" | |||
}, | |||
"id": "WG9MzWkQI5tP", | |||
"outputId": "40e59861-2ca7-400b-9d02-43285a68a85b" | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"[2023-11-05 21:20:09,304] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"import os\n", | |||
"import sys\n", | |||
"import json\n", | |||
"import warnings\n", | |||
"import math\n", | |||
"import argparse\n", | |||
"import logging\n", | |||
"import random\n", | |||
"import gc\n", | |||
"import tqdm\n", | |||
"\n", | |||
"from collections import OrderedDict\n", | |||
"\n", | |||
"import numpy as np\n", | |||
"import pandas as pd\n", | |||
"import deepspeed\n", | |||
"import torch\n", | |||
"import torch.nn as nn\n", | |||
"import torch.nn.functional as F\n", | |||
"import torch.backends.cudnn as cudnn\n", | |||
"import torch.utils.checkpoint as cp\n", | |||
"from torch.utils.data import Dataset\n", | |||
"from torch.utils.data._utils.collate import default_collate\n", | |||
"from torchvision import transforms\n", | |||
"from timm.models import create_model\n", | |||
"from timm.models.layers import trunc_normal_\n", | |||
"from timm.models.registry import register_model\n", | |||
"from timm.loss import SoftTargetCrossEntropy\n", | |||
"from functools import partial\n", | |||
"from datetime import datetime\n", | |||
"\n", | |||
"from transformers import AutoTokenizer, CLIPModel" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "yFhdWFvhbLwF" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"class EnhancedVisionTransformer(VisionTransformer):\n", | |||
" def get_embeddings(self, x):\n", | |||
" B, _, T, H, W = x.shape\n", | |||
"\n", | |||
" x = self.patch_embed(x)\n", | |||
"\n", | |||
" if self.pos_embed is not None:\n", | |||
" x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()\n", | |||
" x = self.pos_drop(x)\n", | |||
"\n", | |||
" for blk in self.blocks:\n", | |||
" if self.with_cp:\n", | |||
" x = cp.checkpoint(blk, x)\n", | |||
" else:\n", | |||
" x = blk(x)\n", | |||
"\n", | |||
" B, num_patches, embed_dim = x.shape\n", | |||
"\n", | |||
" T = T // self.tubelet_size\n", | |||
" x = x.view(B, T, num_patches // T, embed_dim)\n", | |||
" x = x.reshape(B, T, -1)\n", | |||
"\n", | |||
" return x\n", | |||
"\n", | |||
"@register_model\n", | |||
"def vit_small_patch16_224(pretrained=False, **kwargs):\n", | |||
" model = EnhancedVisionTransformer(\n", | |||
" patch_size=16,\n", | |||
" embed_dim=384,\n", | |||
" depth=12,\n", | |||
" num_heads=6,\n", | |||
" mlp_ratio=4,\n", | |||
" qkv_bias=True,\n", | |||
" norm_layer=partial(nn.LayerNorm, eps=1e-6),\n", | |||
" **kwargs)\n", | |||
" model.default_cfg = _cfg()\n", | |||
" return model" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "gRgLU0IzIcxQ" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"def add_space_before_uppercase(s):\n", | |||
" result = [s[0]]\n", | |||
"\n", | |||
" for char in s[1:]:\n", | |||
" if char.isupper():\n", | |||
" result.append(' ')\n", | |||
" result.append(char)\n", | |||
"\n", | |||
" return ''.join(result)\n", | |||
"\n", | |||
"class EnhancedVideoClsDataset(VideoClsDataset):\n", | |||
" def __getitem__(self, index):\n", | |||
" original_data = super().__getitem__(index)\n", | |||
"\n", | |||
" video_filename = self.dataset_samples[index].split('/')[-1].split('.')[0].split('_')[1]\n", | |||
"\n", | |||
" label = add_space_before_uppercase(video_filename)\n", | |||
"\n", | |||
" return (*original_data, ''.join(label))" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "WhYEWCxiJBMa" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"args = argparse.Namespace()\n", | |||
"args.model = 'vit_small_patch16_224'\n", | |||
"args.data_set = 'UCF101'\n", | |||
"args.nb_classes = 101\n", | |||
"args.data_path = '/content/UCF101/'\n", | |||
"args.finetune = '/content/vit_s_k710_dl_from_giant.pth'\n", | |||
"args.batch_size = 6\n", | |||
"args.input_size = 224\n", | |||
"args.short_side_size = 224\n", | |||
"args.num_frames = 16\n", | |||
"args.sampling_rate = 10\n", | |||
"args.num_sample = 2\n", | |||
"args.num_workers = 2\n", | |||
"args.opt = 'adamw'\n", | |||
"args.opt_eps = 1e-8\n", | |||
"args.opt_betas = [0.9, 0.999]\n", | |||
"args.lr = 1e-3\n", | |||
"args.min_lr = 1e-6\n", | |||
"args.drop = 0.0\n", | |||
"args.attn_drop_rate = 0.0\n", | |||
"args.drop_path = 0.35\n", | |||
"args.clip_grad = None # 5.0\n", | |||
"args.aa = 'rand-m7-n4-mstd0.5-inc1'\n", | |||
"args.layer_decay = 0.92 # 0.9\n", | |||
"args.weight_decay = 0.06 # 0.05\n", | |||
"args.epochs = 5\n", | |||
"\n", | |||
"args.tubelet_size = 2\n", | |||
"args.with_checkpoint = True\n", | |||
"args.train_interpolation = 'bicubic'\n", | |||
"args.reprob = 0.25\n", | |||
"args.remode = 'pixel'\n", | |||
"args.recount = 1\n", | |||
"args.data_root = ''\n", | |||
"\n", | |||
"args.num_segments = 1\n", | |||
"\n", | |||
"args.start_epoch = 0\n", | |||
"\n", | |||
"args.pin_mem = True" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "rhsNVcWuKMm0" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |||
"\n", | |||
"torch.manual_seed(0)\n", | |||
"np.random.seed(0)\n", | |||
"random.seed(0)\n", | |||
"cudnn.benchmark = True" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "eRRzjmFsMM0T" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"def multiple_samples_collate(batch):\n", | |||
" inputs, labels, video_idx, extra_data, video_file_names = zip(*batch)\n", | |||
" inputs = [item for sublist in inputs for item in sublist]\n", | |||
" labels = [item for sublist in labels for item in sublist]\n", | |||
" video_idx = [item for sublist in video_idx for item in sublist]\n", | |||
" inputs, labels, video_idx, extra_data, video_file_names = (\n", | |||
" default_collate(inputs),\n", | |||
" default_collate(labels),\n", | |||
" default_collate(video_idx),\n", | |||
" default_collate(extra_data),\n", | |||
" default_collate(video_file_names),\n", | |||
" )\n", | |||
" return inputs, labels, video_idx, extra_data, video_file_names" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "Twu9hFl_KhbN" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"train_dataset = EnhancedVideoClsDataset(\n", | |||
" anno_path=os.path.join(args.data_path, 'train.csv'),\n", | |||
" data_root=args.data_root,\n", | |||
" clip_len=args.num_frames,\n", | |||
" frame_sample_rate=args.sampling_rate,\n", | |||
" num_segment=1,\n", | |||
" crop_size=args.input_size,\n", | |||
" short_side_size=args.short_side_size,\n", | |||
" mode='train',\n", | |||
" args=args)\n", | |||
"\n", | |||
"\n", | |||
"data_loader_train = torch.utils.data.DataLoader(\n", | |||
" train_dataset,\n", | |||
" batch_size=args.batch_size,\n", | |||
" num_workers=args.num_workers,\n", | |||
" pin_memory=args.pin_mem,\n", | |||
" shuffle=True,\n", | |||
" drop_last=True,\n", | |||
" collate_fn=partial(multiple_samples_collate),\n", | |||
" persistent_workers=True)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "eWxUWHrlLBYN" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"class FrameScorePredictor(nn.Module):\n", | |||
" def __init__(self, create_model_fn, args):\n", | |||
" super().__init__()\n", | |||
"\n", | |||
"\n", | |||
" self.base_model = create_model_fn(\n", | |||
" args.model,\n", | |||
" img_size=args.input_size,\n", | |||
" pretrained=False,\n", | |||
" all_frames=args.num_frames * args.num_segments,\n", | |||
" tubelet_size=args.tubelet_size,\n", | |||
" drop_rate=args.drop,\n", | |||
" drop_path_rate=args.drop_path,\n", | |||
" attn_drop_rate=args.attn_drop_rate,\n", | |||
" drop_block_rate=None,\n", | |||
" with_cp=args.with_checkpoint,\n", | |||
" )\n", | |||
"\n", | |||
" self.base_embedding_dim = self.base_model.embed_dim * 14 ** 2 # TODO: use parameters\n", | |||
" self.embedding_frame_number = args.num_frames // args.tubelet_size\n", | |||
"\n", | |||
" self.score_predictor = nn.Sequential(\n", | |||
" nn.Linear(self.base_embedding_dim * self.embedding_frame_number, self.embedding_frame_number)\n", | |||
" )\n", | |||
"\n", | |||
" def forward(self, x):\n", | |||
" batch_size, num_channels, num_frames, height, width = x.shape\n", | |||
"\n", | |||
" embeddings = self.base_model.get_embeddings(x)\n", | |||
"\n", | |||
" embeddings = embeddings.reshape(batch_size, self.base_embedding_dim * self.embedding_frame_number)\n", | |||
"\n", | |||
" frame_scores = self.score_predictor(embeddings)\n", | |||
" return frame_scores\n", | |||
"\n", | |||
" def load_base_model_state_dict(self, checkpoint_model):\n", | |||
" load_state_dict(self.base_model, checkpoint_model)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"colab": { | |||
"base_uri": "https://localhost:8080/" | |||
}, | |||
"id": "VO4DA_QkKl0v", | |||
"outputId": "a107c6dd-272a-4970-ced6-cd0e16abbaa5" | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"size mismatch for head.weight: copying a param with shape torch.Size([710, 384]) from checkpoint, the shape in current model is torch.Size([1000, 384]).\n", | |||
"size mismatch for head.bias: copying a param with shape torch.Size([710]) from checkpoint, the shape in current model is torch.Size([1000]).\n" | |||
] | |||
}, | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"124" | |||
] | |||
}, | |||
"execution_count": 15, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"model = FrameScorePredictor(create_model, args=args)\n", | |||
"\n", | |||
"checkpoint = torch.load(args.finetune, map_location='cpu')\n", | |||
"\n", | |||
"checkpoint_model = checkpoint['model'] if 'model' in checkpoint else checkpoint['module']\n", | |||
"\n", | |||
"model.load_base_model_state_dict(checkpoint_model)\n", | |||
"\n", | |||
"model = model.to(device)\n", | |||
"\n", | |||
"del checkpoint\n", | |||
"del checkpoint_model\n", | |||
"\n", | |||
"gc.collect()" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "_aFo0-9J93FN" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"class PairwiseLoss(torch.nn.Module):\n", | |||
" def __init__(self, margin=0.2):\n", | |||
" super(PairwiseLoss, self).__init__()\n", | |||
" self.margin = margin\n", | |||
"\n", | |||
" def forward(self, scores, labels):\n", | |||
" pairwise_diff = scores.unsqueeze(1) - scores.unsqueeze(0)\n", | |||
"\n", | |||
" positive_mask = (labels.unsqueeze(1) - labels.unsqueeze(0)) > self.margin\n", | |||
"\n", | |||
" positive_diffs = pairwise_diff[positive_mask]\n", | |||
"\n", | |||
" loss = torch.clamp(positive_diffs, min=0).sum()\n", | |||
"\n", | |||
" num_positive_pairs = positive_mask.sum()\n", | |||
" return loss / num_positive_pairs if num_positive_pairs > 0 else loss\n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "GVGSAGKlPJye" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n", | |||
"clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to(device)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "E9nez-e8TOLp" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"all_classes = [add_space_before_uppercase(text) for text in class_to_label.keys()]\n", | |||
"all_classes_inputs = tokenizer(all_classes, padding=True, truncation=True, return_tensors=\"pt\").to(device)\n", | |||
"with torch.no_grad():\n", | |||
" text_embeddings = clip_model.get_text_features(**all_classes_inputs)\n", | |||
"\n", | |||
"all_classes_clip_embedding = {text: emb for text, emb in zip(all_classes, text_embeddings)}\n", | |||
"all_classes_clip_embedding_list = torch.stack(list(all_classes_clip_embedding.values())).to(device)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "9lRGZ1ZuYmEe" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"def compute_accuracies(frame_features, target_index):\n", | |||
" similarities = torch.matmul(frame_features, all_classes_clip_embedding_list.T)\n", | |||
"\n", | |||
" top5_indices = torch.topk(similarities, 5, largest=True).indices\n", | |||
"\n", | |||
" top1 = int(top5_indices[0] == target_index)\n", | |||
" top5 = int(target_index in top5_indices)\n", | |||
"\n", | |||
" return top1, top5" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"colab": { | |||
"background_save": true | |||
}, | |||
"id": "nQ3f69MnGLcW" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"embedding_frame_number = args.num_frames // args.tubelet_size\n", | |||
"ranking_loss_scale = embedding_frame_number**3-embedding_frame_number\n", | |||
"\n", | |||
"# mse_criterion = nn.MSELoss()\n", | |||
"# margin_ranking_loss = nn.MarginRankingLoss(margin=1)\n", | |||
"\n", | |||
"optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n", | |||
"\n", | |||
"gc.collect()\n", | |||
"torch.cuda.empty_cache()\n", | |||
"\n", | |||
"model.train()\n", | |||
"\n", | |||
"bce_loss_sum = 0.0\n", | |||
"mse_loss_sum = 0.0\n", | |||
"rank_loss_sum = 0.0\n", | |||
"combined_loss_sum = 0.0\n", | |||
"print_step = 10\n", | |||
"\n", | |||
"strategies_items = ['random', 'best_pseudo', 'best_predicted']\n", | |||
"\n", | |||
"acc1_sum = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||
"acc5_sum = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||
"acc_counter = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||
"\n", | |||
"for step, (samples, targets, _, _, classes) in enumerate(data_loader_train):\n", | |||
" optimizer.zero_grad()\n", | |||
"\n", | |||
" pseudo_labels = []\n", | |||
" samples = samples.to(device)\n", | |||
" targets = targets.to(device)\n", | |||
" inputs = tokenizer(classes, padding=True, return_tensors=\"pt\").to(device)\n", | |||
" text_features = clip_model.get_text_features(**inputs)\n", | |||
"\n", | |||
" all_frames_features = []\n", | |||
"\n", | |||
" for i in range(samples.shape[0]):\n", | |||
" video_frames = samples[i].permute(1, 0, 2, 3)\n", | |||
" video_frames = video_frames[1::2]\n", | |||
" frames_features = clip_model.get_image_features(pixel_values=video_frames)\n", | |||
" all_frames_features.append(frames_features)\n", | |||
" similarity = torch.matmul(frames_features, text_features[i // args.num_sample])\n", | |||
" pseudo_labels.append(similarity)\n", | |||
" pseudo_labels = torch.stack(pseudo_labels).to(device)\n", | |||
"\n", | |||
" predicted_scores = model(samples)\n", | |||
"\n", | |||
" # mse_loss = mse_criterion(predicted_scores, pseudo_labels)\n", | |||
"\n", | |||
" best_frame_targets = torch.zeros_like(pseudo_labels)\n", | |||
" best_frame_indices = pseudo_labels.argmax(dim=1)\n", | |||
" best_frame_targets[torch.arange(pseudo_labels.size(0)), best_frame_indices] = 1\n", | |||
" bce_loss = F.binary_cross_entropy_with_logits(predicted_scores, best_frame_targets)\n", | |||
"\n", | |||
" # target = (pseudo_labels[1:] > pseudo_labels[:-1]).float() * 2 - 1\n", | |||
" # rank_loss = margin_ranking_loss(predicted_scores[:-1], predicted_scores[1:], target)\n", | |||
"\n", | |||
" # combined_loss = bce_loss + 0.02 * mse_loss + rank_loss\n", | |||
"\n", | |||
" # combined_loss.backward()\n", | |||
" bce_loss.backward()\n", | |||
" optimizer.step()\n", | |||
"\n", | |||
" # mse_loss_sum += mse_loss.item()\n", | |||
" bce_loss_sum += bce_loss.item()\n", | |||
" # rank_loss_sum += rank_loss.item()\n", | |||
" # combined_loss_sum += combined_loss.item()\n", | |||
"\n", | |||
" strategies = {\n", | |||
" 'random': torch.randint(low=0, high=args.num_frames // 2, size=(samples.size(0),)).to(device),\n", | |||
" 'best_pseudo': pseudo_labels.argmax(dim=1),\n", | |||
" 'best_predicted': predicted_scores.argmax(dim=1)\n", | |||
" }\n", | |||
"\n", | |||
" for strategy_name, frame_indices in strategies.items():\n", | |||
" for frames_features, frame_indice, target in zip(all_frames_features, frame_indices, targets):\n", | |||
" frame_features = frames_features[frame_indice]\n", | |||
" acc1, acc5 = compute_accuracies(frame_features, target - 1)\n", | |||
"\n", | |||
" acc1_sum[strategy_name] += acc1\n", | |||
" acc5_sum[strategy_name] += acc5\n", | |||
" acc_counter[strategy_name] += 1\n", | |||
"\n", | |||
" if (step + 1) % print_step == 0:\n", | |||
" print(\n", | |||
" f'Step {step}/{len(data_loader_train)}\\n'\n", | |||
" # f'Average MSE Loss: {mse_loss_sum / print_step}\\n'\n", | |||
" f'Average BCE Loss: {bce_loss_sum / print_step}\\n'\n", | |||
" # f'Average Ranking Loss: {rank_loss_sum / print_step}\\n'\n", | |||
" # f'Average Combined Loss: {combined_loss_sum / print_step}'\n", | |||
" )\n", | |||
"\n", | |||
" for strategy_name in strategies.keys():\n", | |||
" acc1_avg = acc1_sum[strategy_name] / acc_counter[strategy_name]\n", | |||
" acc5_avg = acc5_sum[strategy_name] / acc_counter[strategy_name]\n", | |||
"\n", | |||
" print(f\"{strategy_name} Average Frame Acc@1: {acc1_avg:.2f}%, Acc@5: {acc5_avg:.2f}%\")\n", | |||
"\n", | |||
" # mse_loss_sum = 0.0\n", | |||
" bce_loss_sum = 0.0\n", | |||
" # rank_loss_sum = 0.0\n", | |||
" # combined_loss_sum = 0.0\n", | |||
"\n", | |||
"\n", | |||
" acc1_sum = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||
" acc5_sum = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||
" acc_counter = {strategy_name: 0 for strategy_name in strategies_items}\n", | |||
"\n", | |||
" print('------------------')\n", | |||
" print()\n", | |||
"\n", | |||
"\n", | |||
" gc.collect()\n", | |||
" torch.cuda.empty_cache()" | |||
] | |||
} | |||
], | |||
"metadata": { | |||
"accelerator": "GPU", | |||
"colab": { | |||
"provenance": [] | |||
}, | |||
"kernelspec": { | |||
"display_name": "Python 3", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"name": "python" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 0 | |||
} |
@@ -0,0 +1,396 @@ | |||
{ | |||
"nbformat": 4, | |||
"nbformat_minor": 0, | |||
"metadata": { | |||
"colab": { | |||
"provenance": [] | |||
}, | |||
"kernelspec": { | |||
"name": "python3", | |||
"display_name": "Python 3" | |||
}, | |||
"language_info": { | |||
"name": "python" | |||
} | |||
}, | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"colab": { | |||
"base_uri": "https://localhost:8080/" | |||
}, | |||
"id": "OoqkDAv64cdu", | |||
"outputId": "e0d30294-01b0-48d7-8199-d3ade0698c52" | |||
}, | |||
"outputs": [ | |||
{ | |||
"output_type": "stream", | |||
"name": "stdout", | |||
"text": [ | |||
"Downloading...\n", | |||
"From: https://drive.google.com/uc?id=1bLhNoh7VNY3nkVlhQ0TBo-KQzTa-WW8c\n", | |||
"To: /content/enhanced-ucf-101-label-CLIP-embedding.txt\n", | |||
"\r 0% 0.00/667k [00:00<?, ?B/s]\r100% 667k/667k [00:00<00:00, 24.2MB/s]\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"!gdown 1bLhNoh7VNY3nkVlhQ0TBo-KQzTa-WW8c" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"import numpy as np\n", | |||
"from sklearn.metrics.pairwise import cosine_similarity\n", | |||
"from operator import itemgetter\n", | |||
"\n", | |||
"def parse_embedding_line(line):\n", | |||
" parts = line.strip().split('|')\n", | |||
" label = parts[0].strip()\n", | |||
" description = parts[1].strip()\n", | |||
" embedding = np.array([float(x) for x in parts[2].strip(' []').split(',')])\n", | |||
" return label, description, embedding\n", | |||
"\n", | |||
"# Read the file\n", | |||
"with open('/content/enhanced-ucf-101-label-CLIP-embedding.txt', 'r') as f:\n", | |||
" lines = f.readlines()\n", | |||
"\n", | |||
"# Parse the lines\n", | |||
"embeddings = [parse_embedding_line(line) for line in lines]\n", | |||
"\n", | |||
"# Calculate similarity and sort\n", | |||
"similarities = []\n", | |||
"for i in range(len(embeddings)):\n", | |||
" for j in range(i+1, len(embeddings)):\n", | |||
" similarity = cosine_similarity([embeddings[i][2]], [embeddings[j][2]])[0][0]\n", | |||
" similarities.append((embeddings[i][0], embeddings[j][0], similarity))\n", | |||
"\n", | |||
"# Sort by similarity\n", | |||
"similarities.sort(key=itemgetter(2), reverse=True)\n", | |||
"\n", | |||
"# Print the sorted similarities\n", | |||
"for similarity in similarities[:300]:\n", | |||
" print(f\"{similarity[0]} and {similarity[1]}: {similarity[2]}\")\n" | |||
], | |||
"metadata": { | |||
"colab": { | |||
"base_uri": "https://localhost:8080/" | |||
}, | |||
"id": "9_AHfllm7Y1F", | |||
"outputId": "1eb26326-78ae-4640-bb42-9baaacfe81cb" | |||
}, | |||
"execution_count": null, | |||
"outputs": [ | |||
{ | |||
"output_type": "stream", | |||
"name": "stdout", | |||
"text": [ | |||
"BreastStroke and FrontCrawl: 0.9135641566364405\n", | |||
"HighJump and LongJump: 0.9058382170839308\n", | |||
"PushUps and WallPushups: 0.893513324165334\n", | |||
"BoxingPunchingBag and BoxingSpeedBag: 0.8610437315536404\n", | |||
"PlayingCello and PlayingViolin: 0.8607566078247353\n", | |||
"Kayaking and Rafting: 0.8516765812785851\n", | |||
"HandstandPushups and PushUps: 0.8291473360831765\n", | |||
"HammerThrow and ThrowDiscus: 0.8285788251299882\n", | |||
"HighJump and PoleVault: 0.827215227814937\n", | |||
"LongJump and Shotput: 0.8249689845096113\n", | |||
"Shotput and ThrowDiscus: 0.8225986807562835\n", | |||
"BreastStroke and Diving: 0.8198360899246\n", | |||
"HammerThrow and Shotput: 0.8194067921502493\n", | |||
"ParallelBars and UnevenBars: 0.8172351546564418\n", | |||
"HandstandPushups and WallPushups: 0.8063430148554194\n", | |||
"PlayingDhol and PlayingTabla: 0.8044929381515984\n", | |||
"Drumming and PlayingTabla: 0.8036557686356149\n", | |||
"BaseballPitch and TennisSwing: 0.7926770716298019\n", | |||
"HighJump and Shotput: 0.7916822672394149\n", | |||
"FieldHockeyPenalty and SoccerPenalty: 0.7912218906905013\n", | |||
"PlayingFlute and PlayingViolin: 0.7907740407852453\n", | |||
"BodyWeightSquats and Lunges: 0.790197964204252\n", | |||
"Diving and SkyDiving: 0.7900033596110797\n", | |||
"PlayingPiano and PlayingViolin: 0.7898563658083394\n", | |||
"Diving and FrontCrawl: 0.7884088912671918\n", | |||
"CricketBowling and CricketShot: 0.786969527453953\n", | |||
"Kayaking and Rowing: 0.786382273584344\n", | |||
"Haircut and HeadMassage: 0.7813618969695116\n", | |||
"PlayingSitar and PlayingTabla: 0.7788554840779363\n", | |||
"SkateBoarding and Skiing: 0.7762970903271766\n", | |||
"Rafting and Rowing: 0.7752746043074626\n", | |||
"GolfSwing and TennisSwing: 0.7734050524041192\n", | |||
"PullUps and WallPushups: 0.7729951565091107\n", | |||
"HandstandPushups and HandstandWalking: 0.7717042057111059\n", | |||
"HighJump and TrampolineJumping: 0.7710276014756723\n", | |||
"HandstandWalking and SkateBoarding: 0.7684993331159056\n", | |||
"Biking and HorseRiding: 0.768091642690889\n", | |||
"PlayingGuitar and PlayingViolin: 0.766646619086345\n", | |||
"LongJump and TennisSwing: 0.7665787465707985\n", | |||
"RopeClimbing and TrampolineJumping: 0.7659114066820266\n", | |||
"SkyDiving and TrampolineJumping: 0.76511061957433\n", | |||
"PlayingCello and PlayingPiano: 0.7647740312136126\n", | |||
"Diving and HighJump: 0.7629321426272753\n", | |||
"JugglingBalls and YoYo: 0.7610768612305048\n", | |||
"BrushingTeeth and ShavingBeard: 0.7591425280885891\n", | |||
"Diving and LongJump: 0.7574804843984438\n", | |||
"LongJump and TrampolineJumping: 0.7571481857197025\n", | |||
"HighJump and TennisSwing: 0.7566156464932003\n", | |||
"HammerThrow and LongJump: 0.7558300803658048\n", | |||
"Diving and SkateBoarding: 0.7536973390587902\n", | |||
"JumpRope and RopeClimbing: 0.753327828396938\n", | |||
"LongJump and ThrowDiscus: 0.7529842080950556\n", | |||
"Diving and TrampolineJumping: 0.7518814257833246\n", | |||
"Biking and SkateBoarding: 0.7517431292455309\n", | |||
"LongJump and PoleVault: 0.7507306877543533\n", | |||
"BaseballPitch and CricketShot: 0.7499026266438265\n", | |||
"Shotput and TennisSwing: 0.7489834916216616\n", | |||
"PoleVault and Shotput: 0.7488596433999501\n", | |||
"Swing and TrampolineJumping: 0.7487908446460982\n", | |||
"CliffDiving and Diving: 0.7483162287151933\n", | |||
"CricketShot and TableTennisShot: 0.7481777580061397\n", | |||
"Diving and TennisSwing: 0.7471226907881188\n", | |||
"FrisbeeCatch and ThrowDiscus: 0.7443262070746838\n", | |||
"Skiing and Surfing: 0.7431443081264507\n", | |||
"SkateBoarding and SkyDiving: 0.7426388562026282\n", | |||
"PlayingGuitar and PlayingSitar: 0.7421187542614355\n", | |||
"Punch and SkyDiving: 0.741055650388424\n", | |||
"Rafting and SkyDiving: 0.7403432891302099\n", | |||
"CricketShot and SoccerPenalty: 0.7399070298540545\n", | |||
"RopeClimbing and Swing: 0.739812652356636\n", | |||
"SkateBoarding and TrampolineJumping: 0.7398099398134427\n", | |||
"Hammering and Typing: 0.7391169263001942\n", | |||
"HandstandWalking and LongJump: 0.7376174197128719\n", | |||
"HorseRiding and SkateBoarding: 0.7373558640663447\n", | |||
"HandstandWalking and HighJump: 0.7373394124461938\n", | |||
"CricketShot and TennisSwing: 0.7367685620369944\n", | |||
"LongJump and SkateBoarding: 0.7363108945108562\n", | |||
"Diving and HandstandWalking: 0.7361879858427884\n", | |||
"PlayingGuitar and PlayingPiano: 0.7360426048849933\n", | |||
"Hammering and Punch: 0.7355299374133767\n", | |||
"Punch and TennisSwing: 0.7349357411747358\n", | |||
"BalanceBeam and ParallelBars: 0.734626447206699\n", | |||
"PlayingPiano and Typing: 0.7342818000329813\n", | |||
"Drumming and PlayingDhol: 0.7342338545960496\n", | |||
"Diving and Shotput: 0.733686875936528\n", | |||
"PlayingFlute and PlayingPiano: 0.7334535071686105\n", | |||
"FloorGymnastics and PoleVault: 0.7332110651309142\n", | |||
"HorseRace and HorseRiding: 0.7326100067023478\n", | |||
"HandstandWalking and TrampolineJumping: 0.7324930780008992\n", | |||
"Basketball and BasketballDunk: 0.7320448669471811\n", | |||
"HighJump and ThrowDiscus: 0.7312358151246964\n", | |||
"ApplyEyeMakeup and ApplyLipstick: 0.7308442688180419\n", | |||
"RopeClimbing and SkyDiving: 0.7306117431890142\n", | |||
"TennisSwing and VolleyballSpiking: 0.7303710434696256\n", | |||
"Billiards and Bowling: 0.7291336136266361\n", | |||
"Skiing and SkyDiving: 0.7289941358107048\n", | |||
"Diving and Skiing: 0.7279314246716805\n", | |||
"Swing and TennisSwing: 0.7275821346958722\n", | |||
"BaseballPitch and LongJump: 0.7264541939600715\n", | |||
"PlayingCello and PlayingGuitar: 0.7263333055199999\n", | |||
"HammerThrow and JavelinThrow: 0.7257228448556832\n", | |||
"Haircut and Typing: 0.7257034588639546\n", | |||
"RopeClimbing and SkateBoarding: 0.7255599711121894\n", | |||
"HandstandPushups and PullUps: 0.7253599334545812\n", | |||
"Rowing and Skiing: 0.7249357111858055\n", | |||
"TennisSwing and TrampolineJumping: 0.7240195069941607\n", | |||
"HammerThrow and PoleVault: 0.7239577542426181\n", | |||
"LongJump and Skiing: 0.7237669589998139\n", | |||
"BreastStroke and Shotput: 0.7236121929503939\n", | |||
"PlayingCello and PlayingSitar: 0.7232122696549205\n", | |||
"Punch and Swing: 0.7230845724023044\n", | |||
"SkateBoarding and TennisSwing: 0.7228170165773073\n", | |||
"BodyWeightSquats and HandstandPushups: 0.7220405937519339\n", | |||
"FloorGymnastics and StillRings: 0.7219493456486239\n", | |||
"PullUps and PushUps: 0.7213810950131618\n", | |||
"HorseRiding and Skiing: 0.7206407102268063\n", | |||
"SkyDiving and Typing: 0.7202455770918296\n", | |||
"HorseRiding and RopeClimbing: 0.7191262064512949\n", | |||
"SkyDiving and Swing: 0.7180675832908672\n", | |||
"PlayingFlute and PlayingGuitar: 0.7180613359728372\n", | |||
"HammerThrow and HighJump: 0.7179640718224373\n", | |||
"Diving and Punch: 0.7176298910268708\n", | |||
"PlayingViolin and TennisSwing: 0.7174823425623374\n", | |||
"Diving and RopeClimbing: 0.7167796270997888\n", | |||
"Rafting and RopeClimbing: 0.7167763884184823\n", | |||
"TableTennisShot and TennisSwing: 0.7165475089712395\n", | |||
"Biking and Skiing: 0.7164028469479011\n", | |||
"JavelinThrow and PoleVault: 0.7159371318136043\n", | |||
"RopeClimbing and Typing: 0.7155954413924253\n", | |||
"HeadMassage and Typing: 0.7153357826215936\n", | |||
"Rafting and Typing: 0.7138269972502138\n", | |||
"Biking and Rowing: 0.7137891147805048\n", | |||
"Kayaking and Typing: 0.7137610634362122\n", | |||
"BreastStroke and HighJump: 0.7126134869505956\n", | |||
"Diving and Rowing: 0.7125020136521408\n", | |||
"HighJump and SkateBoarding: 0.7117778149802355\n", | |||
"SoccerPenalty and VolleyballSpiking: 0.7111150090759842\n", | |||
"LongJump and SkyDiving: 0.7110640850737046\n", | |||
"PlayingFlute and PlayingSitar: 0.7110108149484404\n", | |||
"BaseballPitch and Shotput: 0.7109615116869581\n", | |||
"PlayingSitar and PlayingViolin: 0.7104350026580752\n", | |||
"HandstandPushups and JumpingJack: 0.7103055239175267\n", | |||
"SkateBoarding and Surfing: 0.7098772586463374\n", | |||
"Punch and RopeClimbing: 0.7086769659664385\n", | |||
"Punch and TrampolineJumping: 0.7078881926354027\n", | |||
"Diving and Surfing: 0.707117497226668\n", | |||
"PlayingPiano and TennisSwing: 0.706919504587053\n", | |||
"Shotput and SkateBoarding: 0.7061596986078766\n", | |||
"PlayingCello and PlayingFlute: 0.7058423238727487\n", | |||
"FloorGymnastics and UnevenBars: 0.7057777115167209\n", | |||
"ParallelBars and StillRings: 0.7055355076859835\n", | |||
"CleanAndJerk and Shotput: 0.7055125658073746\n", | |||
"Kayaking and SkyDiving: 0.704658976975876\n", | |||
"TableTennisShot and VolleyballSpiking: 0.7045804240564966\n", | |||
"Drumming and PlayingPiano: 0.7045712194927514\n", | |||
"Haircut and Hammering: 0.7044598702031014\n", | |||
"Rowing and TennisSwing: 0.7042651727056333\n", | |||
"HorseRiding and Typing: 0.7040858174512816\n", | |||
"Punch and Shotput: 0.704052525603613\n", | |||
"Archery and TennisSwing: 0.7036815296561322\n", | |||
"SkateBoarding and YoYo: 0.7030016415760869\n", | |||
"RopeClimbing and YoYo: 0.7029277827820148\n", | |||
"Punch and YoYo: 0.7023995431749964\n", | |||
"LongJump and Lunges: 0.7018532908775557\n", | |||
"JumpRope and TrampolineJumping: 0.7016024452998298\n", | |||
"Diving and Rafting: 0.7013908045763904\n", | |||
"Bowling and CricketBowling: 0.7013511724501809\n", | |||
"Punch and TaiChi: 0.7008104891179683\n", | |||
"BalanceBeam and UnevenBars: 0.7005837043328924\n", | |||
"SkateBoarding and Typing: 0.7005140816781765\n", | |||
"PlayingViolin and YoYo: 0.7003404861306218\n", | |||
"BasketballDunk and TennisSwing: 0.6999227936958208\n", | |||
"HighJump and VolleyballSpiking: 0.6998266362653691\n", | |||
"FloorGymnastics and ParallelBars: 0.6988077818792078\n", | |||
"PlayingViolin and Punch: 0.6973703128445751\n", | |||
"Punch and Typing: 0.6967775549981869\n", | |||
"Archery and Shotput: 0.6967583917284877\n", | |||
"Billiards and TableTennisShot: 0.6967073608377916\n", | |||
"BreastStroke and TennisSwing: 0.6966987394749857\n", | |||
"Punch and SkateBoarding: 0.6964323711013919\n", | |||
"TennisSwing and ThrowDiscus: 0.6964309660329386\n", | |||
"Biking and TennisSwing: 0.6962187137964084\n", | |||
"BaseballPitch and GolfSwing: 0.6961587111393436\n", | |||
"SkateBoarding and Swing: 0.6960365072022532\n", | |||
"BalanceBeam and FloorGymnastics: 0.6956748992112082\n", | |||
"RopeClimbing and Skiing: 0.6954019941797185\n", | |||
"BasketballDunk and HighJump: 0.6952324557182699\n", | |||
"RopeClimbing and TennisSwing: 0.6952311829515747\n", | |||
"BreastStroke and Rowing: 0.6949829501769483\n", | |||
"Skiing and TennisSwing: 0.6947422694408114\n", | |||
"JumpingJack and TrampolineJumping: 0.6946418413900224\n", | |||
"ParallelBars and PommelHorse: 0.6943619896075937\n", | |||
"BreastStroke and LongJump: 0.6926574264131027\n", | |||
"FloorGymnastics and HandstandPushups: 0.6924127157339901\n", | |||
"Shotput and Skiing: 0.692409162454707\n", | |||
"BodyWeightSquats and JumpingJack: 0.6915140858910102\n", | |||
"Archery and PlayingViolin: 0.6914364939922757\n", | |||
"BasketballDunk and TrampolineJumping: 0.6910720435738926\n", | |||
"Haircut and SkyDiving: 0.6906279918436358\n", | |||
"HorseRiding and Rowing: 0.6906078683829708\n", | |||
"Diving and Kayaking: 0.6904115966869461\n", | |||
"HighJump and SkyDiving: 0.6903223255840882\n", | |||
"HandstandWalking and RopeClimbing: 0.6900177976361846\n", | |||
"Biking and Typing: 0.688663446247988\n", | |||
"Archery and Punch: 0.6885543671849701\n", | |||
"BreastStroke and Surfing: 0.6882605898627008\n", | |||
"Diving and PlayingPiano: 0.6877803133466125\n", | |||
"PlayingPiano and SkateBoarding: 0.6872490075085536\n", | |||
"BaseballPitch and Punch: 0.6871945116845465\n", | |||
"HandstandWalking and UnevenBars: 0.6869197476332304\n", | |||
"CricketShot and GolfSwing: 0.6868492889521802\n", | |||
"Kayaking and Surfing: 0.6865820428631303\n", | |||
"Basketball and TennisSwing: 0.6864463021798647\n", | |||
"Shotput and YoYo: 0.6862387737044783\n", | |||
"PlayingViolin and SkateBoarding: 0.6861837432385506\n", | |||
"SoccerPenalty and TennisSwing: 0.6861144923920035\n", | |||
"BaseballPitch and SkateBoarding: 0.6855318206995243\n", | |||
"Shotput and TrampolineJumping: 0.6854082143018928\n", | |||
"CricketShot and Shotput: 0.6852120255606943\n", | |||
"Kayaking and Skiing: 0.6850696489537216\n", | |||
"Archery and Fencing: 0.6850545206297898\n", | |||
"HighJump and Skiing: 0.6847553551867225\n", | |||
"BaseballPitch and PlayingViolin: 0.6840371878413228\n", | |||
"HorseRiding and Swing: 0.6835796664282903\n", | |||
"Archery and SkyDiving: 0.683544816768237\n", | |||
"PlayingCello and YoYo: 0.6833837981485343\n", | |||
"Rowing and Typing: 0.6833516303139973\n", | |||
"HorseRiding and SkyDiving: 0.6832857317563026\n", | |||
"Lunges and TennisSwing: 0.6827234718924219\n", | |||
"BasketballDunk and LongJump: 0.6825552137357147\n", | |||
"Rowing and Surfing: 0.6820142752266611\n", | |||
"JumpingJack and LongJump: 0.6819915349194994\n", | |||
"Haircut and Punch: 0.6819490061511206\n", | |||
"Haircut and SkateBoarding: 0.6818481724348707\n", | |||
"FloorGymnastics and HighJump: 0.6808541160844553\n", | |||
"Punch and Rafting: 0.6807997071512121\n", | |||
"PlayingPiano and Rowing: 0.680677304303248\n", | |||
"JavelinThrow and ThrowDiscus: 0.6806693704418652\n", | |||
"Rowing and SkateBoarding: 0.6803773211285751\n", | |||
"Hammering and HeadMassage: 0.6802512293515253\n", | |||
"Punch and Rowing: 0.6802002555149966\n", | |||
"BodyWeightSquats and PushUps: 0.6799776131017082\n", | |||
"Diving and Typing: 0.6797726174709029\n", | |||
"Shotput and SkyDiving: 0.6797673354951406\n", | |||
"Hammering and PlayingViolin: 0.6796496669473412\n", | |||
"HandstandWalking and SkyDiving: 0.6794054995491691\n", | |||
"FloorGymnastics and HandstandWalking: 0.6794028936136237\n", | |||
"BasketballDunk and CricketShot: 0.6792758965914687\n", | |||
"Biking and Diving: 0.6790006973364628\n", | |||
"BaseballPitch and CricketBowling: 0.6786719503376799\n", | |||
"LongJump and VolleyballSpiking: 0.6784090342928275\n", | |||
"PullUps and RopeClimbing: 0.678258564371885\n", | |||
"Rafting and Skiing: 0.6781782308964603\n", | |||
"SkyDiving and TennisSwing: 0.6780321924063827\n", | |||
"Biking and Kayaking: 0.6778740166200692\n", | |||
"HorseRiding and Punch: 0.6775237101716206\n", | |||
"Hammering and SkateBoarding: 0.6770651672183081\n", | |||
"Mixing and Typing: 0.6770458400648384\n", | |||
"Shotput and VolleyballSpiking: 0.6769503514713261\n", | |||
"LongJump and Punch: 0.6768260936231084\n", | |||
"Billiards and TennisSwing: 0.676687811162065\n", | |||
"SoccerJuggling and SoccerPenalty: 0.6765064274372669\n", | |||
"StillRings and UnevenBars: 0.6764089115086491\n", | |||
"JugglingBalls and SoccerJuggling: 0.6760817346996303\n", | |||
"BaseballPitch and SoccerPenalty: 0.6759697039170522\n", | |||
"Hammering and HorseRiding: 0.6757606742710943\n", | |||
"Billiards and PlayingPiano: 0.6756887070040185\n", | |||
"Basketball and Shotput: 0.6756210059393115\n", | |||
"TrampolineJumping and Typing: 0.6756177121880147\n", | |||
"BalanceBeam and PommelHorse: 0.6755478243357194\n", | |||
"CliffDiving and SkyDiving: 0.6754688197958925\n", | |||
"BenchPress and HandstandPushups: 0.6753462315331547\n", | |||
"Rowing and Skijet: 0.6749020427705266\n", | |||
"PlayingCello and TennisSwing: 0.6748126664110354\n", | |||
"CricketBowling and TennisSwing: 0.6747936810472541\n", | |||
"TrampolineJumping and UnevenBars: 0.6747665868776003\n", | |||
"BodyWeightSquats and HandstandWalking: 0.6744448381666313\n", | |||
"Hammering and RopeClimbing: 0.6740811403604476\n", | |||
"Lunges and Shotput: 0.6737521021449214\n", | |||
"Drumming and PlayingSitar: 0.6736439263992084\n", | |||
"Skiing and Skijet: 0.6734635666570382\n", | |||
"HandstandWalking and Shotput: 0.6732521658396622\n", | |||
"BandMarching and SkateBoarding: 0.6730512018894181\n", | |||
"HandstandWalking and TennisSwing: 0.672699463584204\n", | |||
"HorseRiding and TennisSwing: 0.6722271368365996\n", | |||
"Biking and RopeClimbing: 0.6721048261242591\n", | |||
"PlayingViolin and Typing: 0.6720289417432124\n", | |||
"HandstandWalking and JumpingJack: 0.671939227917232\n", | |||
"PlayingPiano and PlayingSitar: 0.671923104534426\n", | |||
"HorseRiding and Rafting: 0.6716409889288361\n", | |||
"Biking and LongJump: 0.6713809775815504\n", | |||
"Lunges and TaiChi: 0.6711312367339834\n", | |||
"BandMarching and HorseRiding: 0.6711082186179306\n", | |||
"CricketBowling and Shotput: 0.6705659431969646\n", | |||
"BodyWeightSquats and WallPushups: 0.6705453623257535\n", | |||
"SoccerPenalty and TableTennisShot: 0.6703895759334059\n", | |||
"HandstandWalking and PushUps: 0.6702445259363492\n", | |||
"RopeClimbing and UnevenBars: 0.6699618121733547\n", | |||
"Rowing and Swing: 0.6697346266800764\n", | |||
"PlayingCello and SkateBoarding: 0.6696594880529304\n" | |||
] | |||
} | |||
] | |||
} | |||
] | |||
} |
@@ -0,0 +1,255 @@ | |||
import torch | |||
from torch.utils.data import DataLoader | |||
from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_DataLoader | |||
from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader | |||
from dataloaders.dataloader_msvd_retrieval import MSVD_DataLoader | |||
from dataloaders.dataloader_lsmdc_retrieval import LSMDC_DataLoader | |||
from dataloaders.dataloader_activitynet_retrieval import ActivityNet_DataLoader | |||
from dataloaders.dataloader_didemo_retrieval import DiDeMo_DataLoader | |||
def dataloader_msrvtt_train(args, tokenizer): | |||
msrvtt_dataset = MSRVTT_TrainDataLoader( | |||
csv_path=args.train_csv, | |||
json_path=args.data_path, | |||
features_path=args.features_path, | |||
max_words=args.max_words, | |||
feature_framerate=args.feature_framerate, | |||
tokenizer=tokenizer, | |||
max_frames=args.max_frames, | |||
unfold_sentences=args.expand_msrvtt_sentences, | |||
frame_order=args.train_frame_order, | |||
slice_framepos=args.slice_framepos, | |||
) | |||
train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset) | |||
dataloader = DataLoader( | |||
msrvtt_dataset, | |||
batch_size=args.batch_size // args.n_gpu, | |||
num_workers=args.num_thread_reader, | |||
pin_memory=False, | |||
shuffle=(train_sampler is None), | |||
sampler=train_sampler, | |||
drop_last=True, | |||
) | |||
return dataloader, len(msrvtt_dataset), train_sampler | |||
def dataloader_msrvtt_test(args, tokenizer, subset="test"): | |||
msrvtt_testset = MSRVTT_DataLoader( | |||
csv_path=args.val_csv, | |||
features_path=args.features_path, | |||
max_words=args.max_words, | |||
feature_framerate=args.feature_framerate, | |||
tokenizer=tokenizer, | |||
max_frames=args.max_frames, | |||
frame_order=args.eval_frame_order, | |||
slice_framepos=args.slice_framepos, | |||
) | |||
dataloader_msrvtt = DataLoader( | |||
msrvtt_testset, | |||
batch_size=args.batch_size_val, | |||
num_workers=args.num_thread_reader, | |||
shuffle=False, | |||
drop_last=False, | |||
) | |||
return dataloader_msrvtt, len(msrvtt_testset) | |||
def dataloader_msvd_train(args, tokenizer): | |||
msvd_dataset = MSVD_DataLoader( | |||
subset="train", | |||
data_path=args.data_path, | |||
features_path=args.features_path, | |||
max_words=args.max_words, | |||
feature_framerate=args.feature_framerate, | |||
tokenizer=tokenizer, | |||
max_frames=args.max_frames, | |||
frame_order=args.train_frame_order, | |||
slice_framepos=args.slice_framepos, | |||
) | |||
train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset) | |||
dataloader = DataLoader( | |||
msvd_dataset, | |||
batch_size=args.batch_size // args.n_gpu, | |||
num_workers=args.num_thread_reader, | |||
pin_memory=False, | |||
shuffle=(train_sampler is None), | |||
sampler=train_sampler, | |||
drop_last=True, | |||
) | |||
return dataloader, len(msvd_dataset), train_sampler | |||
def dataloader_msvd_test(args, tokenizer, subset="test"): | |||
msvd_testset = MSVD_DataLoader( | |||
subset=subset, | |||
data_path=args.data_path, | |||
features_path=args.features_path, | |||
max_words=args.max_words, | |||
feature_framerate=args.feature_framerate, | |||
tokenizer=tokenizer, | |||
max_frames=args.max_frames, | |||
frame_order=args.eval_frame_order, | |||
slice_framepos=args.slice_framepos, | |||
) | |||
dataloader_msrvtt = DataLoader( | |||
msvd_testset, | |||
batch_size=args.batch_size_val, | |||
num_workers=args.num_thread_reader, | |||
shuffle=False, | |||
drop_last=False, | |||
) | |||
return dataloader_msrvtt, len(msvd_testset) | |||
def dataloader_lsmdc_train(args, tokenizer): | |||
lsmdc_dataset = LSMDC_DataLoader( | |||
subset="train", | |||
data_path=args.data_path, | |||
features_path=args.features_path, | |||
max_words=args.max_words, | |||
feature_framerate=args.feature_framerate, | |||
tokenizer=tokenizer, | |||
max_frames=args.max_frames, | |||
frame_order=args.train_frame_order, | |||
slice_framepos=args.slice_framepos, | |||
) | |||
train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset) | |||
dataloader = DataLoader( | |||
lsmdc_dataset, | |||
batch_size=args.batch_size // args.n_gpu, | |||
num_workers=args.num_thread_reader, | |||
pin_memory=False, | |||
shuffle=(train_sampler is None), | |||
sampler=train_sampler, | |||
drop_last=True, | |||
) | |||
return dataloader, len(lsmdc_dataset), train_sampler | |||
def dataloader_lsmdc_test(args, tokenizer, subset="test"): | |||
lsmdc_testset = LSMDC_DataLoader( | |||
subset=subset, | |||
data_path=args.data_path, | |||
features_path=args.features_path, | |||
max_words=args.max_words, | |||
feature_framerate=args.feature_framerate, | |||
tokenizer=tokenizer, | |||
max_frames=args.max_frames, | |||
frame_order=args.eval_frame_order, | |||
slice_framepos=args.slice_framepos, | |||
) | |||
dataloader_msrvtt = DataLoader( | |||
lsmdc_testset, | |||
batch_size=args.batch_size_val, | |||
num_workers=args.num_thread_reader, | |||
shuffle=False, | |||
drop_last=False, | |||
) | |||
return dataloader_msrvtt, len(lsmdc_testset) | |||
def dataloader_activity_train(args, tokenizer): | |||
activity_dataset = ActivityNet_DataLoader( | |||
subset="train", | |||
data_path=args.data_path, | |||
features_path=args.features_path, | |||
max_words=args.max_words, | |||
feature_framerate=args.feature_framerate, | |||
tokenizer=tokenizer, | |||
max_frames=args.max_frames, | |||
frame_order=args.train_frame_order, | |||
slice_framepos=args.slice_framepos, | |||
) | |||
train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset) | |||
dataloader = DataLoader( | |||
activity_dataset, | |||
batch_size=args.batch_size // args.n_gpu, | |||
num_workers=args.num_thread_reader, | |||
pin_memory=False, | |||
shuffle=(train_sampler is None), | |||
sampler=train_sampler, | |||
drop_last=True, | |||
) | |||
return dataloader, len(activity_dataset), train_sampler | |||
def dataloader_activity_test(args, tokenizer, subset="test"): | |||
activity_testset = ActivityNet_DataLoader( | |||
subset=subset, | |||
data_path=args.data_path, | |||
features_path=args.features_path, | |||
max_words=args.max_words, | |||
feature_framerate=args.feature_framerate, | |||
tokenizer=tokenizer, | |||
max_frames=args.max_frames, | |||
frame_order=args.eval_frame_order, | |||
slice_framepos=args.slice_framepos, | |||
) | |||
dataloader_msrvtt = DataLoader( | |||
activity_testset, | |||
batch_size=args.batch_size_val, | |||
num_workers=args.num_thread_reader, | |||
shuffle=False, | |||
drop_last=False, | |||
) | |||
return dataloader_msrvtt, len(activity_testset) | |||
def dataloader_didemo_train(args, tokenizer): | |||
didemo_dataset = DiDeMo_DataLoader( | |||
subset="train", | |||
data_path=args.data_path, | |||
features_path=args.features_path, | |||
max_words=args.max_words, | |||
feature_framerate=args.feature_framerate, | |||
tokenizer=tokenizer, | |||
max_frames=args.max_frames, | |||
frame_order=args.train_frame_order, | |||
slice_framepos=args.slice_framepos, | |||
) | |||
train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset) | |||
dataloader = DataLoader( | |||
didemo_dataset, | |||
batch_size=args.batch_size // args.n_gpu, | |||
num_workers=args.num_thread_reader, | |||
pin_memory=False, | |||
shuffle=(train_sampler is None), | |||
sampler=train_sampler, | |||
drop_last=True, | |||
) | |||
return dataloader, len(didemo_dataset), train_sampler | |||
def dataloader_didemo_test(args, tokenizer, subset="test"): | |||
didemo_testset = DiDeMo_DataLoader( | |||
subset=subset, | |||
data_path=args.data_path, | |||
features_path=args.features_path, | |||
max_words=args.max_words, | |||
feature_framerate=args.feature_framerate, | |||
tokenizer=tokenizer, | |||
max_frames=args.max_frames, | |||
frame_order=args.eval_frame_order, | |||
slice_framepos=args.slice_framepos, | |||
) | |||
dataloader_didemo = DataLoader( | |||
didemo_testset, | |||
batch_size=args.batch_size_val, | |||
num_workers=args.num_thread_reader, | |||
shuffle=False, | |||
drop_last=False, | |||
) | |||
return dataloader_didemo, len(didemo_testset) | |||
DATALOADER_DICT = {} | |||
DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_test, "test":None} | |||
DATALOADER_DICT["msvd"] = {"train":dataloader_msvd_train, "val":dataloader_msvd_test, "test":dataloader_msvd_test} | |||
DATALOADER_DICT["lsmdc"] = {"train":dataloader_lsmdc_train, "val":dataloader_lsmdc_test, "test":dataloader_lsmdc_test} | |||
DATALOADER_DICT["activity"] = {"train":dataloader_activity_train, "val":dataloader_activity_test, "test":None} | |||
DATALOADER_DICT["didemo"] = {"train":dataloader_didemo_train, "val":dataloader_didemo_test, "test":dataloader_didemo_test} |
@@ -0,0 +1,227 @@ | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import unicode_literals | |||
from __future__ import print_function | |||
import os | |||
from torch.utils.data import Dataset | |||
import numpy as np | |||
import json | |||
import math | |||
from dataloaders.rawvideo_util import RawVideoExtractor | |||
class ActivityNet_DataLoader(Dataset): | |||
def __init__( | |||
self, | |||
subset, | |||
data_path, | |||
features_path, | |||
tokenizer, | |||
max_words=30, | |||
feature_framerate=1.0, | |||
max_frames=100, | |||
image_resolution=224, | |||
frame_order=0, | |||
slice_framepos=0, | |||
): | |||
self.data_path = data_path | |||
self.features_path = features_path | |||
self.feature_framerate = feature_framerate | |||
self.max_words = max_words | |||
self.max_frames = max_frames | |||
self.tokenizer = tokenizer | |||
# 0: ordinary order; 1: reverse order; 2: random order. | |||
self.frame_order = frame_order | |||
assert self.frame_order in [0, 1, 2] | |||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||
self.slice_framepos = slice_framepos | |||
assert self.slice_framepos in [0, 1, 2] | |||
self.subset = subset | |||
assert self.subset in ["train", "val"] | |||
video_id_path_dict = {} | |||
video_id_path_dict["train"] = os.path.join(self.data_path, "train_ids.json") | |||
video_id_path_dict["val"] = os.path.join(self.data_path, "val_ids.json") | |||
video_json_path_dict = {} | |||
video_json_path_dict["train"] = os.path.join(self.data_path, "train.json") | |||
video_json_path_dict["val"] = os.path.join(self.data_path, "val_1.json") | |||
pseudo_video_id_list, video_id_list = self._get_video_id_single(video_id_path_dict[self.subset]) | |||
pseudo_caption_dict = self._get_captions_single(video_json_path_dict[self.subset]) | |||
print("video id list: {}".format(len(video_id_list))) | |||
print("pseudo caption dict: {}".format(len(pseudo_caption_dict.keys()))) | |||
video_dict = {} | |||
for root, dub_dir, video_files in os.walk(self.features_path): | |||
for video_file in video_files: | |||
video_id_ = ".".join(video_file.split(".")[:-1]) | |||
if video_id_ not in video_id_list: | |||
continue | |||
file_path_ = os.path.join(root, video_file) | |||
video_dict[video_id_] = file_path_ | |||
self.video_dict = video_dict | |||
print("video dict: {}".format(len(video_dict))) | |||
self.pseudo_video_id_list = pseudo_video_id_list | |||
self.video_id_list = video_id_list | |||
self.pseudo_caption_dict = pseudo_caption_dict | |||
# Get iterator video ids | |||
self.video_id2idx_dict = {pseudo_video_id: id for id, pseudo_video_id in enumerate(self.pseudo_video_id_list)} | |||
# Get all captions | |||
self.iter2video_pairs_dict = {} | |||
for pseudo_video_id, video_id in zip(self.pseudo_video_id_list, self.video_id_list): | |||
if pseudo_video_id not in self.pseudo_caption_dict or video_id not in self.video_dict: | |||
continue | |||
caption = self.pseudo_caption_dict[pseudo_video_id] | |||
n_caption = len(caption['start']) | |||
for sub_id in range(n_caption): | |||
self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (pseudo_video_id, sub_id) | |||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||
def __len__(self): | |||
return len(self.iter2video_pairs_dict) | |||
def _get_video_id_from_pseduo(self, pseudo_video_id): | |||
video_id = pseudo_video_id[2:] | |||
return video_id | |||
def _get_video_id_single(self, path): | |||
pseudo_video_id_list = [] | |||
video_id_list = [] | |||
print('Loading json: {}'.format(path)) | |||
with open(path, 'r') as f: | |||
json_data = json.load(f) | |||
for pseudo_video_id in json_data: | |||
if pseudo_video_id in pseudo_video_id_list: | |||
print("reduplicate.") | |||
else: | |||
video_id = self._get_video_id_from_pseduo(pseudo_video_id) | |||
pseudo_video_id_list.append(pseudo_video_id) | |||
video_id_list.append(video_id) | |||
return pseudo_video_id_list, video_id_list | |||
def _get_captions_single(self, path): | |||
pseudo_caption_dict = {} | |||
with open(path, 'r') as f: | |||
json_data = json.load(f) | |||
for pseudo_video_id, v_ in json_data.items(): | |||
pseudo_caption_dict[pseudo_video_id] = {} | |||
duration = v_["duration"] | |||
pseudo_caption_dict[pseudo_video_id]["start"] = np.array([0], dtype=object) | |||
pseudo_caption_dict[pseudo_video_id]["end"] = np.array([int(math.ceil(float(duration)))], dtype=object) | |||
pseudo_caption_dict[pseudo_video_id]["text"] = np.array([" ".join(v_["sentences"])], dtype=object) | |||
return pseudo_caption_dict | |||
def _get_text(self, pseudo_video_id, sub_id): | |||
caption = self.pseudo_caption_dict[pseudo_video_id] | |||
k = 1 | |||
r_ind = [sub_id] | |||
starts = np.zeros(k, dtype=np.long) | |||
ends = np.zeros(k, dtype=np.long) | |||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||
for i in range(k): | |||
ind = r_ind[i] | |||
start_, end_ = caption['start'][ind], caption['end'][ind] | |||
words = self.tokenizer.tokenize(caption['text'][ind]) | |||
starts[i], ends[i] = start_, end_ | |||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||
total_length_with_CLS = self.max_words - 1 | |||
if len(words) > total_length_with_CLS: | |||
words = words[:total_length_with_CLS] | |||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||
input_mask = [1] * len(input_ids) | |||
segment_ids = [0] * len(input_ids) | |||
while len(input_ids) < self.max_words: | |||
input_ids.append(0) | |||
input_mask.append(0) | |||
segment_ids.append(0) | |||
assert len(input_ids) == self.max_words | |||
assert len(input_mask) == self.max_words | |||
assert len(segment_ids) == self.max_words | |||
pairs_text[i] = np.array(input_ids) | |||
pairs_mask[i] = np.array(input_mask) | |||
pairs_segment[i] = np.array(segment_ids) | |||
return pairs_text, pairs_mask, pairs_segment, starts, ends | |||
def _get_rawvideo(self, idx, s, e): | |||
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long) | |||
max_video_length = [0] * len(s) | |||
# Pair x L x T x 3 x H x W | |||
video = np.zeros((len(s), self.max_frames, 1, 3, | |||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||
video_path = self.video_dict[idx] | |||
try: | |||
for i in range(len(s)): | |||
start_time = int(s[i]) | |||
end_time = int(e[i]) | |||
start_time = start_time if start_time >= 0. else 0. | |||
end_time = end_time if end_time >= 0. else 0. | |||
if start_time > end_time: | |||
start_time, end_time = end_time, start_time | |||
elif start_time == end_time: | |||
end_time = end_time + 1 | |||
# Should be optimized by gathering all asking of this video | |||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) | |||
raw_video_data = raw_video_data['video'] | |||
if len(raw_video_data.shape) > 3: | |||
raw_video_data_clip = raw_video_data | |||
# L x T x 3 x H x W | |||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||
if self.max_frames < raw_video_slice.shape[0]: | |||
if self.slice_framepos == 0: | |||
video_slice = raw_video_slice[:self.max_frames, ...] | |||
elif self.slice_framepos == 1: | |||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||
else: | |||
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) | |||
video_slice = raw_video_slice[sample_indx, ...] | |||
else: | |||
video_slice = raw_video_slice | |||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||
slice_len = video_slice.shape[0] | |||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||
if slice_len < 1: | |||
pass | |||
else: | |||
video[i][:slice_len, ...] = video_slice | |||
else: | |||
print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) | |||
except Exception as excep: | |||
print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) | |||
raise excep | |||
for i, v_length in enumerate(max_video_length): | |||
video_mask[i][:v_length] = [1] * v_length | |||
return video, video_mask | |||
def __getitem__(self, feature_idx): | |||
pseudo_video_id, sub_id = self.iter2video_pairs_dict[feature_idx] | |||
idx = self.video_id2idx_dict[pseudo_video_id] | |||
pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(pseudo_video_id, sub_id) | |||
video, video_mask = self._get_rawvideo(self.video_id_list[idx], starts, ends) | |||
return pairs_text, pairs_mask, pairs_segment, video, video_mask |
@@ -0,0 +1,222 @@ | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import unicode_literals | |||
from __future__ import print_function | |||
import os | |||
from torch.utils.data import Dataset | |||
import numpy as np | |||
import json | |||
from dataloaders.rawvideo_util import RawVideoExtractor | |||
class DiDeMo_DataLoader(Dataset): | |||
def __init__( | |||
self, | |||
subset, | |||
data_path, | |||
features_path, | |||
tokenizer, | |||
max_words=30, | |||
feature_framerate=1.0, | |||
max_frames=100, | |||
image_resolution=224, | |||
frame_order=0, | |||
slice_framepos=0, | |||
): | |||
self.data_path = data_path | |||
self.features_path = features_path | |||
self.feature_framerate = feature_framerate | |||
self.max_words = max_words | |||
self.max_frames = max_frames | |||
self.tokenizer = tokenizer | |||
# 0: ordinary order; 1: reverse order; 2: random order. | |||
self.frame_order = frame_order | |||
assert self.frame_order in [0, 1, 2] | |||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||
self.slice_framepos = slice_framepos | |||
assert self.slice_framepos in [0, 1, 2] | |||
self.subset = subset | |||
assert self.subset in ["train", "val", "test"] | |||
video_id_path_dict = {} | |||
video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") | |||
video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") | |||
video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") | |||
video_json_path_dict = {} | |||
video_json_path_dict["train"] = os.path.join(self.data_path, "train_data.json") | |||
video_json_path_dict["val"] = os.path.join(self.data_path, "val_data.json") | |||
video_json_path_dict["test"] = os.path.join(self.data_path, "test_data.json") | |||
with open(video_id_path_dict[self.subset], 'r') as fp: | |||
video_ids = [itm.strip() for itm in fp.readlines()] | |||
caption_dict = {} | |||
with open(video_json_path_dict[self.subset], 'r') as f: | |||
json_data = json.load(f) | |||
for itm in json_data: | |||
description = itm["description"] | |||
times = itm["times"] | |||
video = itm["video"] | |||
if video not in video_ids: | |||
continue | |||
# each video is split into 5-second temporal chunks | |||
# average the points from each annotator | |||
start_ = np.mean([t_[0] for t_ in times]) * 5 | |||
end_ = (np.mean([t_[1] for t_ in times]) + 1) * 5 | |||
if video in caption_dict: | |||
caption_dict[video]["start"].append(start_) | |||
caption_dict[video]["end"].append(end_) | |||
caption_dict[video]["text"].append(description) | |||
else: | |||
caption_dict[video] = {} | |||
caption_dict[video]["start"] = [start_] | |||
caption_dict[video]["end"] = [end_] | |||
caption_dict[video]["text"] = [description] | |||
for k_ in caption_dict.keys(): | |||
caption_dict[k_]["start"] = [0] | |||
# trick to save time on obtaining each video length | |||
# [https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md]: | |||
# Some videos are longer than 30 seconds. These videos were truncated to 30 seconds during annotation. | |||
caption_dict[k_]["end"] = [31] | |||
caption_dict[k_]["text"] = [" ".join(caption_dict[k_]["text"])] | |||
video_dict = {} | |||
for root, dub_dir, video_files in os.walk(self.features_path): | |||
for video_file in video_files: | |||
video_id_ = video_file | |||
if video_id_ not in video_ids: | |||
continue | |||
file_path_ = os.path.join(root, video_file) | |||
video_dict[video_id_] = file_path_ | |||
self.caption_dict = caption_dict | |||
self.video_dict = video_dict | |||
video_ids = list(set(video_ids) & set(self.caption_dict.keys()) & set(self.video_dict.keys())) | |||
# Get all captions | |||
self.iter2video_pairs_dict = {} | |||
for video_id in self.caption_dict.keys(): | |||
if video_id not in video_ids: | |||
continue | |||
caption = self.caption_dict[video_id] | |||
n_caption = len(caption['start']) | |||
for sub_id in range(n_caption): | |||
self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (video_id, sub_id) | |||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||
def __len__(self): | |||
return len(self.iter2video_pairs_dict) | |||
def _get_text(self, video_id, sub_id): | |||
caption = self.caption_dict[video_id] | |||
k = 1 | |||
r_ind = [sub_id] | |||
starts = np.zeros(k, dtype=np.long) | |||
ends = np.zeros(k, dtype=np.long) | |||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||
for i in range(k): | |||
ind = r_ind[i] | |||
start_, end_ = caption['start'][ind], caption['end'][ind] | |||
words = self.tokenizer.tokenize(caption['text'][ind]) | |||
starts[i], ends[i] = start_, end_ | |||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||
total_length_with_CLS = self.max_words - 1 | |||
if len(words) > total_length_with_CLS: | |||
words = words[:total_length_with_CLS] | |||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||
input_mask = [1] * len(input_ids) | |||
segment_ids = [0] * len(input_ids) | |||
while len(input_ids) < self.max_words: | |||
input_ids.append(0) | |||
input_mask.append(0) | |||
segment_ids.append(0) | |||
assert len(input_ids) == self.max_words | |||
assert len(input_mask) == self.max_words | |||
assert len(segment_ids) == self.max_words | |||
pairs_text[i] = np.array(input_ids) | |||
pairs_mask[i] = np.array(input_mask) | |||
pairs_segment[i] = np.array(segment_ids) | |||
return pairs_text, pairs_mask, pairs_segment, starts, ends | |||
def _get_rawvideo(self, idx, s, e): | |||
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long) | |||
max_video_length = [0] * len(s) | |||
# Pair x L x T x 3 x H x W | |||
video = np.zeros((len(s), self.max_frames, 1, 3, | |||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||
video_path = self.video_dict[idx] | |||
try: | |||
for i in range(len(s)): | |||
start_time = int(s[i]) | |||
end_time = int(e[i]) | |||
start_time = start_time if start_time >= 0. else 0. | |||
end_time = end_time if end_time >= 0. else 0. | |||
if start_time > end_time: | |||
start_time, end_time = end_time, start_time | |||
elif start_time == end_time: | |||
end_time = end_time + 1 | |||
cache_id = "{}_{}_{}".format(video_path, start_time, end_time) | |||
# Should be optimized by gathering all asking of this video | |||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path, start_time, end_time) | |||
raw_video_data = raw_video_data['video'] | |||
if len(raw_video_data.shape) > 3: | |||
raw_video_data_clip = raw_video_data | |||
# L x T x 3 x H x W | |||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||
if self.max_frames < raw_video_slice.shape[0]: | |||
if self.slice_framepos == 0: | |||
video_slice = raw_video_slice[:self.max_frames, ...] | |||
elif self.slice_framepos == 1: | |||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||
else: | |||
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) | |||
video_slice = raw_video_slice[sample_indx, ...] | |||
else: | |||
video_slice = raw_video_slice | |||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||
slice_len = video_slice.shape[0] | |||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||
if slice_len < 1: | |||
pass | |||
else: | |||
video[i][:slice_len, ...] = video_slice | |||
else: | |||
print("video path: {} error. video id: {}, start: {}, end: {}".format(video_path, idx, start_time, end_time)) | |||
except Exception as excep: | |||
print("video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(video_path, idx, s, e, excep)) | |||
pass | |||
# raise e | |||
for i, v_length in enumerate(max_video_length): | |||
video_mask[i][:v_length] = [1] * v_length | |||
return video, video_mask | |||
def __getitem__(self, feature_idx): | |||
video_id, sub_id = self.iter2video_pairs_dict[feature_idx] | |||
pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(video_id, sub_id) | |||
video, video_mask = self._get_rawvideo(video_id, starts, ends) | |||
return pairs_text, pairs_mask, pairs_segment, video, video_mask |
@@ -0,0 +1,208 @@ | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import unicode_literals | |||
from __future__ import print_function | |||
import os | |||
from torch.utils.data import Dataset | |||
import numpy as np | |||
import json | |||
import math | |||
from dataloaders.rawvideo_util import RawVideoExtractor | |||
class LSMDC_DataLoader(Dataset): | |||
"""LSMDC dataset loader.""" | |||
def __init__( | |||
self, | |||
subset, | |||
data_path, | |||
features_path, | |||
tokenizer, | |||
max_words=30, | |||
feature_framerate=1.0, | |||
max_frames=100, | |||
image_resolution=224, | |||
frame_order=0, | |||
slice_framepos=0, | |||
): | |||
self.data_path = data_path | |||
self.features_path = features_path | |||
self.feature_framerate = feature_framerate | |||
self.max_words = max_words | |||
self.max_frames = max_frames | |||
self.tokenizer = tokenizer | |||
# 0: ordinary order; 1: reverse order; 2: random order. | |||
self.frame_order = frame_order | |||
assert self.frame_order in [0, 1, 2] | |||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||
self.slice_framepos = slice_framepos | |||
assert self.slice_framepos in [0, 1, 2] | |||
self.subset = subset | |||
assert self.subset in ["train", "val", "test"] | |||
video_json_path_dict = {} | |||
video_json_path_dict["train"] = os.path.join(self.data_path, "LSMDC16_annos_training.csv") | |||
video_json_path_dict["val"] = os.path.join(self.data_path, "LSMDC16_annos_val.csv") | |||
video_json_path_dict["test"] = os.path.join(self.data_path, "LSMDC16_challenge_1000_publictect.csv") | |||
# <CLIP_ID>\t<START_ALIGNED>\t<END_ALIGNED>\t<START_EXTRACTED>\t<END_EXTRACTED>\t<SENTENCE> | |||
# <CLIP_ID> is not a unique identifier, i.e. the same <CLIP_ID> can be associated with multiple sentences. | |||
# However, LSMDC16_challenge_1000_publictect.csv has no repeat instances | |||
video_id_list = [] | |||
caption_dict = {} | |||
with open(video_json_path_dict[self.subset], 'r') as fp: | |||
for line in fp: | |||
line = line.strip() | |||
line_split = line.split("\t") | |||
assert len(line_split) == 6 | |||
clip_id, start_aligned, end_aligned, start_extracted, end_extracted, sentence = line_split | |||
caption_dict[len(caption_dict)] = (clip_id, sentence) | |||
if clip_id not in video_id_list: video_id_list.append(clip_id) | |||
video_dict = {} | |||
for root, dub_dir, video_files in os.walk(self.features_path): | |||
for video_file in video_files: | |||
video_id_ = ".".join(video_file.split(".")[:-1]) | |||
if video_id_ not in video_id_list: | |||
continue | |||
file_path_ = os.path.join(root, video_file) | |||
video_dict[video_id_] = file_path_ | |||
self.video_dict = video_dict | |||
# Get all captions | |||
self.iter2video_pairs_dict = {} | |||
for clip_id, sentence in caption_dict.values(): | |||
if clip_id not in self.video_dict: | |||
continue | |||
self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (clip_id, sentence) | |||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||
def __len__(self): | |||
return len(self.iter2video_pairs_dict) | |||
def _get_video_id_from_pseduo(self, pseudo_video_id): | |||
video_id = pseudo_video_id[2:] | |||
return video_id | |||
def _get_video_id_single(self, path): | |||
pseudo_video_id_list = [] | |||
video_id_list = [] | |||
print('Loading json: {}'.format(path)) | |||
with open(path, 'r') as f: | |||
json_data = json.load(f) | |||
for pseudo_video_id in json_data: | |||
if pseudo_video_id in pseudo_video_id_list: | |||
print("reduplicate.") | |||
else: | |||
video_id = self._get_video_id_from_pseduo(pseudo_video_id) | |||
pseudo_video_id_list.append(pseudo_video_id) | |||
video_id_list.append(video_id) | |||
return pseudo_video_id_list, video_id_list | |||
def _get_captions_single(self, path): | |||
pseudo_caption_dict = {} | |||
with open(path, 'r') as f: | |||
json_data = json.load(f) | |||
for pseudo_video_id, v_ in json_data.items(): | |||
pseudo_caption_dict[pseudo_video_id] = {} | |||
timestamps = v_["timestamps"] | |||
pseudo_caption_dict[pseudo_video_id]["start"] = \ | |||
np.array([int(math.floor(float(itm[0]))) for itm in timestamps], dtype=object) | |||
pseudo_caption_dict[pseudo_video_id]["end"] = \ | |||
np.array([int(math.ceil(float(itm[1]))) for itm in timestamps], dtype=object) | |||
pseudo_caption_dict[pseudo_video_id]["text"] = np.array(v_["sentences"], dtype=object) | |||
return pseudo_caption_dict | |||
def _get_text(self, video_id, caption): | |||
k = 1 | |||
choice_video_ids = [video_id] | |||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||
for i, video_id in enumerate(choice_video_ids): | |||
words = self.tokenizer.tokenize(caption) | |||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||
total_length_with_CLS = self.max_words - 1 | |||
if len(words) > total_length_with_CLS: | |||
words = words[:total_length_with_CLS] | |||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||
input_mask = [1] * len(input_ids) | |||
segment_ids = [0] * len(input_ids) | |||
while len(input_ids) < self.max_words: | |||
input_ids.append(0) | |||
input_mask.append(0) | |||
segment_ids.append(0) | |||
assert len(input_ids) == self.max_words | |||
assert len(input_mask) == self.max_words | |||
assert len(segment_ids) == self.max_words | |||
pairs_text[i] = np.array(input_ids) | |||
pairs_mask[i] = np.array(input_mask) | |||
pairs_segment[i] = np.array(segment_ids) | |||
return pairs_text, pairs_mask, pairs_segment, choice_video_ids | |||
def _get_rawvideo(self, choice_video_ids): | |||
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) | |||
max_video_length = [0] * len(choice_video_ids) | |||
# Pair x L x T x 3 x H x W | |||
video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, | |||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||
try: | |||
for i, video_id in enumerate(choice_video_ids): | |||
video_path = self.video_dict[video_id] | |||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path) | |||
raw_video_data = raw_video_data['video'] | |||
if len(raw_video_data.shape) > 3: | |||
raw_video_data_clip = raw_video_data | |||
# L x T x 3 x H x W | |||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||
if self.max_frames < raw_video_slice.shape[0]: | |||
if self.slice_framepos == 0: | |||
video_slice = raw_video_slice[:self.max_frames, ...] | |||
elif self.slice_framepos == 1: | |||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||
else: | |||
sample_indx = np.linspace(0, raw_video_slice.shape[0]-1, num=self.max_frames, dtype=int) | |||
video_slice = raw_video_slice[sample_indx, ...] | |||
else: | |||
video_slice = raw_video_slice | |||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||
slice_len = video_slice.shape[0] | |||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||
if slice_len < 1: | |||
pass | |||
else: | |||
video[i][:slice_len, ...] = video_slice | |||
else: | |||
print("video path: {} error. video id: {}".format(video_path, video_id)) | |||
except Exception as excep: | |||
print("Video ids: {}".format(choice_video_ids)) | |||
raise excep | |||
for i, v_length in enumerate(max_video_length): | |||
video_mask[i][:v_length] = [1] * v_length | |||
return video, video_mask | |||
def __getitem__(self, feature_idx): | |||
clip_id, sentence = self.iter2video_pairs_dict[feature_idx] | |||
pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(clip_id, sentence) | |||
video, video_mask = self._get_rawvideo(choice_video_ids) | |||
return pairs_text, pairs_mask, pairs_segment, video, video_mask |
@@ -0,0 +1,300 @@ | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import unicode_literals | |||
from __future__ import print_function | |||
import os | |||
from torch.utils.data import Dataset | |||
import numpy as np | |||
import pandas as pd | |||
from collections import defaultdict | |||
import json | |||
import random | |||
from dataloaders.rawvideo_util import RawVideoExtractor | |||
class MSRVTT_DataLoader(Dataset): | |||
"""MSRVTT dataset loader.""" | |||
def __init__( | |||
self, | |||
csv_path, | |||
features_path, | |||
tokenizer, | |||
max_words=30, | |||
feature_framerate=1.0, | |||
max_frames=100, | |||
image_resolution=224, | |||
frame_order=0, | |||
slice_framepos=0, | |||
): | |||
self.data = pd.read_csv(csv_path) | |||
self.features_path = features_path | |||
self.feature_framerate = feature_framerate | |||
self.max_words = max_words | |||
self.max_frames = max_frames | |||
self.tokenizer = tokenizer | |||
# 0: ordinary order; 1: reverse order; 2: random order. | |||
self.frame_order = frame_order | |||
assert self.frame_order in [0, 1, 2] | |||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||
self.slice_framepos = slice_framepos | |||
assert self.slice_framepos in [0, 1, 2] | |||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||
def __len__(self): | |||
return len(self.data) | |||
def _get_text(self, video_id, sentence): | |||
choice_video_ids = [video_id] | |||
n_caption = len(choice_video_ids) | |||
k = n_caption | |||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||
for i, video_id in enumerate(choice_video_ids): | |||
words = self.tokenizer.tokenize(sentence) | |||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||
total_length_with_CLS = self.max_words - 1 | |||
if len(words) > total_length_with_CLS: | |||
words = words[:total_length_with_CLS] | |||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||
input_mask = [1] * len(input_ids) | |||
segment_ids = [0] * len(input_ids) | |||
while len(input_ids) < self.max_words: | |||
input_ids.append(0) | |||
input_mask.append(0) | |||
segment_ids.append(0) | |||
assert len(input_ids) == self.max_words | |||
assert len(input_mask) == self.max_words | |||
assert len(segment_ids) == self.max_words | |||
pairs_text[i] = np.array(input_ids) | |||
pairs_mask[i] = np.array(input_mask) | |||
pairs_segment[i] = np.array(segment_ids) | |||
return pairs_text, pairs_mask, pairs_segment, choice_video_ids | |||
def _get_rawvideo(self, choice_video_ids): | |||
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) | |||
max_video_length = [0] * len(choice_video_ids) | |||
# Pair x L x T x 3 x H x W | |||
video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, | |||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||
for i, video_id in enumerate(choice_video_ids): | |||
# Individual for YoucokII dataset, due to it video format | |||
video_path = os.path.join(self.features_path, "{}.mp4".format(video_id)) | |||
if os.path.exists(video_path) is False: | |||
video_path = video_path.replace(".mp4", ".webm") | |||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path) | |||
raw_video_data = raw_video_data['video'] | |||
if len(raw_video_data.shape) > 3: | |||
raw_video_data_clip = raw_video_data | |||
# L x T x 3 x H x W | |||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||
if self.max_frames < raw_video_slice.shape[0]: | |||
if self.slice_framepos == 0: | |||
video_slice = raw_video_slice[:self.max_frames, ...] | |||
elif self.slice_framepos == 1: | |||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||
else: | |||
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) | |||
video_slice = raw_video_slice[sample_indx, ...] | |||
else: | |||
video_slice = raw_video_slice | |||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||
slice_len = video_slice.shape[0] | |||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||
if slice_len < 1: | |||
pass | |||
else: | |||
video[i][:slice_len, ...] = video_slice | |||
else: | |||
print("video path: {} error. video id: {}".format(video_path, video_id)) | |||
for i, v_length in enumerate(max_video_length): | |||
video_mask[i][:v_length] = [1] * v_length | |||
return video, video_mask | |||
def __getitem__(self, idx): | |||
video_id = self.data['video_id'].values[idx] | |||
sentence = self.data['sentence'].values[idx] | |||
pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, sentence) | |||
video, video_mask = self._get_rawvideo(choice_video_ids) | |||
return pairs_text, pairs_mask, pairs_segment, video, video_mask | |||
class MSRVTT_TrainDataLoader(Dataset): | |||
"""MSRVTT train dataset loader.""" | |||
def __init__( | |||
self, | |||
csv_path, | |||
json_path, | |||
features_path, | |||
tokenizer, | |||
max_words=30, | |||
feature_framerate=1.0, | |||
max_frames=100, | |||
unfold_sentences=False, | |||
image_resolution=224, | |||
frame_order=0, | |||
slice_framepos=0, | |||
): | |||
self.csv = pd.read_csv(csv_path) | |||
self.data = json.load(open(json_path, 'r')) | |||
self.features_path = features_path | |||
self.feature_framerate = feature_framerate | |||
self.max_words = max_words | |||
self.max_frames = max_frames | |||
self.tokenizer = tokenizer | |||
# 0: ordinary order; 1: reverse order; 2: random order. | |||
self.frame_order = frame_order | |||
assert self.frame_order in [0, 1, 2] | |||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||
self.slice_framepos = slice_framepos | |||
assert self.slice_framepos in [0, 1, 2] | |||
self.unfold_sentences = unfold_sentences | |||
self.sample_len = 0 | |||
if self.unfold_sentences: | |||
train_video_ids = list(self.csv['video_id'].values) | |||
self.sentences_dict = {} | |||
for itm in self.data['sentences']: | |||
if itm['video_id'] in train_video_ids: | |||
self.sentences_dict[len(self.sentences_dict)] = (itm['video_id'], itm['caption']) | |||
self.sample_len = len(self.sentences_dict) | |||
else: | |||
num_sentences = 0 | |||
self.sentences = defaultdict(list) | |||
s_video_id_set = set() | |||
for itm in self.data['sentences']: | |||
self.sentences[itm['video_id']].append(itm['caption']) | |||
num_sentences += 1 | |||
s_video_id_set.add(itm['video_id']) | |||
# Use to find the clips in the same video | |||
self.parent_ids = {} | |||
self.children_video_ids = defaultdict(list) | |||
for itm in self.data['videos']: | |||
vid = itm["video_id"] | |||
url_posfix = itm["url"].split("?v=")[-1] | |||
self.parent_ids[vid] = url_posfix | |||
self.children_video_ids[url_posfix].append(vid) | |||
self.sample_len = len(self.csv) | |||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||
def __len__(self): | |||
return self.sample_len | |||
def _get_text(self, video_id, caption=None): | |||
k = 1 | |||
choice_video_ids = [video_id] | |||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||
for i, video_id in enumerate(choice_video_ids): | |||
if caption is not None: | |||
words = self.tokenizer.tokenize(caption) | |||
else: | |||
words = self._get_single_text(video_id) | |||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||
total_length_with_CLS = self.max_words - 1 | |||
if len(words) > total_length_with_CLS: | |||
words = words[:total_length_with_CLS] | |||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||
input_mask = [1] * len(input_ids) | |||
segment_ids = [0] * len(input_ids) | |||
while len(input_ids) < self.max_words: | |||
input_ids.append(0) | |||
input_mask.append(0) | |||
segment_ids.append(0) | |||
assert len(input_ids) == self.max_words | |||
assert len(input_mask) == self.max_words | |||
assert len(segment_ids) == self.max_words | |||
pairs_text[i] = np.array(input_ids) | |||
pairs_mask[i] = np.array(input_mask) | |||
pairs_segment[i] = np.array(segment_ids) | |||
return pairs_text, pairs_mask, pairs_segment, choice_video_ids | |||
def _get_single_text(self, video_id): | |||
rind = random.randint(0, len(self.sentences[video_id]) - 1) | |||
caption = self.sentences[video_id][rind] | |||
words = self.tokenizer.tokenize(caption) | |||
return words | |||
def _get_rawvideo(self, choice_video_ids): | |||
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) | |||
max_video_length = [0] * len(choice_video_ids) | |||
# Pair x L x T x 3 x H x W | |||
video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, | |||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||
for i, video_id in enumerate(choice_video_ids): | |||
# Individual for YoucokII dataset, due to it video format | |||
video_path = os.path.join(self.features_path, "{}.mp4".format(video_id)) | |||
if os.path.exists(video_path) is False: | |||
video_path = video_path.replace(".mp4", ".webm") | |||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path) | |||
raw_video_data = raw_video_data['video'] | |||
if len(raw_video_data.shape) > 3: | |||
raw_video_data_clip = raw_video_data | |||
# L x T x 3 x H x W | |||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||
if self.max_frames < raw_video_slice.shape[0]: | |||
if self.slice_framepos == 0: | |||
video_slice = raw_video_slice[:self.max_frames, ...] | |||
elif self.slice_framepos == 1: | |||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||
else: | |||
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) | |||
video_slice = raw_video_slice[sample_indx, ...] | |||
else: | |||
video_slice = raw_video_slice | |||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||
slice_len = video_slice.shape[0] | |||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||
if slice_len < 1: | |||
pass | |||
else: | |||
video[i][:slice_len, ...] = video_slice | |||
else: | |||
print("video path: {} error. video id: {}".format(video_path, video_id)) | |||
for i, v_length in enumerate(max_video_length): | |||
video_mask[i][:v_length] = [1] * v_length | |||
return video, video_mask | |||
def __getitem__(self, idx): | |||
if self.unfold_sentences: | |||
video_id, caption = self.sentences_dict[idx] | |||
else: | |||
video_id, caption = self.csv['video_id'].values[idx], None | |||
pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption) | |||
video, video_mask = self._get_rawvideo(choice_video_ids) | |||
return pairs_text, pairs_mask, pairs_segment, video, video_mask |
@@ -0,0 +1,180 @@ | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import unicode_literals | |||
from __future__ import print_function | |||
import os | |||
from torch.utils.data import Dataset | |||
import numpy as np | |||
import pickle | |||
from dataloaders.rawvideo_util import RawVideoExtractor | |||
class MSVD_DataLoader(Dataset): | |||
"""MSVD dataset loader.""" | |||
def __init__( | |||
self, | |||
subset, | |||
data_path, | |||
features_path, | |||
tokenizer, | |||
max_words=30, | |||
feature_framerate=1.0, | |||
max_frames=100, | |||
image_resolution=224, | |||
frame_order=0, | |||
slice_framepos=0, | |||
): | |||
self.data_path = data_path | |||
self.features_path = features_path | |||
self.feature_framerate = feature_framerate | |||
self.max_words = max_words | |||
self.max_frames = max_frames | |||
self.tokenizer = tokenizer | |||
# 0: ordinary order; 1: reverse order; 2: random order. | |||
self.frame_order = frame_order | |||
assert self.frame_order in [0, 1, 2] | |||
# 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly. | |||
self.slice_framepos = slice_framepos | |||
assert self.slice_framepos in [0, 1, 2] | |||
self.subset = subset | |||
assert self.subset in ["train", "val", "test"] | |||
video_id_path_dict = {} | |||
video_id_path_dict["train"] = os.path.join(self.data_path, "train_list.txt") | |||
video_id_path_dict["val"] = os.path.join(self.data_path, "val_list.txt") | |||
video_id_path_dict["test"] = os.path.join(self.data_path, "test_list.txt") | |||
caption_file = os.path.join(self.data_path, "raw-captions.pkl") | |||
with open(video_id_path_dict[self.subset], 'r') as fp: | |||
video_ids = [itm.strip() for itm in fp.readlines()] | |||
with open(caption_file, 'rb') as f: | |||
captions = pickle.load(f) | |||
video_dict = {} | |||
for root, dub_dir, video_files in os.walk(self.features_path): | |||
for video_file in video_files: | |||
video_id_ = ".".join(video_file.split(".")[:-1]) | |||
if video_id_ not in video_ids: | |||
continue | |||
file_path_ = os.path.join(root, video_file) | |||
video_dict[video_id_] = file_path_ | |||
self.video_dict = video_dict | |||
self.sample_len = 0 | |||
self.sentences_dict = {} | |||
self.cut_off_points = [] | |||
for video_id in video_ids: | |||
assert video_id in captions | |||
for cap in captions[video_id]: | |||
cap_txt = " ".join(cap) | |||
self.sentences_dict[len(self.sentences_dict)] = (video_id, cap_txt) | |||
self.cut_off_points.append(len(self.sentences_dict)) | |||
## below variables are used to multi-sentences retrieval | |||
# self.cut_off_points: used to tag the label when calculate the metric | |||
# self.sentence_num: used to cut the sentence representation | |||
# self.video_num: used to cut the video representation | |||
self.multi_sentence_per_video = True # !!! important tag for eval | |||
if self.subset == "val" or self.subset == "test": | |||
self.sentence_num = len(self.sentences_dict) | |||
self.video_num = len(video_ids) | |||
assert len(self.cut_off_points) == self.video_num | |||
print("For {}, sentence number: {}".format(self.subset, self.sentence_num)) | |||
print("For {}, video number: {}".format(self.subset, self.video_num)) | |||
print("Video number: {}".format(len(self.video_dict))) | |||
print("Total Paire: {}".format(len(self.sentences_dict))) | |||
self.sample_len = len(self.sentences_dict) | |||
self.rawVideoExtractor = RawVideoExtractor(framerate=feature_framerate, size=image_resolution) | |||
self.SPECIAL_TOKEN = {"CLS_TOKEN": "<|startoftext|>", "SEP_TOKEN": "<|endoftext|>", | |||
"MASK_TOKEN": "[MASK]", "UNK_TOKEN": "[UNK]", "PAD_TOKEN": "[PAD]"} | |||
def __len__(self): | |||
return self.sample_len | |||
def _get_text(self, video_id, caption): | |||
k = 1 | |||
choice_video_ids = [video_id] | |||
pairs_text = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long) | |||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long) | |||
for i, video_id in enumerate(choice_video_ids): | |||
words = self.tokenizer.tokenize(caption) | |||
words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words | |||
total_length_with_CLS = self.max_words - 1 | |||
if len(words) > total_length_with_CLS: | |||
words = words[:total_length_with_CLS] | |||
words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]] | |||
input_ids = self.tokenizer.convert_tokens_to_ids(words) | |||
input_mask = [1] * len(input_ids) | |||
segment_ids = [0] * len(input_ids) | |||
while len(input_ids) < self.max_words: | |||
input_ids.append(0) | |||
input_mask.append(0) | |||
segment_ids.append(0) | |||
assert len(input_ids) == self.max_words | |||
assert len(input_mask) == self.max_words | |||
assert len(segment_ids) == self.max_words | |||
pairs_text[i] = np.array(input_ids) | |||
pairs_mask[i] = np.array(input_mask) | |||
pairs_segment[i] = np.array(segment_ids) | |||
return pairs_text, pairs_mask, pairs_segment, choice_video_ids | |||
def _get_rawvideo(self, choice_video_ids): | |||
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long) | |||
max_video_length = [0] * len(choice_video_ids) | |||
# Pair x L x T x 3 x H x W | |||
video = np.zeros((len(choice_video_ids), self.max_frames, 1, 3, | |||
self.rawVideoExtractor.size, self.rawVideoExtractor.size), dtype=np.float) | |||
for i, video_id in enumerate(choice_video_ids): | |||
video_path = self.video_dict[video_id] | |||
raw_video_data = self.rawVideoExtractor.get_video_data(video_path) | |||
raw_video_data = raw_video_data['video'] | |||
if len(raw_video_data.shape) > 3: | |||
raw_video_data_clip = raw_video_data | |||
# L x T x 3 x H x W | |||
raw_video_slice = self.rawVideoExtractor.process_raw_data(raw_video_data_clip) | |||
if self.max_frames < raw_video_slice.shape[0]: | |||
if self.slice_framepos == 0: | |||
video_slice = raw_video_slice[:self.max_frames, ...] | |||
elif self.slice_framepos == 1: | |||
video_slice = raw_video_slice[-self.max_frames:, ...] | |||
else: | |||
sample_indx = np.linspace(0, raw_video_slice.shape[0] - 1, num=self.max_frames, dtype=int) | |||
video_slice = raw_video_slice[sample_indx, ...] | |||
else: | |||
video_slice = raw_video_slice | |||
video_slice = self.rawVideoExtractor.process_frame_order(video_slice, frame_order=self.frame_order) | |||
slice_len = video_slice.shape[0] | |||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_len else slice_len | |||
if slice_len < 1: | |||
pass | |||
else: | |||
video[i][:slice_len, ...] = video_slice | |||
else: | |||
print("video path: {} error. video id: {}".format(video_path, video_id)) | |||
for i, v_length in enumerate(max_video_length): | |||
video_mask[i][:v_length] = [1] * v_length | |||
return video, video_mask | |||
def __getitem__(self, idx): | |||
video_id, caption = self.sentences_dict[idx] | |||
pairs_text, pairs_mask, pairs_segment, choice_video_ids = self._get_text(video_id, caption) | |||
video, video_mask = self._get_rawvideo(choice_video_ids) | |||
return pairs_text, pairs_mask, pairs_segment, video, video_mask |
@@ -0,0 +1,106 @@ | |||
import torch as th | |||
import numpy as np | |||
from PIL import Image | |||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |||
import cv2 | |||
class RawVideoExtractorCV2(): | |||
def __init__(self, centercrop=False, size=224, framerate=-1, ): | |||
self.centercrop = centercrop | |||
self.size = size | |||
self.framerate = framerate | |||
self.transform = self._transform(self.size) | |||
def _transform(self, n_px): | |||
return Compose([ | |||
Resize(n_px, interpolation=Image.BICUBIC), | |||
CenterCrop(n_px), | |||
lambda image: image.convert("RGB"), | |||
ToTensor(), | |||
Normalize((0.48145466, 0.4578275, 0.40821073), | |||
(0.26862954, 0.26130258, 0.27577711)), | |||
]) | |||
def video_to_tensor(self, video_file, preprocess, sample_fp=0, start_time=None, end_time=None): | |||
if start_time is not None or end_time is not None: | |||
assert isinstance(start_time, int) and isinstance(end_time, int) \ | |||
and start_time > -1 and end_time > start_time | |||
assert sample_fp > -1 | |||
# Samples a frame sample_fp X frames. | |||
cap = cv2.VideoCapture(video_file) | |||
frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |||
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |||
total_duration = (frameCount + fps - 1) // fps | |||
start_sec, end_sec = 0, total_duration | |||
if start_time is not None: | |||
start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration | |||
cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps)) | |||
interval = 1 | |||
if sample_fp > 0: | |||
interval = fps // sample_fp | |||
else: | |||
sample_fp = fps | |||
if interval == 0: | |||
interval = 1 | |||
inds = [ind for ind in np.arange(0, fps, interval)] | |||
assert len(inds) >= sample_fp | |||
inds = inds[:sample_fp] | |||
ret = True | |||
images, included = [], [] | |||
for sec in np.arange(start_sec, end_sec + 1): | |||
if not ret: | |||
break | |||
sec_base = int(sec * fps) | |||
for ind in inds: | |||
cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind) | |||
ret, frame = cap.read() | |||
if not ret: | |||
break | |||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |||
images.append(preprocess( | |||
Image.fromarray(frame_rgb).convert("RGB"))) | |||
cap.release() | |||
if len(images) > 0: | |||
video_data = th.tensor(np.stack(images)) | |||
else: | |||
video_data = th.zeros(1) | |||
return {'video': video_data} | |||
def get_video_data(self, video_path, start_time=None, end_time=None): | |||
image_input = self.video_to_tensor( | |||
video_path, self.transform, sample_fp=self.framerate, start_time=start_time, end_time=end_time) | |||
return image_input | |||
def process_raw_data(self, raw_video_data): | |||
tensor_size = raw_video_data.size() | |||
tensor = raw_video_data.view(-1, 1, | |||
tensor_size[-3], tensor_size[-2], tensor_size[-1]) | |||
return tensor | |||
def process_frame_order(self, raw_video_data, frame_order=0): | |||
# 0: ordinary order; 1: reverse order; 2: random order. | |||
if frame_order == 0: | |||
pass | |||
elif frame_order == 1: | |||
reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1) | |||
raw_video_data = raw_video_data[reverse_order, ...] | |||
elif frame_order == 2: | |||
random_order = np.arange(raw_video_data.size(0)) | |||
np.random.shuffle(random_order) | |||
raw_video_data = raw_video_data[random_order, ...] | |||
return raw_video_data | |||
# An ordinary video frame extractor based CV2 | |||
RawVideoExtractor = RawVideoExtractorCV2 |
@@ -0,0 +1,104 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"!pip install transformers\n", | |||
"!pip install pytube" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"!git clone https://github.com/abreza/clip-video-embedder\n", | |||
"%cd clip-video-embedder/" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 1, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"from utils.video_loader import download_video_from_youtube\n", | |||
"\n", | |||
"video_id = \"4uw4co69JUQ\"\n", | |||
"video_path = download_video_from_youtube('4uw4co69JUQ', './videos')" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"from transformers import CLIPModel, CLIPProcessor\n", | |||
"\n", | |||
"model = CLIPModel.from_pretrained(\"clip-vit-large-patch14\")\n", | |||
"processor = CLIPProcessor.from_pretrained(\"clip-vit-large-patch14\")" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"import numpy as np\n", | |||
"\n", | |||
"from modules.frame_sampler.uniform_sampler import uniform_sample_frames\n", | |||
"from modules.frame_sampler.aggregate_embedding import aggregate_embeddings\n", | |||
"\n", | |||
"num_frames = 10\n", | |||
"video_frames = uniform_sample_frames(video_path, num_frames)\n", | |||
"\n", | |||
"frame_embeddings = []\n", | |||
"for frame in video_frames:\n", | |||
" inputs = processor(images=frame, return_tensors=\"pt\")\n", | |||
" outputs = model.get_image_features(**inputs)\n", | |||
" frame_embedding = outputs.detach().numpy()\n", | |||
" frame_embeddings.append(frame_embedding)\n", | |||
"\n", | |||
"video_embedding = aggregate_embeddings(frame_embeddings, strategy=\"mean\")\n", | |||
"\n", | |||
"# Compare the video embedding to a given text\n", | |||
"texts = [\"a dog playing in the park\", \"playing basketball\"]\n", | |||
"for text in texts:\n", | |||
" text_inputs = processor(text=text, return_tensors=\"pt\", padding=True)\n", | |||
" text_outputs = model.get_text_features(**text_inputs)\n", | |||
" text_embedding = text_outputs.detach().numpy()\n", | |||
" \n", | |||
" similarity = np.inner(video_embedding, text_embedding)\n", | |||
" print(text, similarity)" | |||
] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "venv", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.10.10" | |||
}, | |||
"orig_nbformat": 4 | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 2 | |||
} |
@@ -0,0 +1,31 @@ | |||
from utils.dotdict import DotNotationDict | |||
from dataloaders.data_dataloaders import DATALOADER_DICT | |||
from transformers import CLIPTokenizer | |||
def main(): | |||
tokenizer = CLIPTokenizer.from_pretrained("clip-vit-large-patch14") | |||
train_args = DotNotationDict({ | |||
"train_csv": 'MSRVTT_train.9k.csv', | |||
"data_path": 'MSRVTT_data.json', | |||
"features_path": 'MSRVTT_Videos', | |||
"max_words": 32, | |||
"feature_framerate": 1, | |||
"max_frames": 100, | |||
"expand_msrvtt_sentences": True, | |||
"train_frame_order": 0, | |||
"slice_framepos": 2, | |||
"batch_size": 128, | |||
"n_gpu": 1, | |||
"num_thread_reader": 0, | |||
}) | |||
train_dataloader, train_length = DATALOADER_DICT['msrvtt']["train"]( | |||
train_args, tokenizer) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,46 @@ | |||
import numpy as np | |||
from transformers import CLIPProcessor, CLIPModel | |||
class VideoEmbedder: | |||
def __init__(self, model_name="openai/clip-vit-large-patch14"): | |||
self.model = CLIPModel.from_pretrained(model_name) | |||
self.processor = CLIPProcessor.from_pretrained(model_name) | |||
def uniform_sampling(self, frames, num_samples): | |||
frame_count = len(frames) | |||
indices = np.linspace(0, frame_count - 1, num_samples, dtype=int) | |||
sampled_frames = [frames[i] for i in indices] | |||
return sampled_frames | |||
def get_frame_embeddings(self, frames): | |||
inputs = self.processor(images=frames, return_tensors="pt") | |||
outputs = self.model.get_image_features(**inputs) | |||
frame_embeddings = outputs.detach().numpy() | |||
return frame_embeddings | |||
def aggregate_embeddings(self, embeddings, strategy="mean"): | |||
if strategy == "mean": | |||
return np.mean(embeddings, axis=0) | |||
elif strategy == "average": | |||
return np.average(embeddings, axis=0) | |||
else: | |||
raise ValueError("Invalid aggregation strategy") | |||
def embed_video(self, frames, num_samples, aggregation_strategy="mean"): | |||
sampled_frames = self.uniform_sampling(frames, num_samples) | |||
frame_embeddings = self.get_frame_embeddings(sampled_frames) | |||
aggregated_embedding = self.aggregate_embeddings( | |||
frame_embeddings, aggregation_strategy) | |||
return aggregated_embedding | |||
def get_text_embedding(self, text): | |||
text_inputs = self.processor( | |||
text=text, return_tensors="pt", padding=True) | |||
text_outputs = self.model.get_text_features(**text_inputs) | |||
text_embedding = text_outputs.detach().numpy() | |||
return text_embedding | |||
def similarity(self, video_embedding, text_embedding): | |||
similarity_score = np.dot(video_embedding, text_embedding.T) | |||
return similarity_score |
@@ -0,0 +1,8 @@ | |||
import numpy as np | |||
def aggregate_embeddings(embeddings, strategy="mean"): | |||
if strategy == "mean": | |||
return np.mean(embeddings, axis=0) | |||
else: | |||
raise ValueError(f"Unknown aggregation strategy: {strategy}") |
@@ -0,0 +1,41 @@ | |||
import timm | |||
import torch.nn as nn | |||
from transformers import CLIPProcessor, CLIPModel | |||
class SaliencyNet(nn.Module): | |||
def __init__(self): | |||
super(SaliencyNet, self).__init__() | |||
self.model = timm.create_model('resnet50', pretrained=True) | |||
self.fc1 = nn.Linear(self.model.num_features, 512) | |||
self.fc2 = nn.Linear(512, 1) | |||
self.relu = nn.ReLU() | |||
self.sigmoid = nn.Sigmoid() | |||
def forward(self, images): | |||
features = self.model.forward_features(images) | |||
x = self.fc1(features) | |||
x = self.relu(x) | |||
x = self.fc2(x) | |||
x = self.sigmoid(x) | |||
return x | |||
class CLIPTeacher(nn.Module): | |||
def __init__(self): | |||
super(CLIPTeacher, self).__init__() | |||
self.model = CLIPModel.from_pretrained('openai/clip-vit-large-patch14') | |||
self.processor = CLIPProcessor.from_pretrained( | |||
'openai/clip-vit-large-patch14') | |||
def forward(self, images, descriptions): | |||
inputs = self.processor( | |||
text=descriptions, images=images, return_tensors="pt", padding=True | |||
) | |||
outputs = self.model(**inputs) | |||
logits_per_image = outputs.logits_per_image | |||
average_scores = logits_per_image.mean(dim=1) | |||
return average_scores |
@@ -0,0 +1,53 @@ | |||
import torch | |||
import torch.nn as nn | |||
from torch.utils.data import DataLoader | |||
def train_salient_frame_sampler(teacher, student, train_dataloader: DataLoader, val_dataloader: DataLoader, epochs: int, optimizer, device): | |||
teacher.eval() | |||
teacher.to(device) | |||
student.train() | |||
student.to(device) | |||
criterion = nn.MSELoss() | |||
num_samples = 10 | |||
for epoch in range(epochs): | |||
running_loss = 0.0 | |||
for frames, descriptions in train_dataloader: | |||
frames = frames.to(device) | |||
descriptions = descriptions.to(device) | |||
student_scores = student(frames).squeeze() | |||
with torch.no_grad(): | |||
teacher_scores = teacher(frames, descriptions) | |||
optimizer.zero_grad() | |||
loss = criterion(student_scores, teacher_scores) | |||
loss.backward() | |||
optimizer.step() | |||
running_loss += loss.item() | |||
train_loss = running_loss / len(train_dataloader) | |||
print(f"Epoch {epoch + 1}/{epochs}, Training Loss: {train_loss:.4f}") | |||
# Calculate validation loss | |||
with torch.no_grad(): | |||
running_val_loss = 0.0 | |||
for val_frames, val_descriptions in val_dataloader: | |||
val_frames = val_frames.to(device) | |||
val_descriptions = val_descriptions.to(device) | |||
val_student_scores = student(val_frames).squeeze() | |||
val_teacher_scores = teacher(val_frames, val_descriptions) | |||
val_loss = criterion(val_student_scores, val_teacher_scores) | |||
running_val_loss += val_loss.item() | |||
val_loss = running_val_loss / len(val_dataloader) | |||
print( | |||
f"Epoch {epoch + 1}/{epochs}, Validation Loss: {val_loss:.4f}") |
@@ -0,0 +1,20 @@ | |||
import cv2 | |||
import numpy as np | |||
from PIL import Image | |||
def uniform_sample_frames(video_path, num_frames): | |||
cap = cv2.VideoCapture(video_path) | |||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |||
frame_idxs = np.linspace(0, total_frames - 1, num=num_frames, dtype=int) | |||
frames = [] | |||
for idx in frame_idxs: | |||
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |||
ret, frame = cap.read() | |||
if ret: | |||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |||
frames.append(Image.fromarray(frame)) | |||
cap.release() | |||
return frames |
@@ -0,0 +1,193 @@ | |||
{ | |||
"nbformat": 4, | |||
"nbformat_minor": 0, | |||
"metadata": { | |||
"colab": { | |||
"provenance": [] | |||
}, | |||
"kernelspec": { | |||
"name": "python3", | |||
"display_name": "Python 3" | |||
}, | |||
"language_info": { | |||
"name": "python" | |||
} | |||
}, | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "QGkNRWwCX4Z4" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"!wget https://storage.googleapis.com/deepmind-media/Datasets/kinetics700_2020.tar.gz" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"!tar -xf kinetics700_2020.tar.gz" | |||
], | |||
"metadata": { | |||
"id": "87GXSaprYGxe" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"import pandas as pd\n", | |||
"\n", | |||
"df = pd.read_csv('/content/kinetics700_2020/train.csv')\n", | |||
"\n", | |||
"unique_labels = df['label'].unique()\n", | |||
"\n", | |||
"for i, label in enumerate(unique_labels):\n", | |||
" print(f\"{i}- {label}\")" | |||
], | |||
"metadata": { | |||
"id": "wJ4USlBtZWi9" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"def merge_files(real_labels_file, gpt_labels_file):\n", | |||
" with open(real_labels_file, 'r') as real_file, open(gpt_labels_file, 'r') as gpt_file:\n", | |||
" real_labels = real_file.readlines()\n", | |||
" gpt_labels = gpt_file.readlines()\n", | |||
"\n", | |||
" merged_labels = []\n", | |||
" for real, gpt in zip(real_labels, gpt_labels):\n", | |||
" real_label = real.split('-')[1].strip()\n", | |||
" gpt_label = gpt.split('-')[1].strip()\n", | |||
"\n", | |||
" # Merge the labels and add to the list\n", | |||
" merged_labels.append(f'{real_label}: {gpt_label}')\n", | |||
"\n", | |||
" return merged_labels\n", | |||
"\n", | |||
"real_labels_file = '/content/real-labels.txt'\n", | |||
"gpt_labels_file = '/content/chatgpt-labels.txt'\n", | |||
"\n", | |||
"merged_labels = merge_files(real_labels_file, gpt_labels_file)" | |||
], | |||
"metadata": { | |||
"id": "CUMXgnGkDXXB" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"!pip install transformers" | |||
], | |||
"metadata": { | |||
"id": "gvWmOILaIUI1" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"from transformers import CLIPProcessor, CLIPModel\n", | |||
"\n", | |||
"model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n", | |||
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")" | |||
], | |||
"metadata": { | |||
"id": "gPfdGYiHIMss" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"with open('enhanced-kinetics-label-CLIP-embedding.txt', 'w') as write_file:\n", | |||
" for label in merged_labels:\n", | |||
" text_inputs = processor(text=label, return_tensors=\"pt\", padding=True)\n", | |||
" text_features = model.get_text_features(**text_inputs).float()\n", | |||
" # format output to only have 4 digits after decimal point\n", | |||
" formatted_features = ['{:.4f}'.format(val) for val in text_features[0].tolist()]\n", | |||
" write_file.write(label.replace(':', ' |') + ' | ' + str(formatted_features) + '\\n')" | |||
], | |||
"metadata": { | |||
"id": "x84uh8PdIdXI" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"!wget https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip --no-check-certificate\n", | |||
"!unzip UCF101TrainTestSplits-RecognitionTask.zip" | |||
], | |||
"metadata": { | |||
"colab": { | |||
"base_uri": "https://localhost:8080/" | |||
}, | |||
"id": "jbg7PJSXO9so", | |||
"outputId": "f2e84ec0-37ae-49c3-a732-d35529d38eed" | |||
}, | |||
"execution_count": null, | |||
"outputs": [ | |||
{ | |||
"output_type": "stream", | |||
"name": "stdout", | |||
"text": [ | |||
"--2023-08-02 13:01:00-- https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip\n", | |||
"Resolving www.crcv.ucf.edu (www.crcv.ucf.edu)... 132.170.214.127\n", | |||
"Connecting to www.crcv.ucf.edu (www.crcv.ucf.edu)|132.170.214.127|:443... connected.\n", | |||
"WARNING: cannot verify www.crcv.ucf.edu's certificate, issued by ‘CN=InCommon RSA Server CA,OU=InCommon,O=Internet2,L=Ann Arbor,ST=MI,C=US’:\n", | |||
" Unable to locally verify the issuer's authority.\n", | |||
"HTTP request sent, awaiting response... 200 OK\n", | |||
"Length: 113943 (111K) [application/zip]\n", | |||
"Saving to: ‘UCF101TrainTestSplits-RecognitionTask.zip’\n", | |||
"\n", | |||
"\r UCF101Tra 0%[ ] 0 --.-KB/s \rUCF101TrainTestSpli 100%[===================>] 111.27K --.-KB/s in 0.05s \n", | |||
"\n", | |||
"2023-08-02 13:01:00 (2.03 MB/s) - ‘UCF101TrainTestSplits-RecognitionTask.zip’ saved [113943/113943]\n", | |||
"\n", | |||
"Archive: UCF101TrainTestSplits-RecognitionTask.zip\n", | |||
" creating: ucfTrainTestlist/\n", | |||
" inflating: ucfTrainTestlist/classInd.txt \n", | |||
" inflating: ucfTrainTestlist/testlist01.txt \n", | |||
" inflating: ucfTrainTestlist/testlist02.txt \n", | |||
" inflating: ucfTrainTestlist/testlist03.txt \n", | |||
" inflating: ucfTrainTestlist/trainlist01.txt \n", | |||
" inflating: ucfTrainTestlist/trainlist02.txt \n", | |||
" inflating: ucfTrainTestlist/trainlist03.txt \n" | |||
] | |||
} | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"source": [ | |||
"with open('enhanced-ucf-101-label-CLIP-embedding.txt', 'w') as write_file:\n", | |||
" with open('ucf101-labels.txt', 'r') as ucf101_labels_file:\n", | |||
" ucf101_labels = ucf101_labels_file.readlines()\n", | |||
" for label in ucf101_labels:\n", | |||
" text_inputs = processor(text=label, return_tensors=\"pt\", padding=True)\n", | |||
" text_features = model.get_text_features(**text_inputs).float()\n", | |||
" # format output to only have 4 digits after decimal point\n", | |||
" formatted_features = ['{:.4f}'.format(val) for val in text_features[0].tolist()]\n", | |||
" write_file.write(label.replace(':', ' |') + ' | ' + str(formatted_features) + '\\n')" | |||
], | |||
"metadata": { | |||
"id": "jAQgg2nMgIE6" | |||
}, | |||
"execution_count": null, | |||
"outputs": [] | |||
} | |||
] | |||
} |
@@ -0,0 +1,7 @@ | |||
pytube | |||
timm | |||
transformers | |||
torchvision | |||
@@ -0,0 +1,76 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"colab": { | |||
"background_save": true | |||
}, | |||
"id": "ucqaqtLhPY6T", | |||
"outputId": "ceea603c-8b61-4c19-cdc3-1de8062ecfc4" | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Mounted at /content/drive\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"from google.colab import drive\n", | |||
"drive.mount('/content/drive')" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"id": "BBOM9m-EQ7so" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"# !pip install opendatalab\n", | |||
"# !odl login\n", | |||
"# %cd /content/drive/MyDrive/Kinetics\n", | |||
"# !odl get Kinetics_700-2020" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": { | |||
"colab": { | |||
"background_save": true | |||
}, | |||
"id": "MTa17FVePkqM" | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"%cd /content/\n", | |||
"\n", | |||
"!git clone https://github.com/cvdfoundation/kinetics-dataset\n", | |||
"\n", | |||
"%cd /content/drive/MyDrive/Kinetics\n", | |||
"!bash /content/kinetics-dataset/k700_2020_downloader.sh\n", | |||
"!bash /content/kinetics-dataset/k700_2020_extractor.sh" | |||
] | |||
} | |||
], | |||
"metadata": { | |||
"colab": { | |||
"provenance": [] | |||
}, | |||
"kernelspec": { | |||
"display_name": "Python 3", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"name": "python" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 0 | |||
} |
@@ -0,0 +1,22 @@ | |||
from datasets.msrvtt.dataset import MSRVTTDataset | |||
from modules.clip_video_embedder import VideoEmbedder | |||
def main(): | |||
dataset = MSRVTTDataset() | |||
video_embedder = VideoEmbedder() | |||
for idx in range(len(dataset)): | |||
frames, captions = dataset[idx] | |||
video_embedding = video_embedder.embed_video(frames, num_samples=10) | |||
for caption in captions: | |||
text_embedding = video_embedder.get_text_embedding(caption) | |||
similarity_score = video_embedder.similarity( | |||
video_embedding, text_embedding) | |||
print( | |||
f"Similarity between video and caption '{caption}': {similarity_score}") | |||
if __name__ == "__main__": | |||
main() |
@@ -0,0 +1,11 @@ | |||
class DotNotationDict(dict): | |||
"""Enables dot notation access to dictionary attributes.""" | |||
def getattr(self, key): | |||
return self.get(key) | |||
def __setattr__(self, key, value): | |||
self.__setitem__(key, value) | |||
def __delattr__(self, key): | |||
self.__delitem__(key) |
@@ -0,0 +1,8 @@ | |||
import torch | |||
def uniform_sample_frames(frames, num_samples): | |||
total_frames = frames.shape[0] | |||
step_size = total_frames // num_samples | |||
indices = torch.arange(0, total_frames, step_size)[:num_samples] | |||
return frames[indices] |
@@ -0,0 +1,21 @@ | |||
import os | |||
import pytube | |||
def download_video_from_youtube(video_id_or_url, destination_path=None, video_name=None): | |||
video_url = video_id_or_url if "://" in video_id_or_url else f'https://youtu.be/{video_id_or_url}' | |||
video_id = pytube.extract.video_id(video_url) | |||
destination_path = destination_path if destination_path else '/content/' | |||
video_name = video_name if video_name else f'{video_id}.mp4' | |||
video_path = os.path.join(destination_path, video_name) | |||
if os.path.isfile(video_path): | |||
return video_path | |||
print(f'Downloading YouTube video {video_id}.') | |||
pytube.YouTube(video_url).streams.get_highest_resolution().download( | |||
destination_path, filename=video_name) | |||
print(f'Download complete.') | |||
return video_path |