You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.py 9.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. """
  2. The transform implementation refers to the following paper:
  3. "Selective Feature Aggregation Network with Area-Boundary Constraints for Polyp Segmentation"
  4. https://github.com/Yuqi-cuhk/Polyp-Seg
  5. """
  6. import torch
  7. import torchvision.transforms.functional as F
  8. import scipy.ndimage
  9. import random
  10. from PIL import Image
  11. import numpy as np
  12. import cv2
  13. from skimage import transform as tf
  14. import numbers
  15. class ToTensor(object):
  16. def __call__(self, data):
  17. image, label = data['image'], data['label']
  18. return {'image': F.to_tensor(image), 'label': F.to_tensor(label)}
  19. class Resize(object):
  20. def __init__(self, size):
  21. self.size = size
  22. def __call__(self, data):
  23. image, label = data['image'], data['label']
  24. return {'image': F.resize(image, self.size), 'label': F.resize(label, self.size)}
  25. class RandomHorizontalFlip(object):
  26. def __init__(self, p=0.5):
  27. self.p = p
  28. def __call__(self, data):
  29. image, label = data['image'], data['label']
  30. if random.random() < self.p:
  31. return {'image': F.hflip(image), 'label': F.hflip(label)}
  32. return {'image': image, 'label': label}
  33. class RandomVerticalFlip(object):
  34. def __init__(self, p=0.5):
  35. self.p = p
  36. def __call__(self, data):
  37. image, label = data['image'], data['label']
  38. if random.random() < self.p:
  39. return {'image': F.vflip(image), 'label': F.vflip(label)}
  40. return {'image': image, 'label': label}
  41. class RandomRotation(object):
  42. def __init__(self, degrees, resample=False, expand=False, center=None):
  43. if isinstance(degrees,numbers.Number):
  44. if degrees < 0:
  45. raise ValueError("If degrees is a single number, it must be positive.")
  46. self.degrees = (-degrees, degrees)
  47. else:
  48. if len(degrees) != 2:
  49. raise ValueError("If degrees is a sequence, it must be of len 2.")
  50. self.degrees = degrees
  51. self.resample = resample
  52. self.expand = expand
  53. self.center = center
  54. @staticmethod
  55. def get_params(degrees):
  56. """Get parameters for ``rotate`` for a random rotation.
  57. Returns:
  58. sequence: params to be passed to ``rotate`` for random rotation.
  59. """
  60. angle = random.uniform(degrees[0], degrees[1])
  61. return angle
  62. def __call__(self, data):
  63. """
  64. img (PIL Image): Image to be rotated.
  65. Returns:
  66. PIL Image: Rotated image.
  67. """
  68. image, label = data['image'], data['label']
  69. if random.random() < 0.5:
  70. angle = self.get_params(self.degrees)
  71. return {'image': F.rotate(image, angle, self.resample, self.expand, self.center),
  72. 'label': F.rotate(label, angle, self.resample, self.expand, self.center)}
  73. return {'image': image, 'label': label}
  74. class RandomZoom(object):
  75. def __init__(self, zoom=(0.8, 1.2)):
  76. self.min, self.max = zoom[0], zoom[1]
  77. def __call__(self, data):
  78. image, label = data['image'], data['label']
  79. if random.random() < 0.5:
  80. image = np.array(image)
  81. label = np.array(label)
  82. zoom = random.uniform(self.min, self.max)
  83. zoom_image = clipped_zoom(image, zoom)
  84. zoom_label = clipped_zoom(label, zoom)
  85. zoom_image = Image.fromarray(zoom_image.astype('uint8'), 'RGB')
  86. zoom_label = Image.fromarray(zoom_label.astype('uint8'), 'L')
  87. return {'image': zoom_image, 'label': zoom_label}
  88. return {'image': image, 'label': label}
  89. def clipped_zoom(img, zoom_factor, **kwargs):
  90. h, w = img.shape[:2]
  91. # For multichannel images we don't want to apply the zoom factor to the RGB
  92. # dimension, so instead we create a tuple of zoom factors, one per array
  93. # dimension, with 1's for any trailing dimensions after the width and height.
  94. zoom_tuple = (zoom_factor,) * 2 + (1,) * (img.ndim - 2)
  95. # Zooming out
  96. if zoom_factor < 1:
  97. # Bounding box of the zoomed-out image within the output array
  98. zh = int(np.round(h * zoom_factor))
  99. zw = int(np.round(w * zoom_factor))
  100. top = (h - zh) // 2
  101. left = (w - zw) // 2
  102. # Zero-padding
  103. out = np.zeros_like(img)
  104. out[top:top + zh, left:left + zw] = scipy.ndimage.zoom(img, zoom_tuple, **kwargs)
  105. # Zooming in
  106. elif zoom_factor > 1:
  107. # Bounding box of the zoomed-in region within the input array
  108. zh = int(np.round(h / zoom_factor))
  109. zw = int(np.round(w / zoom_factor))
  110. top = (h - zh) // 2
  111. left = (w - zw) // 2
  112. zoom_in = scipy.ndimage.zoom(img[top:top + zh, left:left + zw], zoom_tuple, **kwargs)
  113. # `zoom_in` might still be slightly different with `img` due to rounding, so
  114. # trim off any extra pixels at the edges or zero-padding
  115. if zoom_in.shape[0] >= h:
  116. zoom_top = (zoom_in.shape[0] - h) // 2
  117. sh = h
  118. out_top = 0
  119. oh = h
  120. else:
  121. zoom_top = 0
  122. sh = zoom_in.shape[0]
  123. out_top = (h - zoom_in.shape[0]) // 2
  124. oh = zoom_in.shape[0]
  125. if zoom_in.shape[1] >= w:
  126. zoom_left = (zoom_in.shape[1] - w) // 2
  127. sw = w
  128. out_left = 0
  129. ow = w
  130. else:
  131. zoom_left = 0
  132. sw = zoom_in.shape[1]
  133. out_left = (w - zoom_in.shape[1]) // 2
  134. ow = zoom_in.shape[1]
  135. out = np.zeros_like(img)
  136. out[out_top:out_top + oh, out_left:out_left + ow] = zoom_in[zoom_top:zoom_top + sh, zoom_left:zoom_left + sw]
  137. # If zoom_factor == 1, just return the input array
  138. else:
  139. out = img
  140. return out
  141. class Translation(object):
  142. def __init__(self, translation):
  143. self.translation = translation
  144. def __call__(self, data):
  145. image, label = data['image'], data['label']
  146. if random.random() < 0.5:
  147. image = np.array(image)
  148. label = np.array(label)
  149. rows, cols, ch = image.shape
  150. translation = random.uniform(0, self.translation)
  151. tr_x = translation / 2
  152. tr_y = translation / 2
  153. Trans_M = np.float32([[1, 0, tr_x], [0, 1, tr_y]])
  154. translate_image = cv2.warpAffine(image, Trans_M, (cols, rows))
  155. translate_label = cv2.warpAffine(label, Trans_M, (cols, rows))
  156. translate_image = Image.fromarray(translate_image.astype('uint8'), 'RGB')
  157. translate_label = Image.fromarray(translate_label.astype('uint8'), 'L')
  158. return {'image': translate_image, 'label': translate_label}
  159. return {'image': image, 'label': label}
  160. class RandomCrop(object):
  161. def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
  162. if isinstance(size, numbers.Number):
  163. self.size = (int(size), int(size))
  164. else:
  165. self.size = size
  166. self.padding = padding
  167. self.pad_if_needed = pad_if_needed
  168. self.fill = fill
  169. self.padding_mode = padding_mode
  170. @staticmethod
  171. def get_params(img, output_size):
  172. """Get parameters for ``crop`` for a random crop.
  173. Args:
  174. img (PIL Image): Image to be cropped.
  175. output_size (tuple): Expected output size of the crop.
  176. Returns:
  177. tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
  178. """
  179. w, h = img.size
  180. th, tw = output_size
  181. if w == tw and h == th:
  182. return 0, 0, h, w
  183. #i = torch.randint(0, h - th + 1, size=(1, )).item()
  184. #j = torch.randint(0, w - tw + 1, size=(1, )).item()
  185. i = random.randint(0, h - th)
  186. j = random.randint(0, w - tw)
  187. return i, j, th, tw
  188. def __call__(self, data):
  189. """
  190. Args:
  191. img (PIL Image): Image to be cropped.
  192. Returns:
  193. PIL Image: Cropped image.
  194. """
  195. img, label = data['image'], data['label']
  196. if self.padding is not None:
  197. img = F.pad(img, self.padding, self.fill, self.padding_mode)
  198. label = F.pad(label, self.padding, self.fill, self.padding_mode)
  199. # pad the width if needed
  200. if self.pad_if_needed and img.size[0] < self.size[1]:
  201. img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
  202. label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode)
  203. # pad the height if needed
  204. if self.pad_if_needed and img.size[1] < self.size[0]:
  205. img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
  206. label = F.pad(label, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
  207. i, j, h, w = self.get_params(img, self.size)
  208. img = F.crop(img, i, j ,h ,w)
  209. label = F.crop(label, i, j, h, w)
  210. return {"image": img, "label": label}
  211. class Normalization(object):
  212. def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
  213. self.mean = mean
  214. self.std = std
  215. def __call__(self, sample):
  216. image, label = sample['image'], sample['label']
  217. image = F.normalize(image, self.mean, self.std)
  218. return {'image': image, 'label': label}