1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- import os
- import json
- import torch
- from torchvision.io import read_image
- from torch.utils.data import Dataset
- from torchvision import transforms
-
- class VideoDescriptionDataloader(Dataset):
- def __init__(self, cfg, data_type):
- self.frames_dir = cfg.paths[data_type]['frames']
- self.frame_features_dir = cfg.paths[data_type]['frame_features']
- self.text_features_dir = cfg.paths[data_type]['text_features']
- self.gt_scores_file_path = cfg.paths[data_type]['gt_scores']
- self.transform = self._get_transforms()
- self.binary_percentage = cfg.binary_percentage
-
- if cfg.normalization == 'min_max' or data_type=='valid':
- self.normalize_scores = self._min_max_normalize_scores
- elif cfg.normalization == 'binary':
- self.normalize_scores = self._to_binary_labels
- else:
- raise ValueError(f"Unsupported normalization method: {cfg.normalization}")
-
- with open(self.gt_scores_file_path, "r") as f:
- self.gt_scores = json.load(f)
- self.video_ids = list(self.gt_scores.keys())
-
- def __len__(self):
- return len(self.video_ids)
-
- def __getitem__(self, idx):
-
- video_id = self.video_ids[idx]
- video_frames_dir = os.path.join(self.frames_dir, self.video_ids[idx])
-
- frames_tensor = torch.stack([self.transform(read_image(os.path.join(video_frames_dir, frame_file)).float())
- for frame_file in sorted(os.listdir(video_frames_dir))])
-
- text_features = torch.load(os.path.join(self.text_features_dir, f'{self.video_ids[idx]}.pt'))
- frame_features = torch.load(os.path.join(self.frame_features_dir, f'{self.video_ids[idx]}.pt'))
-
- gt_score = torch.tensor(self.gt_scores[self.video_ids[idx]], dtype=torch.float32)
- gt_score = self.normalize_scores(gt_score)
-
- return frames_tensor, text_features, frame_features, gt_score
-
- @staticmethod
- def _get_transforms():
- return transforms.Compose([transforms.Normalize([0.485, 0.456, 0.406],
- [0.229, 0.224, 0.225])])
-
- @staticmethod
- def _min_max_normalize_scores(score):
- min_val, max_val = score.min(), score.max()
-
- if min_val != max_val:
- return (score - min_val) / (max_val - min_val)
- return torch.full_like(score, 0.5)
-
- def _to_binary_labels(self, score):
- num_top_elements = max(int(len(score) * self.binary_percentage) , 1)
- sorted_indices = score.argsort(descending=True)
- binary_labels = torch.zeros_like(score)
- binary_labels[sorted_indices[:num_top_elements]] = 1
- return binary_labels
|