|
|
@@ -0,0 +1,288 @@ |
|
|
|
""" |
|
|
|
The transform implementation refers to the following paper: |
|
|
|
"Selective Feature Aggregation Network with Area-Boundary Constraints for Polyp Segmentation" |
|
|
|
https://github.com/Yuqi-cuhk/Polyp-Seg |
|
|
|
""" |
|
|
|
|
|
|
|
import torch |
|
|
|
import torchvision.transforms.functional as F |
|
|
|
import scipy.ndimage |
|
|
|
import random |
|
|
|
from PIL import Image |
|
|
|
import numpy as np |
|
|
|
import cv2 |
|
|
|
from skimage import transform as tf |
|
|
|
import numbers |
|
|
|
|
|
|
|
class ToTensor(object): |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
image, label = data['image'], data['label'] |
|
|
|
return {'image': F.to_tensor(image), 'label': F.to_tensor(label)} |
|
|
|
|
|
|
|
|
|
|
|
class Resize(object): |
|
|
|
|
|
|
|
def __init__(self, size): |
|
|
|
self.size = size |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
image, label = data['image'], data['label'] |
|
|
|
|
|
|
|
return {'image': F.resize(image, self.size), 'label': F.resize(label, self.size)} |
|
|
|
|
|
|
|
|
|
|
|
class RandomHorizontalFlip(object): |
|
|
|
def __init__(self, p=0.5): |
|
|
|
self.p = p |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
image, label = data['image'], data['label'] |
|
|
|
|
|
|
|
if random.random() < self.p: |
|
|
|
return {'image': F.hflip(image), 'label': F.hflip(label)} |
|
|
|
|
|
|
|
return {'image': image, 'label': label} |
|
|
|
|
|
|
|
|
|
|
|
class RandomVerticalFlip(object): |
|
|
|
def __init__(self, p=0.5): |
|
|
|
self.p = p |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
image, label = data['image'], data['label'] |
|
|
|
|
|
|
|
if random.random() < self.p: |
|
|
|
return {'image': F.vflip(image), 'label': F.vflip(label)} |
|
|
|
|
|
|
|
return {'image': image, 'label': label} |
|
|
|
|
|
|
|
|
|
|
|
class RandomRotation(object): |
|
|
|
|
|
|
|
def __init__(self, degrees, resample=False, expand=False, center=None): |
|
|
|
if isinstance(degrees,numbers.Number): |
|
|
|
if degrees < 0: |
|
|
|
raise ValueError("If degrees is a single number, it must be positive.") |
|
|
|
self.degrees = (-degrees, degrees) |
|
|
|
else: |
|
|
|
if len(degrees) != 2: |
|
|
|
raise ValueError("If degrees is a sequence, it must be of len 2.") |
|
|
|
self.degrees = degrees |
|
|
|
self.resample = resample |
|
|
|
self.expand = expand |
|
|
|
self.center = center |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_params(degrees): |
|
|
|
"""Get parameters for ``rotate`` for a random rotation. |
|
|
|
|
|
|
|
Returns: |
|
|
|
sequence: params to be passed to ``rotate`` for random rotation. |
|
|
|
""" |
|
|
|
angle = random.uniform(degrees[0], degrees[1]) |
|
|
|
|
|
|
|
return angle |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
|
|
|
|
""" |
|
|
|
img (PIL Image): Image to be rotated. |
|
|
|
|
|
|
|
Returns: |
|
|
|
PIL Image: Rotated image. |
|
|
|
""" |
|
|
|
image, label = data['image'], data['label'] |
|
|
|
|
|
|
|
if random.random() < 0.5: |
|
|
|
angle = self.get_params(self.degrees) |
|
|
|
return {'image': F.rotate(image, angle, self.resample, self.expand, self.center), |
|
|
|
'label': F.rotate(label, angle, self.resample, self.expand, self.center)} |
|
|
|
|
|
|
|
return {'image': image, 'label': label} |
|
|
|
|
|
|
|
|
|
|
|
class RandomZoom(object): |
|
|
|
def __init__(self, zoom=(0.8, 1.2)): |
|
|
|
self.min, self.max = zoom[0], zoom[1] |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
image, label = data['image'], data['label'] |
|
|
|
|
|
|
|
if random.random() < 0.5: |
|
|
|
image = np.array(image) |
|
|
|
label = np.array(label) |
|
|
|
|
|
|
|
zoom = random.uniform(self.min, self.max) |
|
|
|
zoom_image = clipped_zoom(image, zoom) |
|
|
|
zoom_label = clipped_zoom(label, zoom) |
|
|
|
|
|
|
|
zoom_image = Image.fromarray(zoom_image.astype('uint8'), 'RGB') |
|
|
|
zoom_label = Image.fromarray(zoom_label.astype('uint8'), 'L') |
|
|
|
return {'image': zoom_image, 'label': zoom_label} |
|
|
|
|
|
|
|
return {'image': image, 'label': label} |
|
|
|
|
|
|
|
|
|
|
|
def clipped_zoom(img, zoom_factor, **kwargs): |
|
|
|
h, w = img.shape[:2] |
|
|
|
|
|
|
|
# For multichannel images we don't want to apply the zoom factor to the RGB |
|
|
|
# dimension, so instead we create a tuple of zoom factors, one per array |
|
|
|
# dimension, with 1's for any trailing dimensions after the width and height. |
|
|
|
zoom_tuple = (zoom_factor,) * 2 + (1,) * (img.ndim - 2) |
|
|
|
|
|
|
|
# Zooming out |
|
|
|
if zoom_factor < 1: |
|
|
|
|
|
|
|
# Bounding box of the zoomed-out image within the output array |
|
|
|
zh = int(np.round(h * zoom_factor)) |
|
|
|
zw = int(np.round(w * zoom_factor)) |
|
|
|
top = (h - zh) // 2 |
|
|
|
left = (w - zw) // 2 |
|
|
|
|
|
|
|
# Zero-padding |
|
|
|
out = np.zeros_like(img) |
|
|
|
out[top:top + zh, left:left + zw] = scipy.ndimage.zoom(img, zoom_tuple, **kwargs) |
|
|
|
|
|
|
|
# Zooming in |
|
|
|
elif zoom_factor > 1: |
|
|
|
|
|
|
|
# Bounding box of the zoomed-in region within the input array |
|
|
|
zh = int(np.round(h / zoom_factor)) |
|
|
|
zw = int(np.round(w / zoom_factor)) |
|
|
|
top = (h - zh) // 2 |
|
|
|
left = (w - zw) // 2 |
|
|
|
|
|
|
|
zoom_in = scipy.ndimage.zoom(img[top:top + zh, left:left + zw], zoom_tuple, **kwargs) |
|
|
|
|
|
|
|
# `zoom_in` might still be slightly different with `img` due to rounding, so |
|
|
|
# trim off any extra pixels at the edges or zero-padding |
|
|
|
|
|
|
|
if zoom_in.shape[0] >= h: |
|
|
|
zoom_top = (zoom_in.shape[0] - h) // 2 |
|
|
|
sh = h |
|
|
|
out_top = 0 |
|
|
|
oh = h |
|
|
|
else: |
|
|
|
zoom_top = 0 |
|
|
|
sh = zoom_in.shape[0] |
|
|
|
out_top = (h - zoom_in.shape[0]) // 2 |
|
|
|
oh = zoom_in.shape[0] |
|
|
|
if zoom_in.shape[1] >= w: |
|
|
|
zoom_left = (zoom_in.shape[1] - w) // 2 |
|
|
|
sw = w |
|
|
|
out_left = 0 |
|
|
|
ow = w |
|
|
|
else: |
|
|
|
zoom_left = 0 |
|
|
|
sw = zoom_in.shape[1] |
|
|
|
out_left = (w - zoom_in.shape[1]) // 2 |
|
|
|
ow = zoom_in.shape[1] |
|
|
|
|
|
|
|
out = np.zeros_like(img) |
|
|
|
out[out_top:out_top + oh, out_left:out_left + ow] = zoom_in[zoom_top:zoom_top + sh, zoom_left:zoom_left + sw] |
|
|
|
|
|
|
|
# If zoom_factor == 1, just return the input array |
|
|
|
else: |
|
|
|
out = img |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class Translation(object): |
|
|
|
def __init__(self, translation): |
|
|
|
self.translation = translation |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
image, label = data['image'], data['label'] |
|
|
|
|
|
|
|
if random.random() < 0.5: |
|
|
|
image = np.array(image) |
|
|
|
label = np.array(label) |
|
|
|
rows, cols, ch = image.shape |
|
|
|
|
|
|
|
translation = random.uniform(0, self.translation) |
|
|
|
tr_x = translation / 2 |
|
|
|
tr_y = translation / 2 |
|
|
|
Trans_M = np.float32([[1, 0, tr_x], [0, 1, tr_y]]) |
|
|
|
|
|
|
|
translate_image = cv2.warpAffine(image, Trans_M, (cols, rows)) |
|
|
|
translate_label = cv2.warpAffine(label, Trans_M, (cols, rows)) |
|
|
|
|
|
|
|
translate_image = Image.fromarray(translate_image.astype('uint8'), 'RGB') |
|
|
|
translate_label = Image.fromarray(translate_label.astype('uint8'), 'L') |
|
|
|
|
|
|
|
return {'image': translate_image, 'label': translate_label} |
|
|
|
|
|
|
|
return {'image': image, 'label': label} |
|
|
|
|
|
|
|
|
|
|
|
class RandomCrop(object): |
|
|
|
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): |
|
|
|
if isinstance(size, numbers.Number): |
|
|
|
self.size = (int(size), int(size)) |
|
|
|
else: |
|
|
|
self.size = size |
|
|
|
self.padding = padding |
|
|
|
self.pad_if_needed = pad_if_needed |
|
|
|
self.fill = fill |
|
|
|
self.padding_mode = padding_mode |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_params(img, output_size): |
|
|
|
"""Get parameters for ``crop`` for a random crop. |
|
|
|
Args: |
|
|
|
img (PIL Image): Image to be cropped. |
|
|
|
output_size (tuple): Expected output size of the crop. |
|
|
|
Returns: |
|
|
|
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. |
|
|
|
""" |
|
|
|
w, h = img.size |
|
|
|
th, tw = output_size |
|
|
|
if w == tw and h == th: |
|
|
|
return 0, 0, h, w |
|
|
|
|
|
|
|
#i = torch.randint(0, h - th + 1, size=(1, )).item() |
|
|
|
#j = torch.randint(0, w - tw + 1, size=(1, )).item() |
|
|
|
i = random.randint(0, h - th) |
|
|
|
j = random.randint(0, w - tw) |
|
|
|
return i, j, th, tw |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
img (PIL Image): Image to be cropped. |
|
|
|
Returns: |
|
|
|
PIL Image: Cropped image. |
|
|
|
""" |
|
|
|
img, label = data['image'], data['label'] |
|
|
|
if self.padding is not None: |
|
|
|
img = F.pad(img, self.padding, self.fill, self.padding_mode) |
|
|
|
label = F.pad(label, self.padding, self.fill, self.padding_mode) |
|
|
|
# pad the width if needed |
|
|
|
if self.pad_if_needed and img.size[0] < self.size[1]: |
|
|
|
img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) |
|
|
|
label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) |
|
|
|
|
|
|
|
# pad the height if needed |
|
|
|
if self.pad_if_needed and img.size[1] < self.size[0]: |
|
|
|
img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) |
|
|
|
label = F.pad(label, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) |
|
|
|
i, j, h, w = self.get_params(img, self.size) |
|
|
|
img = F.crop(img, i, j ,h ,w) |
|
|
|
label = F.crop(label, i, j, h, w) |
|
|
|
return {"image": img, "label": label} |
|
|
|
|
|
|
|
|
|
|
|
class Normalization(object): |
|
|
|
|
|
|
|
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): |
|
|
|
self.mean = mean |
|
|
|
self.std = std |
|
|
|
|
|
|
|
def __call__(self, sample): |
|
|
|
image, label = sample['image'], sample['label'] |
|
|
|
image = F.normalize(image, self.mean, self.std) |
|
|
|
return {'image': image, 'label': label} |
|
|
|
|