You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_load_group.py 8.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import random
  5. from segment_anything.utils.transforms import ResizeLongestSide
  6. from einops import rearrange
  7. import torch
  8. from segment_anything import SamPredictor, sam_model_registry
  9. from torch.utils.data import DataLoader
  10. from time import time
  11. import torch.nn.functional as F
  12. import cv2
  13. from PIL import Image
  14. import cv2
  15. from utils import create_prompt_simple
  16. from pre_processer import PreProcessing
  17. from tqdm import tqdm
  18. from args import get_arguments
  19. def apply_median_filter(input_matrix, kernel_size=5, sigma=0):
  20. # Apply the Gaussian filter
  21. filtered_matrix = cv2.medianBlur(input_matrix.astype(np.uint8), kernel_size)
  22. return filtered_matrix.astype(np.float32)
  23. def apply_guassain_filter(input_matrix, kernel_size=(7, 7), sigma=0):
  24. smoothed_matrix = cv2.blur(input_matrix, kernel_size)
  25. return smoothed_matrix.astype(np.float32)
  26. def img_enhance(img2, over_coef=0.8, under_coef=0.7):
  27. img2 = apply_median_filter(img2)
  28. img_blure = apply_guassain_filter(img2)
  29. img2 = img2 - 0.8 * img_blure
  30. img_mean = np.mean(img2, axis=(1, 2))
  31. img_max = np.amax(img2, axis=(1, 2))
  32. val = (img_max - img_mean) * over_coef + img_mean
  33. img2 = (img2 < img_mean * under_coef).astype(np.float32) * img_mean * under_coef + (
  34. (img2 >= img_mean * under_coef).astype(np.float32)
  35. ) * img2
  36. img2 = (img2 <= val).astype(np.float32) * img2 + (img2 > val).astype(
  37. np.float32
  38. ) * val
  39. return img2
  40. def normalize_and_pad(x, img_size):
  41. """Normalize pixel values and pad to a square input."""
  42. pixel_mean = torch.tensor([[[[123.675]], [[116.28]], [[103.53]]]])
  43. pixel_std = torch.tensor([[[[58.395]], [[57.12]], [[57.375]]]])
  44. # Normalize colors
  45. x = (x - pixel_mean) / pixel_std
  46. # Pad
  47. h, w = x.shape[-2:]
  48. padh = img_size - h
  49. padw = img_size - w
  50. x = F.pad(x, (0, padw, 0, padh))
  51. return x
  52. def preprocess(img_enhanced, img_enhance_times=1, over_coef=0.4, under_coef=0.5):
  53. # img_enhanced = img_enhanced+0.1
  54. img_enhanced -= torch.min(img_enhanced)
  55. img_max = torch.max(img_enhanced)
  56. if img_max > 0:
  57. img_enhanced = img_enhanced / img_max
  58. # raise ValueError(img_max)
  59. img_enhanced = img_enhanced.unsqueeze(1)
  60. img_enhanced = img_enhanced.unsqueeze(1)
  61. img_enhanced = PreProcessing.CLAHE(
  62. img_enhanced, clip_limit=9.0, grid_size=(4, 4)
  63. )
  64. raise ValueError(img_enhanced.shape)
  65. img_enhanced = img_enhanced[0]
  66. # for i in range(img_enhance_times):
  67. # img_enhanced=img_enhance(img_enhanced.astype(np.float32), over_coef=over_coef,under_coef=under_coef)
  68. img_enhanced -= torch.amin(img_enhanced, dim=(1, 2), keepdim=True)
  69. larg_imag = (
  70. img_enhanced / torch.amax(img_enhanced, axis=(1, 2), keepdims=True) * 255
  71. ).type(torch.uint8)
  72. return larg_imag
  73. def prepare(larg_imag, target_image_size):
  74. # larg_imag = 255 - larg_imag
  75. larg_imag = rearrange(larg_imag, "S H W -> S 1 H W")
  76. larg_imag = torch.tensor(
  77. np.concatenate([larg_imag, larg_imag, larg_imag], axis=1)
  78. ).float()
  79. transform = ResizeLongestSide(target_image_size)
  80. larg_imag = transform.apply_image_torch(larg_imag)
  81. larg_imag = normalize_and_pad(larg_imag, target_image_size)
  82. return larg_imag
  83. def process_single_image(image_path, target_image_size):
  84. # Load the image
  85. if image_path.endswith(".png") or image_path.endswith(".jpg"):
  86. data = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE).squeeze()
  87. else:
  88. data = np.load(image_path)
  89. x = rearrange(data, "H W -> 1 H W")
  90. x = torch.tensor(x)
  91. # Apply preprocessing
  92. x = preprocess(x)
  93. x = prepare(x, target_image_size)
  94. return x
  95. class PanDataset:
  96. def __init__(
  97. self,
  98. dirs,
  99. datasets,
  100. target_image_size,
  101. slice_per_image,
  102. split_ratio=0.9,
  103. train=True,
  104. val=False,
  105. augmentation=None,
  106. ):
  107. self.data_set_names = []
  108. self.labels_path = []
  109. self.images_path = []
  110. self.embedds_path = []
  111. self.labels_indexes = []
  112. self.individual_index = []
  113. for dir, dataset_name in zip(dirs, datasets):
  114. labels_dir = dir + "/labels"
  115. npy_files = [file for file in os.listdir(labels_dir)]
  116. items_label = sorted(
  117. npy_files,
  118. key=lambda x: (
  119. int(x.split("_")[2].split(".")[0]),
  120. int(x.split("_")[1]),
  121. ),
  122. )
  123. images_dir = dir + "/images"
  124. npy_files = [file for file in os.listdir(images_dir)]
  125. items_image = sorted(
  126. npy_files,
  127. key=lambda x: (
  128. int(x.split("_")[2].split(".")[0]),
  129. int(x.split("_")[1]),
  130. ),
  131. )
  132. try:
  133. embedds_dir = dir + "/embeddings"
  134. npy_files = [file for file in os.listdir(embedds_dir)]
  135. items_embedds = sorted(
  136. npy_files,
  137. key=lambda x: (
  138. int(x.split("_")[2].split(".")[0]),
  139. int(x.split("_")[1]),
  140. ),
  141. )
  142. self.embedds_path.extend(
  143. [os.path.join(embedds_dir, item) for item in items_embedds]
  144. )
  145. except:
  146. a = 1
  147. # raise ValueError(items_label[990].split('_')[2].split('.')[0])
  148. subject_indexes = set()
  149. for item in items_label:
  150. subject_indexes.add(int(item.split("_")[2].split(".")[0]))
  151. indexes = list(subject_indexes)
  152. self.labels_indexes.extend(indexes)
  153. self.individual_index.extend(
  154. [int(item.split("_")[2].split(".")[0]) for item in items_label]
  155. )
  156. self.data_set_names.extend([dataset_name[0] for _ in items_label])
  157. self.labels_path.extend(
  158. [os.path.join(labels_dir, item) for item in items_label]
  159. )
  160. self.images_path.extend(
  161. [os.path.join(images_dir, item) for item in items_image]
  162. )
  163. self.target_image_size = target_image_size
  164. self.datasets = datasets
  165. self.slice_per_image = slice_per_image
  166. self.augmentation = augmentation
  167. self.individual_index = torch.tensor(self.individual_index)
  168. if val:
  169. self.labels_indexes=self.labels_indexes[int(split_ratio*len(self.labels_indexes)):]
  170. elif train:
  171. self.labels_indexes=self.labels_indexes[:int(split_ratio*len(self.labels_indexes))]
  172. def __getitem__(self, idx):
  173. indexes = (self.individual_index == self.labels_indexes[idx]).nonzero()
  174. images_list = []
  175. labels_list = []
  176. batched_input = []
  177. for index in indexes:
  178. data = np.load(self.images_path[index])
  179. embedd = np.load(self.embedds_path[index])
  180. labels = np.load(self.labels_path[index])
  181. if self.data_set_names[index] == "NIH_PNG":
  182. x = data.T
  183. y = rearrange(labels.T, "H W -> 1 H W")
  184. y = (y == 1).astype(np.uint8)
  185. elif self.data_set_names[index] == "Abdment1k-npy":
  186. x = data
  187. y = rearrange(labels, "H W -> 1 H W")
  188. y = (y == 4).astype(np.uint8)
  189. else:
  190. raise ValueError("Incorect dataset name")
  191. x = torch.tensor(x)
  192. embedd = torch.tensor(embedd)
  193. y = torch.tensor(y)
  194. current_image_size = y.shape[-1]
  195. points, point_labels = create_prompt_simple(y[:, ::2, ::2].squeeze(1).float())
  196. points *= self.target_image_size // y[:, ::2, ::2].shape[-1]
  197. y = F.interpolate(y.unsqueeze(1), size=self.target_image_size)
  198. batched_input.append(
  199. {
  200. "image_embedd": embedd,
  201. "image": x,
  202. "label": y,
  203. "point_coords": points[0],
  204. "point_labels": point_labels[0],
  205. "original_size": (1024, 1024),
  206. },
  207. )
  208. return batched_input
  209. def collate_fn(self, data):
  210. batched_input = zip(*data)
  211. return data
  212. def __len__(self):
  213. return len(self.labels_indexes)