import os from typing import TYPE_CHECKING, Dict, List, Tuple import pickle from PIL import Image from enum import Enum import pandas as pd import numpy as np import torch from torchvision import transforms from torchvision.datasets.folder import default_loader, find_classes, make_dataset, IMG_EXTENSIONS from .content_loader import ContentLoader from ...utils.bb_generator import generate_bb_map if TYPE_CHECKING: from ...configs.imagenet_configs import ImagenetConfigs def _get_image_transform(data_specification: str, input_size: int) -> transforms.Compose: if data_specification == 'train': return transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) if data_specification in ['test', 'val']: return transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(input_size), transforms.ToTensor(), ]) raise ValueError('Unknown data specification: {}'.format(data_specification)) def load_samples_cache(cache_name: str, imagenet_dir: str) -> List[Tuple[str, int]]: cache_path = os.path.join('.cache/', cache_name) print(cache_path) if os.path.isfile(cache_path): print('Loading cached samples from {}'.format(cache_path)) with open(cache_path, 'rb') as f: return pickle.load(f) print('Creating cache for {}'.format(cache_name)) os.makedirs(os.path.dirname(cache_path), exist_ok=True) _, class_to_idx = find_classes(imagenet_dir) samples = make_dataset(imagenet_dir, class_to_idx, IMG_EXTENSIONS) with open(cache_path, 'wb') as f: pickle.dump(samples, f) return samples class BBoxField(Enum): ImageId = 'ImageId' PredictionString = 'PredictionString' def load_bbox_df(imagenet_root: str, data_specification: str) -> pd.DataFrame: return pd.read_csv(os.path.join(imagenet_root, f'LOC_{data_specification}_solution.csv')) class ImagenetLoader(ContentLoader): def __init__(self, conf: 'ImagenetConfigs', data_specification: str): super().__init__(conf, data_specification) imagenet_dir = os.path.join(conf.data_separation, data_specification) self.__samples = load_samples_cache(f'imagenet.{data_specification}', imagenet_dir) self.__transform = _get_image_transform(data_specification, conf.input_size) self.__bboxes = load_bbox_df(conf.data_separation, data_specification) self.__rng_states: Dict[int, Tuple[str, torch.Tensor]] = {} def get_samples_names(self): ''' sample names must be unique, they can be either scan_names or scan_dirs. Decided to put scan_names. No difference''' return np.array([path for path, _ in self.__samples]) def get_samples_labels(self): return np.array([label for _, label in self.__samples]) def drop_samples(self, drop_mask: np.ndarray) -> None: self.__samples = [sample for sample, drop in zip(self.__samples, drop_mask) if not drop] def get_placeholder_name_to_fill_function_dict(self): """ Returns a dictionary of the placeholders' names (the ones this content loader supports) to the functions used for filling them. The functions must receive as input data_loader, which is an object of class data_loader that contains information about the current batch (e.g. the indices of the samples, or if the sample has many elements the indices of the chosen elements) and return an array per placeholder name according to the receives batch information. IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader they belong to! Some sort of having a mark :))!""" return { 'x': self.__get_x, 'y': self.__get_y, 'bbox': self.__get_bbox, 'interpretations': self.__get_bbox, } def __get_x(self, samples_inds: np.ndarray)\ -> np.ndarray: def get_image(index) -> torch.Tensor: path, _ = self.__samples[index] sample = default_loader(path) self.__handle_random_state(index, 'x') return self.__transform(sample) return torch.stack( tuple([get_image(index) for index in samples_inds]), dim=0) def __get_y(self, samples_inds: np.ndarray)\ -> np.ndarray: return self.get_samples_labels()[samples_inds] def __handle_random_state(self, idx: int, label: str) -> None: if idx not in self.__rng_states or self.__rng_states[idx][0] == label: self.__rng_states[idx] = (label, torch.get_rng_state()) else: torch.set_rng_state(self.__rng_states[idx][1]) def __get_bbox(self, sample_inds: np.ndarray)\ -> np.ndarray: def extract_bb(prediction_string: str) -> np.ndarray: splitted = prediction_string.split() n = len(splitted) // 5 return np.array([[float(b) for b in splitted[5 * i + 1: 5 * i + 5]] for i in range(n)])\ .reshape((-1, 2, 2)) def make_bb(index) -> torch.Tensor: path, _ = self.__samples[index] image_id = os.path.basename(path).split('.')[0] image_size = np.array(Image.open(path).size)[::-1] # 2 bboxes = self.__bboxes[self.__bboxes[BBoxField.ImageId.value] == image_id][BBoxField.PredictionString.value] if len(bboxes) == 0: return torch.zeros(1, self.conf.inp_size, self.conf.inp_size) * np.nan bboxes = bboxes.values[0] bboxes = extract_bb(bboxes)[..., ::-1] / image_size # N 2 2 start_points = bboxes[:, 0] # N 2 end_points = bboxes[:, 1] # N 2 bb_map = generate_bb_map(start_points, end_points, tuple(image_size)) bb_map = Image.fromarray(bb_map) self.__handle_random_state(index, 'bb') return self.__transform(bb_map) return torch.stack([make_bb(i) for i in sample_inds], dim=0)