Browse Source

Add files via upload

main
rucv 3 years ago
parent
commit
19356b87ad
No account linked to committer's email address
1 changed files with 288 additions and 0 deletions
  1. 288
    0
      utils/transform.py

+ 288
- 0
utils/transform.py View File

"""
The transform implementation refers to the following paper:
"Selective Feature Aggregation Network with Area-Boundary Constraints for Polyp Segmentation"
https://github.com/Yuqi-cuhk/Polyp-Seg
"""

import torch
import torchvision.transforms.functional as F
import scipy.ndimage
import random
from PIL import Image
import numpy as np
import cv2
from skimage import transform as tf
import numbers

class ToTensor(object):

def __call__(self, data):
image, label = data['image'], data['label']
return {'image': F.to_tensor(image), 'label': F.to_tensor(label)}


class Resize(object):

def __init__(self, size):
self.size = size

def __call__(self, data):
image, label = data['image'], data['label']

return {'image': F.resize(image, self.size), 'label': F.resize(label, self.size)}


class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p

def __call__(self, data):
image, label = data['image'], data['label']

if random.random() < self.p:
return {'image': F.hflip(image), 'label': F.hflip(label)}

return {'image': image, 'label': label}


class RandomVerticalFlip(object):
def __init__(self, p=0.5):
self.p = p

def __call__(self, data):
image, label = data['image'], data['label']

if random.random() < self.p:
return {'image': F.vflip(image), 'label': F.vflip(label)}

return {'image': image, 'label': label}


class RandomRotation(object):

def __init__(self, degrees, resample=False, expand=False, center=None):
if isinstance(degrees,numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
self.degrees = degrees
self.resample = resample
self.expand = expand
self.center = center

@staticmethod
def get_params(degrees):
"""Get parameters for ``rotate`` for a random rotation.

Returns:
sequence: params to be passed to ``rotate`` for random rotation.
"""
angle = random.uniform(degrees[0], degrees[1])

return angle

def __call__(self, data):

"""
img (PIL Image): Image to be rotated.

Returns:
PIL Image: Rotated image.
"""
image, label = data['image'], data['label']

if random.random() < 0.5:
angle = self.get_params(self.degrees)
return {'image': F.rotate(image, angle, self.resample, self.expand, self.center),
'label': F.rotate(label, angle, self.resample, self.expand, self.center)}

return {'image': image, 'label': label}


class RandomZoom(object):
def __init__(self, zoom=(0.8, 1.2)):
self.min, self.max = zoom[0], zoom[1]

def __call__(self, data):
image, label = data['image'], data['label']

if random.random() < 0.5:
image = np.array(image)
label = np.array(label)

zoom = random.uniform(self.min, self.max)
zoom_image = clipped_zoom(image, zoom)
zoom_label = clipped_zoom(label, zoom)

zoom_image = Image.fromarray(zoom_image.astype('uint8'), 'RGB')
zoom_label = Image.fromarray(zoom_label.astype('uint8'), 'L')
return {'image': zoom_image, 'label': zoom_label}

return {'image': image, 'label': label}


def clipped_zoom(img, zoom_factor, **kwargs):
h, w = img.shape[:2]

# For multichannel images we don't want to apply the zoom factor to the RGB
# dimension, so instead we create a tuple of zoom factors, one per array
# dimension, with 1's for any trailing dimensions after the width and height.
zoom_tuple = (zoom_factor,) * 2 + (1,) * (img.ndim - 2)

# Zooming out
if zoom_factor < 1:

# Bounding box of the zoomed-out image within the output array
zh = int(np.round(h * zoom_factor))
zw = int(np.round(w * zoom_factor))
top = (h - zh) // 2
left = (w - zw) // 2

# Zero-padding
out = np.zeros_like(img)
out[top:top + zh, left:left + zw] = scipy.ndimage.zoom(img, zoom_tuple, **kwargs)

# Zooming in
elif zoom_factor > 1:

# Bounding box of the zoomed-in region within the input array
zh = int(np.round(h / zoom_factor))
zw = int(np.round(w / zoom_factor))
top = (h - zh) // 2
left = (w - zw) // 2

zoom_in = scipy.ndimage.zoom(img[top:top + zh, left:left + zw], zoom_tuple, **kwargs)

# `zoom_in` might still be slightly different with `img` due to rounding, so
# trim off any extra pixels at the edges or zero-padding

if zoom_in.shape[0] >= h:
zoom_top = (zoom_in.shape[0] - h) // 2
sh = h
out_top = 0
oh = h
else:
zoom_top = 0
sh = zoom_in.shape[0]
out_top = (h - zoom_in.shape[0]) // 2
oh = zoom_in.shape[0]
if zoom_in.shape[1] >= w:
zoom_left = (zoom_in.shape[1] - w) // 2
sw = w
out_left = 0
ow = w
else:
zoom_left = 0
sw = zoom_in.shape[1]
out_left = (w - zoom_in.shape[1]) // 2
ow = zoom_in.shape[1]

out = np.zeros_like(img)
out[out_top:out_top + oh, out_left:out_left + ow] = zoom_in[zoom_top:zoom_top + sh, zoom_left:zoom_left + sw]

# If zoom_factor == 1, just return the input array
else:
out = img
return out


class Translation(object):
def __init__(self, translation):
self.translation = translation

def __call__(self, data):
image, label = data['image'], data['label']

if random.random() < 0.5:
image = np.array(image)
label = np.array(label)
rows, cols, ch = image.shape

translation = random.uniform(0, self.translation)
tr_x = translation / 2
tr_y = translation / 2
Trans_M = np.float32([[1, 0, tr_x], [0, 1, tr_y]])

translate_image = cv2.warpAffine(image, Trans_M, (cols, rows))
translate_label = cv2.warpAffine(label, Trans_M, (cols, rows))

translate_image = Image.fromarray(translate_image.astype('uint8'), 'RGB')
translate_label = Image.fromarray(translate_label.astype('uint8'), 'L')

return {'image': translate_image, 'label': translate_label}

return {'image': image, 'label': label}


class RandomCrop(object):
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode

@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = img.size
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w

#i = torch.randint(0, h - th + 1, size=(1, )).item()
#j = torch.randint(0, w - tw + 1, size=(1, )).item()
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw

def __call__(self, data):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
img, label = data['image'], data['label']
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)
label = F.pad(label, self.padding, self.fill, self.padding_mode)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode)

# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
label = F.pad(label, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
i, j, h, w = self.get_params(img, self.size)
img = F.crop(img, i, j ,h ,w)
label = F.crop(label, i, j, h, w)
return {"image": img, "label": label}


class Normalization(object):

def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std

def __call__(self, sample):
image, label = sample['image'], sample['label']
image = F.normalize(image, self.mean, self.std)
return {'image': image, 'label': label}


Loading…
Cancel
Save