You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long. 3.2KB

  1. import os
  2. import numpy as np
  3. from PIL import Image
  4. from import Dataset
  5. from config import Config
  6. from fragment_splitter import CustomFragmentLoader
  7. from transformation import get_transformation
  8. from utils import show_and_wait
  9. class ThyroidDataset(Dataset):
  10. def __init__(self, image_paths_labels_list, class_to_index, transform=None, force_to_size_with_padding=512):
  11. super().__init__()
  12. self.class_to_idx_dict = class_to_index
  13. self.force_to_size_with_padding = force_to_size_with_padding
  14. self.transform = transform
  15. self.samples = self._make_dataset(image_paths_labels_list)
  16. self.class_weights = self._calculate_class_weights(image_paths_labels_list)
  17. def _calculate_class_weights(self, image_paths_labels_list):
  18. class_counts = {}
  19. for image_path, (label, slide) in image_paths_labels_list:
  20. class_counts[label] = class_counts.get(label, 0) + 1
  21. class_weights = [
  22. (self.class_to_idx_dict.get(c, None), len(image_paths_labels_list) / (len(class_counts) * v)) for c, v
  23. in
  24. class_counts.items()]
  25. class_weights.sort()
  26. return [item[1] for item in class_weights]
  27. def _make_dataset(self, image_paths_labels_list):
  28. images = []
  29. for image_path, (label, slide) in image_paths_labels_list:
  30. if not os.path.exists(os.path.abspath(image_path)):
  31. raise (RuntimeError(f"{image_path} not found."))
  32. item = (image_path, (self.class_to_idx_dict.get(label, "Unknown label"), slide))
  33. images.append(item)
  34. return images
  35. def __len__(self):
  36. return len(self.samples)
  37. def __getitem__(self, index):
  38. path, target = self.samples[index]
  39. image =
  40. image = image.convert('RGB')
  41. image = self.add_margin(image)
  42. image = np.array(image)
  43. if self.transform is not None:
  44. # show_and_wait(image, name=f"./transformations/{index}-original", wait=False, save=True)
  45. image = self.transform(image=image)['image']
  46. # image_show = np.moveaxis(image.cpu().detach().numpy(), 0, -1)
  47. # show_and_wait(image_show, name=f"./transformations/{index}-transformed", save=True)
  48. else:
  49. transform = get_transformation(augmentation="min")
  50. image = transform(image=image)['image']
  51. return image, target
  52. def add_margin(self, pil_img):
  53. width, height = pil_img.size
  54. new_width = self.force_to_size_with_padding
  55. new_height = self.force_to_size_with_padding
  56. result ="RGB", (new_width, new_height), (0, 0, 0))
  57. top_padding = (new_height - height) // 2
  58. left_padding = (new_width - width) // 2
  59. result.paste(pil_img, (left_padding, top_padding))
  60. return result
  61. if __name__ == '__main__':
  62. class_idx_dict = Config.class_idx_dict
  63. datasets_folder = ["stanford_tissue_microarray", "papsociaty"]
  64. train, val, test = CustomFragmentLoader(datasets_folder).load_image_path_and_labels_and_split()
  65. train_ds = ThyroidDataset(train, class_idx_dict)
  66. test_ds = ThyroidDataset(test, class_idx_dict)
  67. val_ds = ThyroidDataset(val, class_idx_dict)
  68. res = train_ds.__getitem__(0)
  69. print(res)