You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

celeba_loader.py 6.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from enum import Enum
  2. import os
  3. import glob
  4. from typing import Callable, TYPE_CHECKING, Union
  5. import imageio
  6. import cv2
  7. import numpy as np
  8. import pandas as pd
  9. from .content_loader import ContentLoader
  10. if TYPE_CHECKING:
  11. from ...configs.celeba_configs import CelebAConfigs
  12. class CelebATag(Enum):
  13. FiveOClockShadowTag = '5_o_Clock_Shadow'
  14. ArchedEyebrowsTag = 'Arched_Eyebrows'
  15. AttractiveTag = 'Attractive'
  16. BagsUnderEyesTag = 'Bags_Under_Eyes'
  17. BaldTag = 'Bald'
  18. BangsTag = 'Bangs'
  19. BigLipsTag = 'Big_Lips'
  20. BigNoseTag = 'Big_Nose'
  21. BlackHairTag = 'Black_Hair'
  22. BlondHairTag = 'Blond_Hair'
  23. BlurryTag = 'Blurry'
  24. BrownHairTag = 'Brown_Hair'
  25. BushyEyebrowsTag = 'Bushy_Eyebrows'
  26. ChubbyTag = 'Chubby'
  27. DoubleChinTag = 'Double_Chin'
  28. EyeglassesTag = 'Eyeglasses'
  29. GoateeTag = 'Goatee'
  30. GrayHairTag = 'Gray_Hair'
  31. HighCheekbonesTag = 'High_Cheekbones'
  32. MaleTag = 'Male'
  33. MouthSlightlyOpenTag = 'Mouth_Slightly_Open'
  34. MustacheTag = 'Mustache'
  35. NarrowEyesTag = 'Narrow_Eyes'
  36. NoBeardTag = 'No_Beard'
  37. OvalFaceTag = 'Oval_Face'
  38. PaleSkinTag = 'Pale_Skin'
  39. PointyNoseTag = 'Pointy_Nose'
  40. RecedingHairlineTag = 'Receding_Hairline'
  41. RosyCheeksTag = 'Rosy_Cheeks'
  42. SideburnsTag = 'Sideburns'
  43. SmilingTag = 'Smiling'
  44. StraightHairTag = 'Straight_Hair'
  45. WavyHairTag = 'Wavy_Hair'
  46. WearingEarringsTag = 'Wearing_Earrings'
  47. WearingHatTag = 'Wearing_Hat'
  48. WearingLipstickTag = 'Wearing_Lipstick'
  49. WearingNecklaceTag = 'Wearing_Necklace'
  50. WearingNecktieTag = 'Wearing_Necktie'
  51. YoungTag = 'Young'
  52. BoundingBoxX = 'x_1'
  53. BoundingBoxY = 'y_1'
  54. BoundingBoxW = 'width'
  55. BoundingBoxH = 'height'
  56. Partition = 'partition'
  57. LeftEyeX = 'lefteye_x'
  58. LeftEyeY = 'lefteye_y'
  59. RightEyeX = 'righteye_x'
  60. RightEyeY = 'righteye_y'
  61. NoseX = 'nose_x'
  62. NoseY = 'nose_y'
  63. LeftMouthX = 'leftmouth_x'
  64. LeftMouthY = 'leftmouth_y'
  65. RightMouthX = 'rightmouth_x'
  66. RightMouthY = 'rightmouth_y'
  67. class SampleField(Enum):
  68. ImageId = 'image_id'
  69. Specification = 'specification'
  70. class CelebALoader(ContentLoader):
  71. def __init__(self, conf: 'CelebAConfigs', data_specification: str):
  72. """ read all directories for scans and annotations, which are split by DataSplitter.py
  73. And then, keeps only those samples which are used for 'usage' """
  74. super().__init__(conf, data_specification)
  75. self.conf = conf
  76. self._datasep = True
  77. self._samples_metadata = self._load_metadata(data_specification)
  78. self._data_root = conf.data_root
  79. self._warning_count = 0
  80. def _load_metadata(self, data_specification: str) -> pd.DataFrame:
  81. if '/' in data_specification:
  82. self._datasep = False
  83. return pd.DataFrame({
  84. SampleField.ImageId.value: glob.glob(os.path.join(data_specification, '*.jpg'))
  85. })
  86. metadata: pd.DataFrame = pd.read_csv(self.conf.dataset_metadata, sep='\t')
  87. metadata = metadata[metadata[SampleField.Specification.value] == data_specification]
  88. metadata = metadata.drop(SampleField.Specification.value, axis=1)
  89. metadata = metadata.reset_index().drop('index', axis=1)
  90. return metadata
  91. def get_samples_names(self):
  92. return self._samples_metadata[SampleField.ImageId.value].values
  93. def get_samples_labels(self):
  94. r""" Dummy. Because we have multiple labels """
  95. return np.ones((len(self._samples_metadata),))\
  96. if self.conf.main_tag is None\
  97. else self._samples_metadata[self.conf.main_tag.value].values
  98. def drop_samples(self, drop_mask: np.ndarray) -> None:
  99. self._samples_metadata = self._samples_metadata[np.logical_not(drop_mask)]
  100. def get_placeholder_name_to_fill_function_dict(self):
  101. """ Returns a dictionary of the placeholders' names (the ones this content loader supports)
  102. to the functions used for filling them. The functions must receive as input data_loader,
  103. which is an object of class data_loader that contains information about the current batch
  104. (e.g. the indices of the samples, or if the sample has many elements the indices of the chosen
  105. elements) and return an array per placeholder name according to the receives batch information.
  106. IMPORTANT: Better to use a fixed prefix in the names of the placeholders to become clear which content loader
  107. they belong to! Some sort of having a mark :))!"""
  108. return {
  109. 'x': self._get_x,
  110. **{
  111. tag.name: self._generate_tag_getter(tag) for tag in CelebATag
  112. }
  113. }
  114. def _get_x(self, samples_inds: np.ndarray)\
  115. -> np.ndarray:
  116. images = tuple(self._read_image(index) for index in samples_inds)
  117. images = np.stack(images, axis=0) # B I I 3
  118. images = images.transpose((0, 3, 1, 2)) # B 3 I I
  119. images = images.astype(float) / 255.
  120. return images
  121. def _read_image(self, sample_index: int) -> np.ndarray:
  122. filepath = self._samples_metadata.iloc[sample_index][SampleField.ImageId.value]
  123. if self._datasep:
  124. filepath = os.path.join(self._data_root, filepath)
  125. image = imageio.imread(filepath) # H W 3
  126. image = cv2.resize(image, (self.conf.input_size, self.conf.input_size)) # I I 3
  127. return image
  128. def _generate_tag_getter(self, tag: CelebATag) -> Callable[[np.ndarray, Union[None, np.ndarray]], np.ndarray]:
  129. def get_tag(samples_inds: np.ndarray) -> np.ndarray:
  130. return self._samples_metadata.iloc[samples_inds][tag.value].values if self._datasep else np.zeros(len(samples_inds))
  131. return get_tag