You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

t2v_dataloader.py 2.7KB

4 months ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import os
  2. import json
  3. import torch
  4. from torchvision.io import read_image
  5. from torch.utils.data import Dataset
  6. from torchvision import transforms
  7. class VideoDescriptionDataloader(Dataset):
  8. def __init__(self, cfg, data_type):
  9. self.frames_dir = cfg.paths[data_type]['frames']
  10. self.frame_features_dir = cfg.paths[data_type]['frame_features']
  11. self.text_features_dir = cfg.paths[data_type]['text_features']
  12. self.gt_scores_file_path = cfg.paths[data_type]['gt_scores']
  13. self.transform = self._get_transforms()
  14. self.binary_percentage = cfg.binary_percentage
  15. if cfg.normalization == 'min_max' or data_type=='valid':
  16. self.normalize_scores = self._min_max_normalize_scores
  17. elif cfg.normalization == 'binary':
  18. self.normalize_scores = self._to_binary_labels
  19. else:
  20. raise ValueError(f"Unsupported normalization method: {cfg.normalization}")
  21. with open(self.gt_scores_file_path, "r") as f:
  22. self.gt_scores = json.load(f)
  23. self.video_ids = list(self.gt_scores.keys())
  24. def __len__(self):
  25. return len(self.video_ids)
  26. def __getitem__(self, idx):
  27. video_id = self.video_ids[idx]
  28. video_frames_dir = os.path.join(self.frames_dir, self.video_ids[idx])
  29. frames_tensor = torch.stack([self.transform(read_image(os.path.join(video_frames_dir, frame_file)).float())
  30. for frame_file in sorted(os.listdir(video_frames_dir))])
  31. text_features = torch.load(os.path.join(self.text_features_dir, f'{self.video_ids[idx]}.pt'))
  32. frame_features = torch.load(os.path.join(self.frame_features_dir, f'{self.video_ids[idx]}.pt'))
  33. gt_score = torch.tensor(self.gt_scores[self.video_ids[idx]], dtype=torch.float32)
  34. gt_score = self.normalize_scores(gt_score)
  35. return frames_tensor, text_features, frame_features, gt_score
  36. @staticmethod
  37. def _get_transforms():
  38. return transforms.Compose([transforms.Normalize([0.485, 0.456, 0.406],
  39. [0.229, 0.224, 0.225])])
  40. @staticmethod
  41. def _min_max_normalize_scores(score):
  42. min_val, max_val = score.min(), score.max()
  43. if min_val != max_val:
  44. return (score - min_val) / (max_val - min_val)
  45. return torch.full_like(score, 0.5)
  46. def _to_binary_labels(self, score):
  47. num_top_elements = max(int(len(score) * self.binary_percentage) , 1)
  48. sorted_indices = score.argsort(descending=True)
  49. binary_labels = torch.zeros_like(score)
  50. binary_labels[sorted_indices[:num_top_elements]] = 1
  51. return binary_labels