|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- import os
- from typing import TYPE_CHECKING, Dict, List, Tuple
- import pickle
- from PIL import Image
- from enum import Enum
-
- import pandas as pd
- import numpy as np
- import torch
- from torchvision import transforms
- from torchvision.datasets.folder import default_loader, find_classes, make_dataset, IMG_EXTENSIONS
-
- from .content_loader import ContentLoader
- from ...utils.bb_generator import generate_bb_map
-
- if TYPE_CHECKING:
- from ...configs.imagenet_configs import ImagenetConfigs
-
-
- def _get_image_transform(data_specification: str, input_size: int) -> transforms.Compose:
- if data_specification == 'train':
- return transforms.Compose([
- transforms.RandomResizedCrop(input_size),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- ])
- if data_specification in ['test', 'val']:
- return transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(input_size),
- transforms.ToTensor(),
- ])
- raise ValueError('Unknown data specification: {}'.format(data_specification))
-
-
- def load_samples_cache(cache_name: str, imagenet_dir: str) -> List[Tuple[str, int]]:
- cache_path = os.path.join('.cache/', cache_name)
- print(cache_path)
- if os.path.isfile(cache_path):
- print('Loading cached samples from {}'.format(cache_path))
- with open(cache_path, 'rb') as f:
- return pickle.load(f)
-
- print('Creating cache for {}'.format(cache_name))
- os.makedirs(os.path.dirname(cache_path), exist_ok=True)
- _, class_to_idx = find_classes(imagenet_dir)
- samples = make_dataset(imagenet_dir, class_to_idx, IMG_EXTENSIONS)
- with open(cache_path, 'wb') as f:
- pickle.dump(samples, f)
- return samples
-
-
- class BBoxField(Enum):
- ImageId = 'ImageId'
- PredictionString = 'PredictionString'
-
-
- def load_bbox_df(imagenet_root: str, data_specification: str) -> pd.DataFrame:
- return pd.read_csv(os.path.join(imagenet_root, f'LOC_{data_specification}_solution.csv'))
-
-
- class ImagenetLoader(ContentLoader):
-
- def __init__(self, conf: 'ImagenetConfigs', data_specification: str):
- super().__init__(conf, data_specification)
- imagenet_dir = os.path.join(conf.data_separation, data_specification)
- self.__samples = load_samples_cache(f'imagenet.{data_specification}', imagenet_dir)
- self.__transform = _get_image_transform(data_specification, conf.input_size)
- self.__bboxes = load_bbox_df(conf.data_separation, data_specification)
- self.__rng_states: Dict[int, Tuple[str, torch.Tensor]] = {}
-
- def get_samples_names(self):
- ''' sample names must be unique, they can be either scan_names or scan_dirs.
- Decided to put scan_names. No difference'''
- return np.array([path for path, _ in self.__samples])
-
- def get_samples_labels(self):
- return np.array([label for _, label in self.__samples])
-
- def drop_samples(self, drop_mask: np.ndarray) -> None:
- self.__samples = [sample for sample, drop in zip(self.__samples, drop_mask) if not drop]
-
- def get_placeholder_name_to_fill_function_dict(self):
- """ Returns a dictionary of the placeholders' names (the ones this content loader supports)
- to the functions used for filling them. The functions must receive as input data_loader,
- which is an object of class data_loader that contains information about the current batch
- (e.g. the indices of the samples, or if the sample has many elements the indices of the chosen
- elements) and return an array per placeholder name according to the receives batch information.
- IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader
- they belong to! Some sort of having a mark :))!"""
- return {
- 'x': self.__get_x,
- 'y': self.__get_y,
- 'bbox': self.__get_bbox,
- 'interpretations': self.__get_bbox,
- }
-
- def __get_x(self, samples_inds: np.ndarray)\
- -> np.ndarray:
-
- def get_image(index) -> torch.Tensor:
- path, _ = self.__samples[index]
- sample = default_loader(path)
- self.__handle_random_state(index, 'x')
- return self.__transform(sample)
-
- return torch.stack(
- tuple([get_image(index) for index in samples_inds]),
- dim=0)
-
- def __get_y(self, samples_inds: np.ndarray)\
- -> np.ndarray:
- return self.get_samples_labels()[samples_inds]
-
- def __handle_random_state(self, idx: int, label: str) -> None:
- if idx not in self.__rng_states or self.__rng_states[idx][0] == label:
- self.__rng_states[idx] = (label, torch.get_rng_state())
- else:
- torch.set_rng_state(self.__rng_states[idx][1])
-
- def __get_bbox(self, sample_inds: np.ndarray)\
- -> np.ndarray:
-
- def extract_bb(prediction_string: str) -> np.ndarray:
- splitted = prediction_string.split()
- n = len(splitted) // 5
- return np.array([[float(b) for b in splitted[5 * i + 1: 5 * i + 5]] for i in range(n)])\
- .reshape((-1, 2, 2))
-
- def make_bb(index) -> torch.Tensor:
- path, _ = self.__samples[index]
- image_id = os.path.basename(path).split('.')[0]
- image_size = np.array(Image.open(path).size)[::-1] # 2
- bboxes = self.__bboxes[self.__bboxes[BBoxField.ImageId.value] == image_id][BBoxField.PredictionString.value]
- if len(bboxes) == 0:
- return torch.zeros(1, self.conf.inp_size, self.conf.inp_size) * np.nan
- bboxes = bboxes.values[0]
- bboxes = extract_bb(bboxes)[..., ::-1] / image_size # N 2 2
- start_points = bboxes[:, 0] # N 2
- end_points = bboxes[:, 1] # N 2
- bb_map = generate_bb_map(start_points, end_points, tuple(image_size))
- bb_map = Image.fromarray(bb_map)
- self.__handle_random_state(index, 'bb')
- return self.__transform(bb_map)
-
- return torch.stack([make_bb(i) for i in sample_inds], dim=0)
|