In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install decord deepspeed einops timm==0.4.12 tensorboardX mpi4py

In [None]:
import os
import sys
import json
import warnings
import math
import argparse
import logging
import random
import gc
import tqdm

from collections import OrderedDict

import numpy as np
import pandas as pd
import deepspeed
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.utils.checkpoint as cp
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import default_collate
from torchvision import transforms
from timm.models import create_model
from timm.models.layers import trunc_normal_
from timm.models.registry import register_model
from timm.loss import SoftTargetCrossEntropy
from functools import partial
from datetime import datetime

[2023-08-11 09:16:10,313] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [None]:
%cd /content
!git clone https://github.com/OpenGVLab/VideoMAEv2

/content
fatal: destination path 'VideoMAEv2' already exists and is not an empty directory.


In [None]:
%cd /content/VideoMAEv2
from dataset import video_transforms
from dataset.loader import get_video_loader
from dataset.random_erasing import RandomErasing
from dataset.datasets import spatial_sampling, tensor_normalize
from utils import (
    load_state_dict,
    multiple_samples_collate,
)
from models.modeling_finetune import PatchEmbed, Block
from optim_factory import (
    LayerDecayValueAssigner,
    get_parameter_groups,
)
%cd /content/

/content/VideoMAEv2
/content


In [None]:
if not os.path.exists("vit_g_hybrid_pt_1200e.pth"):
    !gdown 1fqbZe2BHz2W6WDQnXIpZkEQVuIiiaf4-

In [None]:
class VideoClsBaseDataset(Dataset):
    def __init__(self,
                 anno_path,
                 data_root='',
                 clip_len=8,
                 frame_sample_rate=2,
                 short_side_size=256,
                 num_segment=1,
                 sparse_sample=False,
                 args=None):
        self.anno_path = anno_path
        self.data_root = data_root
        self.clip_len = clip_len
        self.frame_sample_rate = frame_sample_rate
        self.short_side_size = short_side_size
        self.num_segment = num_segment
        self.sparse_sample = sparse_sample
        self.args = args

        self.video_loader = get_video_loader()

        cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ')
        self.dataset_samples = list(
            cleaned[0].apply(lambda row: os.path.join(self.data_root, row)))
        self.label_array = list(cleaned.values[:, 1])

    def load_and_check_video(self, index):
        sample = self.dataset_samples[index]
        buffer = self.load_video(sample)
        while len(buffer) == 0:
            warnings.warn("Video {} not correctly loaded.".format(sample))
            index = np.random.randint(self.__len__())
            sample = self.dataset_samples[index]
            buffer = self.load_video(sample)
        return buffer, sample, index

    def load_video(self, fname, sample_rate_scale=1):
        try:
            vr = self.video_loader(fname)
        except Exception as e:
            print(f"Failed to load video from {fname} with error {e}!")
            return None, []
        length = len(vr)
        return vr, length

    def __len__(self):
        return len(self.dataset_samples)

In [None]:
class VideoClsTrainDataset(VideoClsBaseDataset):
    """Inherits from VideoClsBaseDataset. Used for training video classification models."""

    def __init__(self, crop_size=224, *args, **kwargs):
        """Initializes the training dataset object."""
        super().__init__(*args, **kwargs)
        self.crop_size = crop_size
        self.random_erasure = self.args.reprob > 0

    def __getitem__(self, index):
        """Returns the item at the given index after applying transformations."""
        buffer, sample, index = self.load_and_check_video(index)

        frame_list, label_list, index_list = [], [], []
        for _ in range(self.args.num_sample):
            augmented_frames = self._augment_frame(buffer)
            frame_list.append(augmented_frames)
            label_list.append(self.label_array[index])
            index_list.append(index)
        return frame_list, label_list, index_list, []


    def _augment_frame(self, buffer):
        """Applies the augmentation transformations to the frames."""
        augment_transform = video_transforms.create_random_augment(
            input_size=(self.crop_size, self.crop_size),
            auto_augment=self.args.aa,
            interpolation=self.args.train_interpolation,
        )

        buffer = [transforms.ToPILImage()(frame) for frame in buffer]
        buffer = augment_transform(buffer)
        buffer = [transforms.ToTensor()(img) for img in buffer]
        buffer = self._buffer_processing(buffer)
        buffer = self._spatial_sampling(buffer)
        if self.random_erasure:
            buffer = self._apply_random_erasure(buffer)

        return buffer

    def _buffer_processing(self, buffer):
        """Processes the buffer by stacking and normalizing the tensors."""
        buffer = torch.stack(buffer).permute(0, 2, 3, 1)  # T H W C
        buffer = tensor_normalize(buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # T H W C
        return buffer.permute(3, 0, 1, 2)  # C T H W

    def _spatial_sampling(self, buffer):
        """Applies spatial sampling on the buffer."""
        return spatial_sampling(
            buffer,
            spatial_idx=-1,
            min_scale=256,
            max_scale=320,
            crop_size=self.args.input_size,
            random_horizontal_flip=False if self.args.data_set == 'SSV2' else True,
            inverse_uniform_sampling=False,
            aspect_ratio=[0.75, 1.3333],
            scale=[0.08, 1.0],
            motion_shift=False,
        )

    def _apply_random_erasure(self, buffer):
        """Applies random erasure on the buffer."""
        erase_transform = RandomErasing(
            self.args.reprob,
            mode=self.args.remode,
            max_count=self.args.recount,
            num_splits=self.args.recount,
            device="cpu",
        )
        buffer = buffer.permute(1, 0, 2, 3)  # C T H W -> T C H W
        buffer = erase_transform(buffer)
        return buffer.permute(1, 0, 2, 3)  # T C H W -> C T H W

    def load_video(self, sample, sample_rate_scale=1):
        """Loads the video and returns a buffer."""
        video_reader, length = super().load_video(sample)
        if video_reader is None:
            return []

        all_index = self._generate_temporal_segments_indices(length)
        video_reader.seek(0)
        buffer = video_reader.get_batch(all_index).asnumpy()
        return buffer

    def _generate_temporal_segments_indices(self, length):
        """Generates indices for temporal segments."""
        converted_len = int(self.clip_len * self.frame_sample_rate)
        segment_length = length // self.num_segment
        indices = []
        for i in range(self.num_segment):
            segment_indices = self._generate_segment_indices(
                segment_length, converted_len, i)
            indices.extend(segment_indices)
        return indices[::int(1)]

    def _generate_segment_indices(self, segment_length, converted_len, segment_number):
        """Generates indices for a single segment."""
        if segment_length <= converted_len:
            index = np.linspace(
                0, segment_length, num=segment_length // self.frame_sample_rate)
            index = np.concatenate(
                (index,
                 np.ones(self.clip_len - segment_length // self.frame_sample_rate)
                 * segment_length))
        else:
            end_idx = np.random.randint(converted_len, segment_length)
            start_idx = end_idx - converted_len
            index = np.linspace(start_idx, end_idx, num=self.clip_len)

        index = np.clip(index, 0, segment_length - 1).astype(np.int64)
        return list(index + segment_number * segment_length)

    def __len__(self):
        """Returns the length of the training dataset."""
        return len(self.dataset_samples)

In [None]:
LOG_10000 = math.log(10000.0)

def get_sinusoid_encoding_table(n_position, d_hid):
    ''' Sinusoid position encoding table '''

    position = torch.arange(n_position, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_hid, 2, dtype=torch.float) *
                         -(LOG_10000 / d_hid))

    position = position * div_term

    sinusoid_table = torch.zeros(n_position, d_hid)
    sinusoid_table[:, 0::2] = torch.sin(position)  # dim 2i
    if d_hid > 1:
        sinusoid_table[:, 1::2] = torch.cos(position)  # dim 2i+1

    return sinusoid_table.unsqueeze(0)

In [None]:
class VisionTransformer(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """

    def __init__(self,
                img_size=224,
                patch_size=16,
                in_chans=3,
                embed_dim=768,
                depth=12,
                num_heads=12,
                mlp_ratio=4.,
                qkv_bias=False,
                qk_scale=None,
                drop_rate=0.,
                attn_drop_rate=0.,
                drop_path_rate=0.,
                norm_layer=nn.LayerNorm,
                init_values=0.,
                use_learnable_pos_emb=False,
                all_frames=16,
                tubelet_size=2,
                with_cp=False,
                cos_attn=False,
                clip_embed_dim=768):
        super().__init__()

        self.initialize_params(
            embed_dim, tubelet_size, img_size, patch_size, in_chans, all_frames,
            use_learnable_pos_emb, drop_rate, norm_layer,
            with_cp, clip_embed_dim)

        self.build_model(depth, num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate,
                         attn_drop_rate, drop_path_rate, init_values, cos_attn)

        self.initialize_weights()

    def initialize_params(self, embed_dim, tubelet_size, img_size, patch_size,
                          in_chans, all_frames, use_learnable_pos_emb, drop_rate,
                          norm_layer, with_cp, clip_embed_dim):
        self.embed_dim = embed_dim
        self.num_features = embed_dim
        self.tubelet_size = tubelet_size
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans,
            embed_dim=embed_dim, num_frames=all_frames, tubelet_size=tubelet_size)
        self.with_cp = with_cp
        self.norm_layer = norm_layer
        self.init_pos_embed(use_learnable_pos_emb, self.patch_embed.num_patches, embed_dim)
        self.pos_drop = nn.Dropout(p=drop_rate)
        self.final_norm = self.norm_layer(embed_dim)
        self.head = nn.Linear(self.embed_dim, clip_embed_dim)



    def init_pos_embed(self, use_learnable_pos_emb, num_patches, embed_dim):
        if use_learnable_pos_emb:
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.pos_embed, std=.02)
        else:
            self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)


    def build_model(self, depth, num_heads, mlp_ratio, qkv_bias, qk_scale,
                    drop_rate, attn_drop_rate, drop_path_rate, init_values, cos_attn):
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            Block(
                dim=self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                attn_drop=attn_drop_rate, drop_path=dpr[i],
                norm_layer=self.norm_layer, init_values=init_values,
                cos_attn=cos_attn
            ) for i in range(depth)
        ])
        self.additional_block = Block(
            dim=self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
            attn_drop=attn_drop_rate, drop_path=drop_path_rate,
            norm_layer=self.norm_layer, init_values=init_values,
            cos_attn=cos_attn
        )

    def initialize_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x + self.pos_embed.expand(B, -1, -1).to(x.device).clone().detach()
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = cp.checkpoint(blk, x) if self.with_cp else blk(x)
        x = cp.checkpoint(self.additional_block, x) if self.with_cp else self.additional_block(x)
        return self.final_norm(x.mean(1))


    def get_num_layers(self):
        return len(self.blocks)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}


@register_model
def vit_giant_patch14_224(pretrained=False, **kwargs):
    return VisionTransformer(
        patch_size=14,
        embed_dim=1408,
        depth=40,
        num_heads=16,
        mlp_ratio=48 / 11,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs)

In [None]:
args = argparse.Namespace()
args.model = 'vit_giant_patch14_224'
args.data_set = 'kinetics'
args.nb_classes = 101
args.data_path = '/content/kinetics/'
args.finetune = '/content/vit_g_hybrid_pt_1200e.pth'
args.log_dir = '/content/vit_g_hybrid_pt_1200e_kinetics_ft'
args.output_dir = '/content/'
args.batch_size = 6
args.input_size = 224
args.short_side_size = 224
args.num_frames = 16
args.sampling_rate = 4
args.num_sample = 2
args.num_workers = 10
args.opt = 'adamw'
args.opt_eps = 1e-8
args.opt_betas = [0.9, 0.999]
args.lr = 1e-3
args.min_lr = 1e-6
args.drop = 0.0
args.attn_drop_rate = 0.0
args.drop_path = 0.35
args.clip_grad = None # 5.0
args.aa = 'rand-m7-n4-mstd0.5-inc1'
args.layer_decay = 0.92 # 0.9
args.weight_decay = 0.06 # 0.05
args.epochs = 5

args.tubelet_size = 2
args.with_checkpoint = True
args.train_interpolation = 'bicubic'
args.reprob = 0.25
args.remode = 'pixel'
args.recount = 1
args.data_root = ''

args.num_segments = 1

args.start_epoch = 0

args.pin_mem = True

In [None]:
args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
with open(args.deepspeed_config, mode="w") as writer:
    ds_config = {
        "train_batch_size": args.batch_size,
        "train_micro_batch_size_per_gpu": args.batch_size,
        "steps_per_print": 1000,
        "gradient_clipping": 0.0 if args.clip_grad is None else args.clip_grad,
        "optimizer": {
            "type": "Adam",
            "adam_w_mode": True,
            "params": {
                "lr": args.lr,
                "weight_decay": args.weight_decay,
                "bias_correction": True,
                "betas": args.opt_betas,
                "eps": args.opt_eps
            }
        },
        "fp16": {
            "enabled": True,
            "loss_scale": 0,
            "initial_scale_power": 7,
            "loss_scale_window": 128
        }
    }

    writer.write(json.dumps(ds_config, indent=2))

In [None]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.FileHandler('model_training.log')
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False

classification_logger = logging.getLogger('classification_log')
file_handler = logging.FileHandler('classification_log.log')
formatter = logging.Formatter('%(message)s')
file_handler.setFormatter(formatter)
classification_logger.addHandler(file_handler)
classification_logger.setLevel(logging.INFO)
classification_logger.propagate = False

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
cudnn.benchmark = True

In [None]:
gc.collect()
model = create_model(
    args.model,
    img_size=args.input_size,
    pretrained=False,
    all_frames=args.num_frames * args.num_segments,
    tubelet_size=args.tubelet_size,
    drop_rate=args.drop,
    drop_path_rate=args.drop_path,
    attn_drop_rate=args.attn_drop_rate,
    drop_block_rate=None,
    with_cp=args.with_checkpoint,
)

In [None]:
train_dataset = VideoClsTrainDataset(
    anno_path=os.path.join(args.data_path, 'train.csv'),
    data_root=args.data_root,
    clip_len=args.num_frames,
    frame_sample_rate=args.sampling_rate,
    num_segment=1,
    crop_size=args.input_size,
    short_side_size=args.short_side_size,
    args=args)


data_loader_train = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    shuffle=True,
    drop_last=True,
    collate_fn=partial(multiple_samples_collate, fold=False),
    persistent_workers=True)

In [None]:
checkpoint = torch.load(args.finetune, map_location='cpu')

checkpoint_model = checkpoint['model'] if 'model' in checkpoint else checkpoint['module']

In [None]:
new_dict = OrderedDict()
for key, value in checkpoint_model.items():
    if 'decoder.' in key:
        continue
    if key.startswith('backbone.'):
        new_key = key[9:]
    elif key.startswith('encoder.'):
        new_key = key[8:]
    else:
        new_key = key
    new_dict[new_key] = value

checkpoint_model = new_dict

In [None]:
load_state_dict(model, checkpoint_model)

model.to(device)

model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)


num_training_steps_per_epoch = len(train_dataset) // args.batch_size
args.lr = args.lr * args.batch_size / 256
args.min_lr = args.min_lr * args.batch_size / 256

num_layers = model_without_ddp.get_num_layers()
assigner = LayerDecayValueAssigner(
    list(args.layer_decay**(num_layers + 1 - i)
          for i in range(num_layers + 2)))

skip_weight_decay_list = model.no_weight_decay()

optimizer_params = get_parameter_groups(
    model, args.weight_decay, skip_weight_decay_list,
    assigner.get_layer_id,
    assigner.get_scale)
model, optimizer, _, _ = deepspeed.initialize(
    args=args,
    model=model,
    model_parameters=optimizer_params,
)

Weights of VisionTransformer not initialized from pretrained model: ['final_norm.weight', 'final_norm.bias', 'head.weight', 'head.bias', 'additional_block.norm1.weight', 'additional_block.norm1.bias', 'additional_block.attn.q_bias', 'additional_block.attn.v_bias', 'additional_block.attn.qkv.weight', 'additional_block.attn.proj.weight', 'additional_block.attn.proj.bias', 'additional_block.norm2.weight', 'additional_block.norm2.bias', 'additional_block.mlp.fc1.weight', 'additional_block.mlp.fc1.bias', 'additional_block.mlp.fc2.weight', 'additional_block.mlp.fc2.bias']
Weights from pretrained model not used in VisionTransformer: ['mask_token', 'norm.weight', 'norm.bias']
Param groups = {
  "layer_0_decay": {
    "weight_decay": 0.06,
    "params": [
      "patch_embed.proj.weight"
    ],
    "lr_scale": 0.03275675867289639
  },
  "layer_0_no_decay": {
    "weight_decay": 0.0,
    "params": [
      "patch_embed.proj.bias"
    ],
    "lr_scale": 0.03275675867289639
  },
  "layer_41_no_decay

Using /root/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py310_cu118/fused_adam/build.ninja...
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


Time to load fused_adam op: 0.10823392868041992 seconds
[2023-08-11 09:16:33,723] [INFO] [logging.py:96:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adam as basic optimizer
[2023-08-11 09:16:33,765] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Basic Optimizer = FusedAdam
[2023-08-11 09:16:33,767] [INFO] [logging.py:96:log_dist] [Rank 0] Creating fp16 optimizer with dynamic loss scale


Loading extension module fused_adam...


[2023-08-11 09:16:33,852] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Final Optimizer = adam
[2023-08-11 09:16:33,854] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed using client LR scheduler
[2023-08-11 09:16:33,854] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed LR Scheduler = None
[2023-08-11 09:16:33,856] [INFO] [logging.py:96:log_dist] [Rank 0] step=0, skipped=0, lr=[0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001], mom=[[0.9, 0.999], [0

In [None]:
schedule_iters = np.arange(args.epochs * num_training_steps_per_epoch)
lr_schedule_values = np.array([
    args.min_lr + 0.5 * (args.lr - args.min_lr) *
    (1 + math.cos(math.pi * i / (len(schedule_iters)))) for i in schedule_iters
])

In [None]:
labels_embedding_path = os.path.join(args.data_path, 'labels-CLIP-embedding.txt')

label_to_embedding = []
with open(labels_embedding_path, 'r') as file:
    label_index = 1
    for line in file:
        label, description, raw_embedding = line.split(' | ')
        embedding = json.loads(raw_embedding)
        label_to_embedding.append(torch.FloatTensor(embedding))

label_to_embedding_tensor = torch.stack(label_to_embedding).to(device)
mean_embedding = label_to_embedding_tensor.mean(dim=0)
num_classes = label_to_embedding_tensor.size(0)

In [None]:
def update_optimizer_params(optimizer, it, lr_schedule_values):
    for i, param_group in enumerate(optimizer.param_groups):
        if lr_schedule_values is not None:
            param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
        if param_group["weight_decay"] > 0:
            param_group["weight_decay"] = args.weight_decay


def train_epoch(epoch, model, data_loader_train, optimizer, args, print_steps, accumulate_grad_steps=1, log_details=False):
    running_loss_sum = 0.0
    start_steps = epoch * len(data_loader_train)

    model.train()

    for step, (samples, targets, ids, _) in tqdm.tqdm(enumerate(data_loader_train), total=len(data_loader_train)):
        samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True) - 1

        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            outputs = model(samples).view(samples.shape[0], -1)
            diffs = (outputs.unsqueeze(1) - label_to_embedding_tensor.unsqueeze(0)).pow(2).sum(2)
            sorted_diffs, sorted_indices = torch.sort(diffs, dim=1)
            ranks = (sorted_indices == targets.unsqueeze(1)).nonzero(as_tuple=True)[1]
            target_diffs = diffs[torch.arange(diffs.shape[0]), targets]
            loss_elements = (2 * ranks.float() + 1) * target_diffs - sorted_diffs.gather(1, ranks.unsqueeze(1)).squeeze()
            loss = loss_elements.mean()

            if log_details:
              for i in range(outputs.shape[0]):
                rank = (sorted_indices[i] == targets[i].unsqueeze(0)).nonzero(as_tuple=True)[0].item()
                classification_logger.info(f"Video ID {ids[i].item()}: True Label {targets[i].item()}: Rank: {rank}: Loss: {loss.item()}")

        loss.backward()

        if (step + 1) % accumulate_grad_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        running_loss_sum += loss.item()

        if (step + 1) % print_steps == 0:
            avg_loss = running_loss_sum / print_steps
            tqdm.tqdm.write(f'Epoch {epoch}, Step {step + 1}: Avg Loss: {avg_loss:.4f}')
            logger.info(f'Epoch {epoch}, Step {step + 1}: Avg Loss: {avg_loss:.4f}')
            running_loss_sum = 0.0

        if not math.isfinite(loss.item()):
            logger.error(f"Loss is {loss.item()}, stopping training")
            sys.exit(1)



def save_model(epoch, model, optimizer, path):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    model_save_path = os.path.join(path, f"model_epoch_{epoch}_{timestamp}.pth")

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, model_save_path)

In [None]:
logger = logging.getLogger(__name__)

model.train(True)
save_path = "/content/drive/MyDrive/MSE"
os.makedirs(save_path, exist_ok=True)


for epoch in range(args.start_epoch, args.epochs):
    gc.collect()
    torch.cuda.empty_cache()
    optimizer.zero_grad()
    train_epoch(epoch, model, data_loader_train, optimizer, args, print_steps=50, accumulate_grad_steps=1, log_details=True)
    save_model(epoch, model, optimizer, save_path)
    logger.info(f'Completed epoch {epoch}')