import os import json import torch from torchvision.io import read_image from torch.utils.data import Dataset from torchvision import transforms class VideoDescriptionDataloader(Dataset): def __init__(self, cfg, data_type): self.frames_dir = cfg.paths[data_type]['frames'] self.frame_features_dir = cfg.paths[data_type]['frame_features'] self.text_features_dir = cfg.paths[data_type]['text_features'] self.gt_scores_file_path = cfg.paths[data_type]['gt_scores'] self.transform = self._get_transforms() self.binary_percentage = cfg.binary_percentage if cfg.normalization == 'min_max' or data_type=='valid': self.normalize_scores = self._min_max_normalize_scores elif cfg.normalization == 'binary': self.normalize_scores = self._to_binary_labels else: raise ValueError(f"Unsupported normalization method: {cfg.normalization}") with open(self.gt_scores_file_path, "r") as f: self.gt_scores = json.load(f) self.video_ids = list(self.gt_scores.keys()) def __len__(self): return len(self.video_ids) def __getitem__(self, idx): video_id = self.video_ids[idx] video_frames_dir = os.path.join(self.frames_dir, self.video_ids[idx]) frames_tensor = torch.stack([self.transform(read_image(os.path.join(video_frames_dir, frame_file)).float()) for frame_file in sorted(os.listdir(video_frames_dir))]) text_features = torch.load(os.path.join(self.text_features_dir, f'{self.video_ids[idx]}.pt')) frame_features = torch.load(os.path.join(self.frame_features_dir, f'{self.video_ids[idx]}.pt')) gt_score = torch.tensor(self.gt_scores[self.video_ids[idx]], dtype=torch.float32) gt_score = self.normalize_scores(gt_score) return frames_tensor, text_features, frame_features, gt_score @staticmethod def _get_transforms(): return transforms.Compose([transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) @staticmethod def _min_max_normalize_scores(score): min_val, max_val = score.min(), score.max() if min_val != max_val: return (score - min_val) / (max_val - min_val) return torch.full_like(score, 0.5) def _to_binary_labels(self, score): num_top_elements = max(int(len(score) * self.binary_percentage) , 1) sorted_indices = score.argsort(descending=True) binary_labels = torch.zeros_like(score) binary_labels[sorted_indices[:num_top_elements]] = 1 return binary_labels