You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

thyroid_dataset.py 3.2KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. import numpy as np
  3. from PIL import Image
  4. from torch.utils.data import Dataset
  5. from config import Config
  6. from fragment_splitter import CustomFragmentLoader
  7. from transformation import get_transformation
  8. from utils import show_and_wait
  9. class ThyroidDataset(Dataset):
  10. def __init__(self, image_paths_labels_list, class_to_index, transform=None, force_to_size_with_padding=512):
  11. super().__init__()
  12. self.class_to_idx_dict = class_to_index
  13. self.force_to_size_with_padding = force_to_size_with_padding
  14. self.transform = transform
  15. self.samples = self._make_dataset(image_paths_labels_list)
  16. self.class_weights = self._calculate_class_weights(image_paths_labels_list)
  17. def _calculate_class_weights(self, image_paths_labels_list):
  18. class_counts = {}
  19. for image_path, (label, slide) in image_paths_labels_list:
  20. class_counts[label] = class_counts.get(label, 0) + 1
  21. class_weights = [
  22. (self.class_to_idx_dict.get(c, None), len(image_paths_labels_list) / (len(class_counts) * v)) for c, v
  23. in
  24. class_counts.items()]
  25. class_weights.sort()
  26. return [item[1] for item in class_weights]
  27. def _make_dataset(self, image_paths_labels_list):
  28. images = []
  29. for image_path, (label, slide) in image_paths_labels_list:
  30. if not os.path.exists(os.path.abspath(image_path)):
  31. raise (RuntimeError(f"{image_path} not found."))
  32. item = (image_path, (self.class_to_idx_dict.get(label, "Unknown label"), slide))
  33. images.append(item)
  34. return images
  35. def __len__(self):
  36. return len(self.samples)
  37. def __getitem__(self, index):
  38. path, target = self.samples[index]
  39. image = Image.open(path)
  40. image = image.convert('RGB')
  41. image = self.add_margin(image)
  42. image = np.array(image)
  43. if self.transform is not None:
  44. # show_and_wait(image, name=f"./transformations/{index}-original", wait=False, save=True)
  45. image = self.transform(image=image)['image']
  46. # image_show = np.moveaxis(image.cpu().detach().numpy(), 0, -1)
  47. # show_and_wait(image_show, name=f"./transformations/{index}-transformed", save=True)
  48. else:
  49. transform = get_transformation(augmentation="min")
  50. image = transform(image=image)['image']
  51. return image, target
  52. def add_margin(self, pil_img):
  53. width, height = pil_img.size
  54. new_width = self.force_to_size_with_padding
  55. new_height = self.force_to_size_with_padding
  56. result = Image.new("RGB", (new_width, new_height), (0, 0, 0))
  57. top_padding = (new_height - height) // 2
  58. left_padding = (new_width - width) // 2
  59. result.paste(pil_img, (left_padding, top_padding))
  60. return result
  61. if __name__ == '__main__':
  62. class_idx_dict = Config.class_idx_dict
  63. datasets_folder = ["stanford_tissue_microarray", "papsociaty"]
  64. train, val, test = CustomFragmentLoader(datasets_folder).load_image_path_and_labels_and_split()
  65. train_ds = ThyroidDataset(train, class_idx_dict)
  66. test_ds = ThyroidDataset(test, class_idx_dict)
  67. val_ds = ThyroidDataset(val, class_idx_dict)
  68. res = train_ds.__getitem__(0)
  69. print(res)