from enum import Enum import os import glob from typing import Callable, TYPE_CHECKING, Union import imageio import cv2 import numpy as np import pandas as pd from .content_loader import ContentLoader if TYPE_CHECKING: from ...configs.celeba_configs import CelebAConfigs class CelebATag(Enum): FiveOClockShadowTag = '5_o_Clock_Shadow' ArchedEyebrowsTag = 'Arched_Eyebrows' AttractiveTag = 'Attractive' BagsUnderEyesTag = 'Bags_Under_Eyes' BaldTag = 'Bald' BangsTag = 'Bangs' BigLipsTag = 'Big_Lips' BigNoseTag = 'Big_Nose' BlackHairTag = 'Black_Hair' BlondHairTag = 'Blond_Hair' BlurryTag = 'Blurry' BrownHairTag = 'Brown_Hair' BushyEyebrowsTag = 'Bushy_Eyebrows' ChubbyTag = 'Chubby' DoubleChinTag = 'Double_Chin' EyeglassesTag = 'Eyeglasses' GoateeTag = 'Goatee' GrayHairTag = 'Gray_Hair' HighCheekbonesTag = 'High_Cheekbones' MaleTag = 'Male' MouthSlightlyOpenTag = 'Mouth_Slightly_Open' MustacheTag = 'Mustache' NarrowEyesTag = 'Narrow_Eyes' NoBeardTag = 'No_Beard' OvalFaceTag = 'Oval_Face' PaleSkinTag = 'Pale_Skin' PointyNoseTag = 'Pointy_Nose' RecedingHairlineTag = 'Receding_Hairline' RosyCheeksTag = 'Rosy_Cheeks' SideburnsTag = 'Sideburns' SmilingTag = 'Smiling' StraightHairTag = 'Straight_Hair' WavyHairTag = 'Wavy_Hair' WearingEarringsTag = 'Wearing_Earrings' WearingHatTag = 'Wearing_Hat' WearingLipstickTag = 'Wearing_Lipstick' WearingNecklaceTag = 'Wearing_Necklace' WearingNecktieTag = 'Wearing_Necktie' YoungTag = 'Young' BoundingBoxX = 'x_1' BoundingBoxY = 'y_1' BoundingBoxW = 'width' BoundingBoxH = 'height' Partition = 'partition' LeftEyeX = 'lefteye_x' LeftEyeY = 'lefteye_y' RightEyeX = 'righteye_x' RightEyeY = 'righteye_y' NoseX = 'nose_x' NoseY = 'nose_y' LeftMouthX = 'leftmouth_x' LeftMouthY = 'leftmouth_y' RightMouthX = 'rightmouth_x' RightMouthY = 'rightmouth_y' class SampleField(Enum): ImageId = 'image_id' Specification = 'specification' class CelebALoader(ContentLoader): def __init__(self, conf: 'CelebAConfigs', data_specification: str): """ read all directories for scans and annotations, which are split by DataSplitter.py And then, keeps only those samples which are used for 'usage' """ super().__init__(conf, data_specification) self.conf = conf self._datasep = True self._samples_metadata = self._load_metadata(data_specification) self._data_root = conf.data_root self._warning_count = 0 def _load_metadata(self, data_specification: str) -> pd.DataFrame: if '/' in data_specification: self._datasep = False return pd.DataFrame({ SampleField.ImageId.value: glob.glob(os.path.join(data_specification, '*.jpg')) }) metadata: pd.DataFrame = pd.read_csv(self.conf.dataset_metadata, sep='\t') metadata = metadata[metadata[SampleField.Specification.value] == data_specification] metadata = metadata.drop(SampleField.Specification.value, axis=1) metadata = metadata.reset_index().drop('index', axis=1) return metadata def get_samples_names(self): return self._samples_metadata[SampleField.ImageId.value].values def get_samples_labels(self): r""" Dummy. Because we have multiple labels """ return np.ones((len(self._samples_metadata),))\ if self.conf.main_tag is None\ else self._samples_metadata[self.conf.main_tag.value].values def drop_samples(self, drop_mask: np.ndarray) -> None: self._samples_metadata = self._samples_metadata[np.logical_not(drop_mask)] def get_placeholder_name_to_fill_function_dict(self): """ Returns a dictionary of the placeholders' names (the ones this content loader supports) to the functions used for filling them. The functions must receive as input data_loader, which is an object of class data_loader that contains information about the current batch (e.g. the indices of the samples, or if the sample has many elements the indices of the chosen elements) and return an array per placeholder name according to the receives batch information. IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader they belong to! Some sort of having a mark :))!""" return { 'x': self._get_x, **{ tag.name: self._generate_tag_getter(tag) for tag in CelebATag } } def _get_x(self, samples_inds: np.ndarray)\ -> np.ndarray: images = tuple(self._read_image(index) for index in samples_inds) images = np.stack(images, axis=0) # B I I 3 images = images.transpose((0, 3, 1, 2)) # B 3 I I images = images.astype(float) / 255. return images def _read_image(self, sample_index: int) -> np.ndarray: filepath = self._samples_metadata.iloc[sample_index][SampleField.ImageId.value] if self._datasep: filepath = os.path.join(self._data_root, filepath) image = imageio.imread(filepath) # H W 3 image = cv2.resize(image, (self.conf.input_size, self.conf.input_size)) # I I 3 return image def _generate_tag_getter(self, tag: CelebATag) -> Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray]: def get_tag(samples_inds: np.ndarray) -> np.ndarray: return self._samples_metadata.iloc[samples_inds][tag.value].values if self._datasep else np.zeros(len(samples_inds)) return get_tag