You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import random
  5. from segment_anything.utils.transforms import ResizeLongestSide
  6. from einops import rearrange
  7. import torch
  8. from segment_anything import SamPredictor, sam_model_registry
  9. from torch.utils.data import DataLoader
  10. from time import time
  11. import torch.nn.functional as F
  12. import cv2
  13. from PIL import Image
  14. import cv2
  15. from kernel.pre_processer import PreProcessing
  16. def apply_median_filter(input_matrix, kernel_size=5, sigma=0):
  17. # Apply the Gaussian filter
  18. filtered_matrix = cv2.medianBlur(input_matrix.astype(np.uint8), kernel_size)
  19. return filtered_matrix.astype(np.float32)
  20. def apply_guassain_filter(input_matrix, kernel_size=(7, 7), sigma=0):
  21. smoothed_matrix = cv2.blur(input_matrix, kernel_size)
  22. return smoothed_matrix.astype(np.float32)
  23. def img_enhance(img2, over_coef=0.8, under_coef=0.7):
  24. img2 = apply_median_filter(img2)
  25. img_blure = apply_guassain_filter(img2)
  26. img2 = img2 - 0.8 * img_blure
  27. img_mean = np.mean(img2, axis=(1, 2))
  28. img_max = np.amax(img2, axis=(1, 2))
  29. val = (img_max - img_mean) * over_coef + img_mean
  30. img2 = (img2 < img_mean * under_coef).astype(np.float32) * img_mean * under_coef + (
  31. (img2 >= img_mean * under_coef).astype(np.float32)
  32. ) * img2
  33. img2 = (img2 <= val).astype(np.float32) * img2 + (img2 > val).astype(
  34. np.float32
  35. ) * val
  36. return img2
  37. def normalize_and_pad(x, img_size):
  38. """Normalize pixel values and pad to a square input."""
  39. pixel_mean = torch.tensor([[[[123.675]], [[116.28]], [[103.53]]]])
  40. pixel_std = torch.tensor([[[[58.395]], [[57.12]], [[57.375]]]])
  41. # Normalize colors
  42. x = (x - pixel_mean) / pixel_std
  43. # Pad
  44. h, w = x.shape[-2:]
  45. padh = img_size - h
  46. padw = img_size - w
  47. x = F.pad(x, (0, padw, 0, padh))
  48. return x
  49. def preprocess(img_enhanced, img_enhance_times=1, over_coef=0.4, under_coef=0.5):
  50. # img_enhanced = img_enhanced+0.1
  51. img_enhanced -= torch.amin(img_enhanced, dim=(1, 2), keepdim=True)
  52. img_max = torch.amax(img_enhanced, axis=(1, 2), keepdims=True)
  53. img_max[img_max == 0] = 1
  54. img_enhanced = img_enhanced / img_max
  55. # raise ValueError(img_max)
  56. img_enhanced = img_enhanced.unsqueeze(1)
  57. img_enhanced = PreProcessing.CLAHE(img_enhanced, clip_limit=9.0, grid_size=(4, 4))
  58. img_enhanced = img_enhanced[0]
  59. # for i in range(img_enhance_times):
  60. # img_enhanced=img_enhance(img_enhanced.astype(np.float32), over_coef=over_coef,under_coef=under_coef)
  61. img_enhanced -= torch.amin(img_enhanced, dim=(1, 2), keepdim=True)
  62. larg_imag = (
  63. img_enhanced / torch.amax(img_enhanced, axis=(1, 2), keepdims=True) * 255
  64. ).type(torch.uint8)
  65. return larg_imag
  66. def prepare(larg_imag, target_image_size):
  67. # larg_imag = 255 - larg_imag
  68. larg_imag = rearrange(larg_imag, "S H W -> S 1 H W")
  69. larg_imag = torch.tensor(
  70. np.concatenate([larg_imag, larg_imag, larg_imag], axis=1)
  71. ).float()
  72. transform = ResizeLongestSide(target_image_size)
  73. larg_imag = transform.apply_image_torch(larg_imag)
  74. larg_imag = normalize_and_pad(larg_imag, target_image_size)
  75. return larg_imag
  76. def process_single_image(image_path, target_image_size):
  77. # Load the image
  78. if image_path.endswith(".png") or image_path.endswith(".jpg"):
  79. data = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE).squeeze()
  80. else:
  81. data = np.load(image_path)
  82. x = rearrange(data, "H W -> 1 H W")
  83. x = torch.tensor(x)
  84. # Apply preprocessing
  85. x = preprocess(x)
  86. x = prepare(x, target_image_size)
  87. return x
  88. class PanDataset:
  89. def __init__(
  90. self,
  91. images_dirs,
  92. labels_dirs,
  93. datasets,
  94. target_image_size,
  95. slice_per_image,
  96. train=True,
  97. ratio=0.9,
  98. augmentation=None,
  99. ):
  100. self.data_set_names = []
  101. self.labels_path = []
  102. self.images_path = []
  103. for labels_dir, images_dir, dataset_name in zip(
  104. labels_dirs, images_dirs, datasets
  105. ):
  106. if train == True:
  107. self.data_set_names.extend(
  108. sorted([dataset_name[0] for _ in os.listdir(labels_dir)[:int(len(os.listdir(labels_dir)) * ratio)]])
  109. )
  110. self.labels_path.extend(
  111. sorted([os.path.join(labels_dir, item) for item in os.listdir(labels_dir)[:int(len(os.listdir(labels_dir)) * ratio)]])
  112. )
  113. self.images_path.extend(
  114. sorted([os.path.join(images_dir, item) for item in os.listdir(images_dir)[:int(len(os.listdir(images_dir)) * ratio)]])
  115. )
  116. else:
  117. self.data_set_names.extend(
  118. sorted([dataset_name[0] for _ in os.listdir(labels_dir)[int(len(os.listdir(labels_dir)) * ratio):]])
  119. )
  120. self.labels_path.extend(
  121. sorted([os.path.join(labels_dir, item) for item in os.listdir(labels_dir)[int(len(os.listdir(labels_dir)) * ratio):]])
  122. )
  123. self.images_path.extend(
  124. sorted([os.path.join(images_dir, item) for item in os.listdir(images_dir)[int(len(os.listdir(images_dir)) * ratio):]])
  125. )
  126. self.target_image_size = target_image_size
  127. self.datasets = datasets
  128. self.slice_per_image = slice_per_image
  129. self.augmentation = augmentation
  130. def __getitem__(self, idx):
  131. data = np.load(self.images_path[idx])
  132. raw_data = data
  133. labels = np.load(self.labels_path[idx])
  134. if self.data_set_names[idx] == "NIH_PNG":
  135. x = rearrange(data.T, "H W -> 1 H W")
  136. y = rearrange(labels.T, "H W -> 1 H W")
  137. y = (y == 1).astype(np.uint8)
  138. elif self.data_set_names[idx] == "Abdment1kPNG":
  139. x = rearrange(data, "H W -> 1 H W")
  140. y = rearrange(labels, "H W -> 1 H W")
  141. y = (y == 4).astype(np.uint8)
  142. else:
  143. raise ValueError("Incorect dataset name")
  144. x = torch.tensor(x)
  145. y = torch.tensor(y)
  146. x = preprocess(x)
  147. x, y = self.apply_augmentation(x.numpy(), y.numpy())
  148. y = F.interpolate(y.unsqueeze(1), size=self.target_image_size)
  149. x = prepare(x, self.target_image_size)
  150. return x, y ,raw_data
  151. def collate_fn(self, data):
  152. images, labels , raw_data = zip(*data)
  153. images = torch.cat(images, dim=0)
  154. labels = torch.cat(labels, dim=0)
  155. # raw_data = torch.cat(raw_data, dim=0)
  156. return images, labels , raw_data
  157. def __len__(self):
  158. return len(self.images_path)
  159. def apply_augmentation(self, image, label):
  160. if self.augmentation:
  161. # If image and label are tensors, convert them to numpy arrays
  162. # raise ValueError(label.shape)
  163. augmented = self.augmentation(image=image[0], mask=label[0])
  164. image = torch.tensor(augmented["image"])
  165. label = torch.tensor(augmented["mask"])
  166. # You might want to convert back to torch.Tensor after the transformation
  167. image = image.unsqueeze(0)
  168. label = label.unsqueeze(0)
  169. else:
  170. image = torch.Tensor(image)
  171. label = torch.Tensor(label)
  172. return image, label
  173. import albumentations as A
  174. if __name__ == "__main__":
  175. model_type = "vit_h"
  176. batch_size = 4
  177. num_workers = 4
  178. slice_per_image = 1
  179. image_size = 1024
  180. checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
  181. panc_sam_instance = sam_model_registry[model_type](checkpoint=checkpoint)
  182. augmentation = A.Compose(
  183. [
  184. A.Rotate(limit=10, p=0.5),
  185. A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
  186. A.RandomResizedCrop(1024, 1024, scale=(0.9, 1.0), p=1),
  187. ]
  188. )
  189. train_dataset = PanDataset(
  190. "bath image",
  191. "bath label",
  192. image_size,
  193. slice_per_image=slice_per_image,
  194. train=True,
  195. augmentation=None,
  196. )
  197. train_loader = DataLoader(
  198. train_dataset,
  199. batch_size=batch_size,
  200. collate_fn=train_dataset.collate_fn,
  201. shuffle=True,
  202. drop_last=False,
  203. num_workers=num_workers,
  204. )
  205. # x, y = dataset[7]
  206. # print(x.shape, y.shape)
  207. now = time()
  208. for images, labels in train_loader:
  209. # pass
  210. image_numpy = images[0].permute(1, 2, 0).cpu().numpy()
  211. # Ensure that the values are in the correct range [0, 255] and cast to uint8
  212. image_numpy = (image_numpy * 255).astype(np.uint8)
  213. # Save the image using OpenCV
  214. cv2.imwrite("image2.png", image_numpy[:, :, 1])
  215. break
  216. # print((time() - now) / batch_size / slice_per_image)