You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

save.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import matplotlib.pyplot as plt
  2. import os
  3. import numpy as np
  4. import random
  5. from segment_anything.utils.transforms import ResizeLongestSide
  6. from einops import rearrange
  7. import torch
  8. import os
  9. from segment_anything import SamPredictor, sam_model_registry
  10. from torch.utils.data import DataLoader
  11. from time import time
  12. import torch.nn.functional as F
  13. import cv2
  14. # def preprocess(image_paths, label_paths):
  15. # preprocessed_images = []
  16. # preprocessed_labels = []
  17. # for image_path, label_path in zip(image_paths, label_paths):
  18. # # Load image and label from paths
  19. # image = plt.imread(image_path)
  20. # label = plt.imread(label_path)
  21. # # Perform preprocessing steps here
  22. # # ...
  23. # preprocessed_images.append(image)
  24. # preprocessed_labels.append(label)
  25. # return preprocessed_images, preprocessed_labels
  26. class PanDataset:
  27. def __init__(self, images_dir, labels_dir, slice_per_image, train=True,**kwargs):
  28. #for Abdonomial
  29. self.images_path = sorted([os.path.join(images_dir, item[:item.rindex('.')] + '_0000.npz') for item in os.listdir(labels_dir) if item.endswith('.npz') and not item.startswith('.')])
  30. self.labels_path = sorted([os.path.join(labels_dir, item) for item in os.listdir(labels_dir) if item.endswith('.npz') and not item.startswith('.')])
  31. #for NIH
  32. # self.images_path = sorted([os.path.join(images_dir, item) for item in os.listdir(labels_dir) if item.endswith('.npy')])
  33. # self.labels_path = sorted([os.path.join(labels_dir, item) for item in os.listdir(labels_dir) if item.endswith('.npy')])
  34. N = len(self.images_path)
  35. n = int(N * 0.8)
  36. self.train = train
  37. self.slice_per_image = slice_per_image
  38. if train:
  39. self.labels_path = self.labels_path[:n]
  40. self.images_path = self.images_path[:n]
  41. else:
  42. self.labels_path = self.labels_path[n:]
  43. self.images_path = self.images_path[n:]
  44. self.args=kwargs['args']
  45. def __getitem__(self, idx):
  46. now = time()
  47. # for abdoment
  48. data = np.load(self.images_path[idx])['arr_0']
  49. labels = np.load(self.labels_path[idx])['arr_0']
  50. #for nih
  51. # data = np.load(self.images_path[idx])
  52. # labels = np.load(self.labels_path[idx])
  53. H, W, C = data.shape
  54. positive_slices = np.any(labels == 1, axis=(0, 1))
  55. # print("Load from file time = ", time() - now)
  56. now = time()
  57. # Find the first and last positive slices
  58. first_positive_slice = np.argmax(positive_slices)
  59. last_positive_slice = labels.shape[2] - np.argmax(positive_slices[::-1]) - 1
  60. dist=last_positive_slice-first_positive_slice
  61. if self.train:
  62. save_dir = self.args.images_dir # data address here
  63. labels_save_dir = self.args.labels_dir # label address here
  64. else :
  65. save_dir = self.args.test_images_dir # data address here
  66. labels_save_dir = self.args.test_labels_dir # label address here
  67. j=0
  68. for j in range(1):
  69. slice = range(len(labels[0,0,:]))
  70. # raise ValueError(labels.shape)
  71. image_paths = []
  72. label_paths = []
  73. for i, slc_idx in enumerate(slice):
  74. # Saving Image Slices
  75. image_array = data[:, :, slc_idx]
  76. # Resize the array to 512x512
  77. resized_image_array = cv2.resize(image_array, (512, 512))
  78. min_val = resized_image_array.min()
  79. max_val = resized_image_array.max()
  80. normalized_image_array = ((resized_image_array - min_val) / (max_val - min_val) * 255).astype(np.uint8)
  81. image_paths.append(f"slice_{i}_{idx}.npy")
  82. if normalized_image_array.max()>0:
  83. np.save(os.path.join(save_dir, image_paths[-1]), normalized_image_array)
  84. # Saving Corresponding Label Slices
  85. label_array = labels[:, :, slc_idx]
  86. # Resize the array to 512x512
  87. resized_label_array = cv2.resize(label_array, (512, 512))
  88. min_val = resized_label_array.min()
  89. max_val = resized_label_array.max()
  90. # raise ValueError(np.unique(resized_label_array))
  91. # normalized_label_array = ((resized_label_array - min_val) / (max_val - min_val) * 255).astype(np.uint8)
  92. label_paths.append(f"label_{i}_{idx}.npy")
  93. np.save(os.path.join(labels_save_dir, label_paths[-1]), resized_label_array)
  94. return data
  95. def collate_fn(self, data):
  96. return data
  97. def __len__(self):
  98. return len(self.images_path)
  99. if __name__ == '__main__':
  100. model_type = 'vit_b'
  101. batch_size = 4
  102. num_workers = 4
  103. slice_per_image = 1
  104. dataset = PanDataset('../../Data/AbdomenCT-1K/numpy/images', '../../Data/AbdomenCT-1K/numpy/labels',
  105. slice_per_image=slice_per_image)
  106. # x, y = dataset[7]
  107. # # print(x.shape, y.shape)
  108. dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True, drop_last=False, num_workers=num_workers)
  109. now = time()
  110. for data in dataloader:
  111. # pass
  112. # print(images.shape, labels.shape)
  113. continue
  114. dataset = PanDataset(f'{args.train_dir}/numpy/images', f'{args.train_dir}/numpy/labels',
  115. train = False , slice_per_image=slice_per_image)
  116. # x, y = dataset[7]
  117. # # print(x.shape, y.shape)
  118. dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True, drop_last=False, num_workers=num_workers)
  119. now = time()
  120. for data in dataloader:
  121. # pass
  122. # print(images.shape, labels.shape)
  123. continue