| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- import matplotlib.pyplot as plt
- import os
- import numpy as np
- import random
- from segment_anything.utils.transforms import ResizeLongestSide
- from einops import rearrange
- import torch
- from segment_anything import SamPredictor, sam_model_registry
- from torch.utils.data import DataLoader
- from time import time
- import torch.nn.functional as F
- import cv2
- from PIL import Image
- import cv2
-
- from kernel.pre_processer import PreProcessing
-
-
- def apply_median_filter(input_matrix, kernel_size=5, sigma=0):
- # Apply the Gaussian filter
- filtered_matrix = cv2.medianBlur(input_matrix.astype(np.uint8), kernel_size)
-
- return filtered_matrix.astype(np.float32)
-
-
- def apply_guassain_filter(input_matrix, kernel_size=(7, 7), sigma=0):
- smoothed_matrix = cv2.blur(input_matrix, kernel_size)
-
- return smoothed_matrix.astype(np.float32)
-
-
- def img_enhance(img2, over_coef=0.8, under_coef=0.7):
- img2 = apply_median_filter(img2)
- img_blure = apply_guassain_filter(img2)
-
- img2 = img2 - 0.8 * img_blure
-
- img_mean = np.mean(img2, axis=(1, 2))
-
- img_max = np.amax(img2, axis=(1, 2))
-
- val = (img_max - img_mean) * over_coef + img_mean
-
- img2 = (img2 < img_mean * under_coef).astype(np.float32) * img_mean * under_coef + (
- (img2 >= img_mean * under_coef).astype(np.float32)
- ) * img2
-
- img2 = (img2 <= val).astype(np.float32) * img2 + (img2 > val).astype(
- np.float32
- ) * val
-
- return img2
-
-
- def normalize_and_pad(x, img_size):
- """Normalize pixel values and pad to a square input."""
-
- pixel_mean = torch.tensor([[[[123.675]], [[116.28]], [[103.53]]]])
- pixel_std = torch.tensor([[[[58.395]], [[57.12]], [[57.375]]]])
-
- # Normalize colors
- x = (x - pixel_mean) / pixel_std
-
- # Pad
- h, w = x.shape[-2:]
- padh = img_size - h
- padw = img_size - w
- x = F.pad(x, (0, padw, 0, padh))
- return x
-
-
- def preprocess(img_enhanced, img_enhance_times=1, over_coef=0.4, under_coef=0.5):
- # img_enhanced = img_enhanced+0.1
-
- img_enhanced -= torch.amin(img_enhanced, dim=(1, 2), keepdim=True)
- img_max = torch.amax(img_enhanced, axis=(1, 2), keepdims=True)
- img_max[img_max == 0] = 1
- img_enhanced = img_enhanced / img_max
- # raise ValueError(img_max)
- img_enhanced = img_enhanced.unsqueeze(1)
-
- img_enhanced = PreProcessing.CLAHE(img_enhanced, clip_limit=9.0, grid_size=(4, 4))
- img_enhanced = img_enhanced[0]
-
- # for i in range(img_enhance_times):
- # img_enhanced=img_enhance(img_enhanced.astype(np.float32), over_coef=over_coef,under_coef=under_coef)
-
- img_enhanced -= torch.amin(img_enhanced, dim=(1, 2), keepdim=True)
- larg_imag = (
- img_enhanced / torch.amax(img_enhanced, axis=(1, 2), keepdims=True) * 255
- ).type(torch.uint8)
-
- return larg_imag
-
-
- def prepare(larg_imag, target_image_size):
- # larg_imag = 255 - larg_imag
- larg_imag = rearrange(larg_imag, "S H W -> S 1 H W")
- larg_imag = torch.tensor(
- np.concatenate([larg_imag, larg_imag, larg_imag], axis=1)
- ).float()
- transform = ResizeLongestSide(target_image_size)
- larg_imag = transform.apply_image_torch(larg_imag)
- larg_imag = normalize_and_pad(larg_imag, target_image_size)
- return larg_imag
-
-
- def process_single_image(image_path, target_image_size):
- # Load the image
- if image_path.endswith(".png") or image_path.endswith(".jpg"):
- data = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE).squeeze()
- else:
- data = np.load(image_path)
- x = rearrange(data, "H W -> 1 H W")
- x = torch.tensor(x)
-
- # Apply preprocessing
- x = preprocess(x)
- x = prepare(x, target_image_size)
-
- return x
-
-
- class PanDataset:
- def __init__(
- self,
- images_dirs,
- labels_dirs,
- datasets,
- target_image_size,
- slice_per_image,
- train=True,
- ratio=0.9,
- augmentation=None,
- ):
- self.data_set_names = []
- self.labels_path = []
- self.images_path = []
- for labels_dir, images_dir, dataset_name in zip(
- labels_dirs, images_dirs, datasets
- ):
- if train == True:
- self.data_set_names.extend(
- sorted([dataset_name[0] for _ in os.listdir(labels_dir)[:int(len(os.listdir(labels_dir)) * ratio)]])
- )
-
- self.labels_path.extend(
- sorted([os.path.join(labels_dir, item) for item in os.listdir(labels_dir)[:int(len(os.listdir(labels_dir)) * ratio)]])
- )
- self.images_path.extend(
- sorted([os.path.join(images_dir, item) for item in os.listdir(images_dir)[:int(len(os.listdir(images_dir)) * ratio)]])
- )
- else:
- self.data_set_names.extend(
- sorted([dataset_name[0] for _ in os.listdir(labels_dir)[int(len(os.listdir(labels_dir)) * ratio):]])
- )
-
- self.labels_path.extend(
- sorted([os.path.join(labels_dir, item) for item in os.listdir(labels_dir)[int(len(os.listdir(labels_dir)) * ratio):]])
- )
- self.images_path.extend(
- sorted([os.path.join(images_dir, item) for item in os.listdir(images_dir)[int(len(os.listdir(images_dir)) * ratio):]])
- )
-
-
- self.target_image_size = target_image_size
- self.datasets = datasets
- self.slice_per_image = slice_per_image
- self.augmentation = augmentation
-
- def __getitem__(self, idx):
- data = np.load(self.images_path[idx])
- raw_data = data
-
- labels = np.load(self.labels_path[idx])
-
-
- if self.data_set_names[idx] == "NIH_PNG":
- x = rearrange(data.T, "H W -> 1 H W")
-
- y = rearrange(labels.T, "H W -> 1 H W")
- y = (y == 1).astype(np.uint8)
-
-
-
-
- elif self.data_set_names[idx] == "Abdment1kPNG":
- x = rearrange(data, "H W -> 1 H W")
-
- y = rearrange(labels, "H W -> 1 H W")
- y = (y == 4).astype(np.uint8)
- else:
- raise ValueError("Incorect dataset name")
-
- x = torch.tensor(x)
- y = torch.tensor(y)
- x = preprocess(x)
-
-
-
- x, y = self.apply_augmentation(x.numpy(), y.numpy())
-
- y = F.interpolate(y.unsqueeze(1), size=self.target_image_size)
-
- x = prepare(x, self.target_image_size)
- return x, y ,raw_data
-
- def collate_fn(self, data):
- images, labels , raw_data = zip(*data)
- images = torch.cat(images, dim=0)
- labels = torch.cat(labels, dim=0)
- # raw_data = torch.cat(raw_data, dim=0)
-
- return images, labels , raw_data
-
- def __len__(self):
- return len(self.images_path)
-
- def apply_augmentation(self, image, label):
- if self.augmentation:
- # If image and label are tensors, convert them to numpy arrays
- # raise ValueError(label.shape)
- augmented = self.augmentation(image=image[0], mask=label[0])
-
- image = torch.tensor(augmented["image"])
- label = torch.tensor(augmented["mask"])
-
- # You might want to convert back to torch.Tensor after the transformation
- image = image.unsqueeze(0)
- label = label.unsqueeze(0)
-
- else:
- image = torch.Tensor(image)
- label = torch.Tensor(label)
-
- return image, label
-
-
- import albumentations as A
-
- if __name__ == "__main__":
- model_type = "vit_h"
- batch_size = 4
- num_workers = 4
- slice_per_image = 1
- image_size = 1024
-
- checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
- panc_sam_instance = sam_model_registry[model_type](checkpoint=checkpoint)
-
- augmentation = A.Compose(
- [
- A.Rotate(limit=10, p=0.5),
- A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
- A.RandomResizedCrop(1024, 1024, scale=(0.9, 1.0), p=1),
- ]
- )
- train_dataset = PanDataset(
- "bath image",
- "bath label",
- image_size,
- slice_per_image=slice_per_image,
- train=True,
- augmentation=None,
- )
-
- train_loader = DataLoader(
- train_dataset,
- batch_size=batch_size,
- collate_fn=train_dataset.collate_fn,
- shuffle=True,
- drop_last=False,
- num_workers=num_workers,
- )
- # x, y = dataset[7]
- # print(x.shape, y.shape)
-
- now = time()
- for images, labels in train_loader:
- # pass
- image_numpy = images[0].permute(1, 2, 0).cpu().numpy()
-
- # Ensure that the values are in the correct range [0, 255] and cast to uint8
- image_numpy = (image_numpy * 255).astype(np.uint8)
-
- # Save the image using OpenCV
- cv2.imwrite("image2.png", image_numpy[:, :, 1])
-
- break
-
- # print((time() - now) / batch_size / slice_per_image)
|