You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

imagenet_loader.py 6.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import os
  2. from typing import TYPE_CHECKING, Dict, List, Tuple
  3. import pickle
  4. from PIL import Image
  5. from enum import Enum
  6. import pandas as pd
  7. import numpy as np
  8. import torch
  9. from torchvision import transforms
  10. from torchvision.datasets.folder import default_loader, find_classes, make_dataset, IMG_EXTENSIONS
  11. from .content_loader import ContentLoader
  12. from ...utils.bb_generator import generate_bb_map
  13. if TYPE_CHECKING:
  14. from ...configs.imagenet_configs import ImagenetConfigs
  15. def _get_image_transform(data_specification: str, input_size: int) -> transforms.Compose:
  16. if data_specification == 'train':
  17. return transforms.Compose([
  18. transforms.RandomResizedCrop(input_size),
  19. transforms.RandomHorizontalFlip(),
  20. transforms.ToTensor(),
  21. ])
  22. if data_specification in ['test', 'val']:
  23. return transforms.Compose([
  24. transforms.Resize(256),
  25. transforms.CenterCrop(input_size),
  26. transforms.ToTensor(),
  27. ])
  28. raise ValueError('Unknown data specification: {}'.format(data_specification))
  29. def load_samples_cache(cache_name: str, imagenet_dir: str) -> List[Tuple[str, int]]:
  30. cache_path = os.path.join('.cache/', cache_name)
  31. print(cache_path)
  32. if os.path.isfile(cache_path):
  33. print('Loading cached samples from {}'.format(cache_path))
  34. with open(cache_path, 'rb') as f:
  35. return pickle.load(f)
  36. print('Creating cache for {}'.format(cache_name))
  37. os.makedirs(os.path.dirname(cache_path), exist_ok=True)
  38. _, class_to_idx = find_classes(imagenet_dir)
  39. samples = make_dataset(imagenet_dir, class_to_idx, IMG_EXTENSIONS)
  40. with open(cache_path, 'wb') as f:
  41. pickle.dump(samples, f)
  42. return samples
  43. class BBoxField(Enum):
  44. ImageId = 'ImageId'
  45. PredictionString = 'PredictionString'
  46. def load_bbox_df(imagenet_root: str, data_specification: str) -> pd.DataFrame:
  47. return pd.read_csv(os.path.join(imagenet_root, f'LOC_{data_specification}_solution.csv'))
  48. class ImagenetLoader(ContentLoader):
  49. def __init__(self, conf: 'ImagenetConfigs', data_specification: str):
  50. super().__init__(conf, data_specification)
  51. imagenet_dir = os.path.join(conf.data_separation, data_specification)
  52. self.__samples = load_samples_cache(f'imagenet.{data_specification}', imagenet_dir)
  53. self.__transform = _get_image_transform(data_specification, conf.input_size)
  54. self.__bboxes = load_bbox_df(conf.data_separation, data_specification)
  55. self.__rng_states: Dict[int, Tuple[str, torch.Tensor]] = {}
  56. def get_samples_names(self):
  57. ''' sample names must be unique, they can be either scan_names or scan_dirs.
  58. Decided to put scan_names. No difference'''
  59. return np.array([path for path, _ in self.__samples])
  60. def get_samples_labels(self):
  61. return np.array([label for _, label in self.__samples])
  62. def drop_samples(self, drop_mask: np.ndarray) -> None:
  63. self.__samples = [sample for sample, drop in zip(self.__samples, drop_mask) if not drop]
  64. def get_placeholder_name_to_fill_function_dict(self):
  65. """ Returns a dictionary of the placeholders' names (the ones this content loader supports)
  66. to the functions used for filling them. The functions must receive as input data_loader,
  67. which is an object of class data_loader that contains information about the current batch
  68. (e.g. the indices of the samples, or if the sample has many elements the indices of the chosen
  69. elements) and return an array per placeholder name according to the receives batch information.
  70. IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader
  71. they belong to! Some sort of having a mark :))!"""
  72. return {
  73. 'x': self.__get_x,
  74. 'y': self.__get_y,
  75. 'bbox': self.__get_bbox,
  76. 'interpretations': self.__get_bbox,
  77. }
  78. def __get_x(self, samples_inds: np.ndarray)\
  79. -> np.ndarray:
  80. def get_image(index) -> torch.Tensor:
  81. path, _ = self.__samples[index]
  82. sample = default_loader(path)
  83. self.__handle_random_state(index, 'x')
  84. return self.__transform(sample)
  85. return torch.stack(
  86. tuple([get_image(index) for index in samples_inds]),
  87. dim=0)
  88. def __get_y(self, samples_inds: np.ndarray)\
  89. -> np.ndarray:
  90. return self.get_samples_labels()[samples_inds]
  91. def __handle_random_state(self, idx: int, label: str) -> None:
  92. if idx not in self.__rng_states or self.__rng_states[idx][0] == label:
  93. self.__rng_states[idx] = (label, torch.get_rng_state())
  94. else:
  95. torch.set_rng_state(self.__rng_states[idx][1])
  96. def __get_bbox(self, sample_inds: np.ndarray)\
  97. -> np.ndarray:
  98. def extract_bb(prediction_string: str) -> np.ndarray:
  99. splitted = prediction_string.split()
  100. n = len(splitted) // 5
  101. return np.array([[float(b) for b in splitted[5 * i + 1: 5 * i + 5]] for i in range(n)])\
  102. .reshape((-1, 2, 2))
  103. def make_bb(index) -> torch.Tensor:
  104. path, _ = self.__samples[index]
  105. image_id = os.path.basename(path).split('.')[0]
  106. image_size = np.array(Image.open(path).size)[::-1] # 2
  107. bboxes = self.__bboxes[self.__bboxes[BBoxField.ImageId.value] == image_id][BBoxField.PredictionString.value]
  108. if len(bboxes) == 0:
  109. return torch.zeros(1, self.conf.inp_size, self.conf.inp_size) * np.nan
  110. bboxes = bboxes.values[0]
  111. bboxes = extract_bb(bboxes)[..., ::-1] / image_size # N 2 2
  112. start_points = bboxes[:, 0] # N 2
  113. end_points = bboxes[:, 1] # N 2
  114. bb_map = generate_bb_map(start_points, end_points, tuple(image_size))
  115. bb_map = Image.fromarray(bb_map)
  116. self.__handle_random_state(index, 'bb')
  117. return self.__transform(bb_map)
  118. return torch.stack([make_bb(i) for i in sample_inds], dim=0)