You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

pre_processer.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. # This code has been taken from monai
  2. import matplotlib.pyplot as plt
  3. from typing import Tuple, Optional
  4. import os
  5. import math
  6. import torch.nn.functional as F
  7. import glob
  8. from pprint import pprint
  9. import tempfile
  10. import shutil
  11. import torchvision.transforms as transforms
  12. import pandas as pd
  13. import numpy as np
  14. import torch
  15. from einops import rearrange
  16. import nibabel as nib
  17. ###############################
  18. from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union, Mapping, Callable, Generator
  19. from enum import Enum
  20. from abc import ABC, abstractmethod
  21. from typing import Any, TypeVar
  22. from warnings import warn
  23. def _map_luts(interp_tiles: torch.Tensor, luts: torch.Tensor) -> torch.Tensor:
  24. r"""Assign the required luts to each tile.
  25. Args:
  26. interp_tiles (torch.Tensor): set of interpolation tiles. (B, 2GH, 2GW, C, TH/2, TW/2)
  27. luts (torch.Tensor): luts for each one of the original tiles. (B, GH, GW, C, 256)
  28. Returns:
  29. torch.Tensor: mapped luts (B, 2GH, 2GW, 4, C, 256)
  30. """
  31. assert interp_tiles.dim() == 6, "interp_tiles tensor must be 6D."
  32. assert luts.dim() == 5, "luts tensor must be 5D."
  33. # gh, gw -> 2x the number of tiles used to compute the histograms
  34. # th, tw -> /2 the sizes of the tiles used to compute the histograms
  35. num_imgs, gh, gw, c, th, tw = interp_tiles.shape
  36. # precompute idxs for non corner regions (doing it in cpu seems sligthly faster)
  37. j_idxs = torch.ones(gh - 2, 4, dtype=torch.long) * torch.arange(1, gh - 1).reshape(gh - 2, 1)
  38. i_idxs = torch.ones(gw - 2, 4, dtype=torch.long) * torch.arange(1, gw - 1).reshape(gw - 2, 1)
  39. j_idxs = j_idxs // 2 + j_idxs % 2
  40. j_idxs[:, 0:2] -= 1
  41. i_idxs = i_idxs // 2 + i_idxs % 2
  42. # i_idxs[:, [0, 2]] -= 1 # this slicing is not supported by jit
  43. i_idxs[:, 0] -= 1
  44. i_idxs[:, 2] -= 1
  45. # selection of luts to interpolate each patch
  46. # create a tensor with dims: interp_patches height and width x 4 x num channels x bins in the histograms
  47. # the tensor is init to -1 to denote non init hists
  48. luts_x_interp_tiles: torch.Tensor = -torch.ones(
  49. num_imgs, gh, gw, 4, c, luts.shape[-1], device=interp_tiles.device) # B x GH x GW x 4 x C x 256
  50. # corner regions
  51. luts_x_interp_tiles[:, 0::gh - 1, 0::gw - 1, 0] = luts[:, 0::max(gh // 2 - 1, 1), 0::max(gw // 2 - 1, 1)]
  52. # border region (h)
  53. luts_x_interp_tiles[:, 1:-1, 0::gw - 1, 0] = luts[:, j_idxs[:, 0], 0::max(gw // 2 - 1, 1)]
  54. luts_x_interp_tiles[:, 1:-1, 0::gw - 1, 1] = luts[:, j_idxs[:, 2], 0::max(gw // 2 - 1, 1)]
  55. # border region (w)
  56. luts_x_interp_tiles[:, 0::gh - 1, 1:-1, 0] = luts[:, 0::max(gh // 2 - 1, 1), i_idxs[:, 0]]
  57. luts_x_interp_tiles[:, 0::gh - 1, 1:-1, 1] = luts[:, 0::max(gh // 2 - 1, 1), i_idxs[:, 1]]
  58. # internal region
  59. luts_x_interp_tiles[:, 1:-1, 1:-1, :] = luts[
  60. :, j_idxs.repeat(max(gh - 2, 1), 1, 1).permute(1, 0, 2), i_idxs.repeat(max(gw - 2, 1), 1, 1)]
  61. return luts_x_interp_tiles
  62. def marginal_pdf(values: torch.Tensor, bins: torch.Tensor, sigma: torch.Tensor,
  63. epsilon: float = 1e-10) -> Tuple[torch.Tensor, torch.Tensor]:
  64. """Function that calculates the marginal probability distribution function of the input tensor
  65. based on the number of histogram bins.
  66. Args:
  67. values (torch.Tensor): shape [BxNx1].
  68. bins (torch.Tensor): shape [NUM_BINS].
  69. sigma (torch.Tensor): shape [1], gaussian smoothing factor.
  70. epsilon: (float), scalar, for numerical stability.
  71. Returns:
  72. Tuple[torch.Tensor, torch.Tensor]:
  73. - torch.Tensor: shape [BxN].
  74. - torch.Tensor: shape [BxNxNUM_BINS].
  75. """
  76. if not isinstance(values, torch.Tensor):
  77. raise TypeError("Input values type is not a torch.Tensor. Got {}"
  78. .format(type(values)))
  79. if not isinstance(bins, torch.Tensor):
  80. raise TypeError("Input bins type is not a torch.Tensor. Got {}"
  81. .format(type(bins)))
  82. if not isinstance(sigma, torch.Tensor):
  83. raise TypeError("Input sigma type is not a torch.Tensor. Got {}"
  84. .format(type(sigma)))
  85. if not values.dim() == 3:
  86. raise ValueError("Input values must be a of the shape BxNx1."
  87. " Got {}".format(values.shape))
  88. if not bins.dim() == 1:
  89. raise ValueError("Input bins must be a of the shape NUM_BINS"
  90. " Got {}".format(bins.shape))
  91. if not sigma.dim() == 0:
  92. raise ValueError("Input sigma must be a of the shape 1"
  93. " Got {}".format(sigma.shape))
  94. residuals = values - bins.unsqueeze(0).unsqueeze(0)
  95. kernel_values = torch.exp(-0.5 * (residuals / sigma).pow(2))
  96. pdf = torch.mean(kernel_values, dim=1)
  97. normalization = torch.sum(pdf, dim=1).unsqueeze(1) + epsilon
  98. pdf = pdf / normalization
  99. return (pdf, kernel_values)
  100. def histogram(x: torch.Tensor, bins: torch.Tensor, bandwidth: torch.Tensor,
  101. epsilon: float = 1e-10) -> torch.Tensor:
  102. """Function that estimates the histogram of the input tensor.
  103. The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.
  104. Args:
  105. x (torch.Tensor): Input tensor to compute the histogram with shape :math:`(B, D)`.
  106. bins (torch.Tensor): The number of bins to use the histogram :math:`(N_{bins})`.
  107. bandwidth (torch.Tensor): Gaussian smoothing factor with shape shape [1].
  108. epsilon (float): A scalar, for numerical stability. Default: 1e-10.
  109. Returns:
  110. torch.Tensor: Computed histogram of shape :math:`(B, N_{bins})`.
  111. Examples:
  112. >>> x = torch.rand(1, 10)
  113. >>> bins = torch.torch.linspace(0, 255, 128)
  114. >>> hist = histogram(x, bins, bandwidth=torch.tensor(0.9))
  115. >>> hist.shape
  116. torch.Size([1, 128])
  117. """
  118. pdf, _ = marginal_pdf(x.unsqueeze(2), bins, bandwidth, epsilon)
  119. return pdf
  120. def _compute_tiles(imgs: torch.Tensor, grid_size: Tuple[int, int], even_tile_size: bool = False
  121. ) -> Tuple[torch.Tensor, torch.Tensor]:
  122. r"""Compute tiles on an image according to a grid size.
  123. Note that padding can be added to the image in order to crop properly the image.
  124. So, the grid_size (GH, GW) x tile_size (TH, TW) >= image_size (H, W)
  125. Args:
  126. imgs (torch.Tensor): batch of 2D images with shape (B, C, H, W) or (C, H, W).
  127. grid_size (Tuple[int, int]): number of tiles to be cropped in each direction (GH, GW)
  128. even_tile_size (bool, optional): Determine if the width and height of the tiles must be even. Default: False.
  129. Returns:
  130. torch.Tensor: tensor with tiles (B, GH, GW, C, TH, TW). B = 1 in case of a single image is provided.
  131. torch.Tensor: tensor with the padded batch of 2D imageswith shape (B, C, H', W')
  132. """
  133. batch: torch.Tensor = _to_bchw(imgs) # B x C x H x W
  134. # compute stride and kernel size
  135. h, w = batch.shape[-2:]
  136. # raise ValueError(batch.shape)
  137. kernel_vert: int = math.ceil(h / grid_size[0])
  138. kernel_horz: int = math.ceil(w / grid_size[1])
  139. if even_tile_size:
  140. kernel_vert += 1 if kernel_vert % 2 else 0
  141. kernel_horz += 1 if kernel_horz % 2 else 0
  142. # add padding (with that kernel size we could need some extra cols and rows...)
  143. pad_vert = kernel_vert * grid_size[0] - h
  144. pad_horz = kernel_horz * grid_size[1] - w
  145. # raise ValueError(pad_horz)
  146. # add the padding in the last coluns and rows
  147. if pad_vert > 0 or pad_horz > 0:
  148. batch = F.pad(batch, (0, pad_horz, 0, pad_vert), mode='reflect') # B x C x H' x W'
  149. # compute tiles
  150. c: int = batch.shape[-3]
  151. tiles: torch.Tensor = (batch.unfold(1, c, c) # unfold(dimension, size, step)
  152. .unfold(2, kernel_vert, kernel_vert)
  153. .unfold(3, kernel_horz, kernel_horz)
  154. .squeeze(1)) # GH x GW x C x TH x TW
  155. assert tiles.shape[-5] == grid_size[0] # check the grid size
  156. assert tiles.shape[-4] == grid_size[1]
  157. return tiles, batch
  158. def _to_bchw(tensor: torch.Tensor, color_channel_num: Optional[int] = None) -> torch.Tensor:
  159. """Converts a PyTorch tensor image to BCHW format.
  160. Args:
  161. tensor (torch.Tensor): image of the form :math:`(H, W)`, :math:`(C, H, W)`, :math:`(H, W, C)` or
  162. :math:`(B, C, H, W)`.
  163. color_channel_num (Optional[int]): Color channel of the input tensor.
  164. If None, it will not alter the input channel.
  165. Returns:
  166. torch.Tensor: input tensor of the form :math:`(B, C, H, W)`.
  167. """
  168. if not isinstance(tensor, torch.Tensor):
  169. raise TypeError(f"Input type is not a torch.Tensor. Got {type(tensor)}")
  170. if len(tensor.shape) > 4 or len(tensor.shape) < 2:
  171. raise ValueError(f"Input size must be a two, three or four dimensional tensor. Got {tensor.shape}")
  172. if len(tensor.shape) == 2:
  173. tensor = tensor.unsqueeze(0)
  174. if len(tensor.shape) == 3:
  175. tensor = tensor.unsqueeze(0)
  176. # TODO(jian): this function is never used. Besides is not feasible for torchscript.
  177. # In addition, the docs must be updated. I don't understand what is doing.
  178. # if color_channel_num is not None and color_channel_num != 1:
  179. # channel_list = [0, 1, 2, 3]
  180. # channel_list.insert(1, channel_list.pop(color_channel_num))
  181. # tensor = tensor.permute(*channel_list)
  182. return tensor
  183. def _compute_interpolation_tiles(padded_imgs: torch.Tensor, tile_size: Tuple[int, int]) -> torch.Tensor:
  184. r"""Compute interpolation tiles on a properly padded set of images.
  185. Note that images must be padded. So, the tile_size (TH, TW) * grid_size (GH, GW) = image_size (H, W)
  186. Args:
  187. padded_imgs (torch.Tensor): batch of 2D images with shape (B, C, H, W) already padded to extract tiles
  188. of size (TH, TW).
  189. tile_size (Tuple[int, int]): shape of the current tiles (TH, TW).
  190. Returns:
  191. torch.Tensor: tensor with the interpolation tiles (B, 2GH, 2GW, C, TH/2, TW/2).
  192. """
  193. assert padded_imgs.dim() == 4, "Images Tensor must be 4D."
  194. assert padded_imgs.shape[-2] % tile_size[0] == 0, "Images are not correctly padded."
  195. assert padded_imgs.shape[-1] % tile_size[1] == 0, "Images are not correctly padded."
  196. # tiles to be interpolated are built by dividing in 4 each alrady existing
  197. interp_kernel_vert: int = tile_size[0] // 2
  198. interp_kernel_horz: int = tile_size[1] // 2
  199. c: int = padded_imgs.shape[-3]
  200. interp_tiles: torch.Tensor = (padded_imgs.unfold(1, c, c)
  201. .unfold(2, interp_kernel_vert, interp_kernel_vert)
  202. .unfold(3, interp_kernel_horz, interp_kernel_horz)
  203. .squeeze(1)) # 2GH x 2GW x C x TH/2 x TW/2
  204. assert interp_tiles.shape[-3] == c
  205. assert interp_tiles.shape[-2] == tile_size[0] / 2
  206. assert interp_tiles.shape[-1] == tile_size[1] / 2
  207. return interp_tiles
  208. def _compute_luts(tiles_x_im: torch.Tensor, num_bins: int = 256, clip: float = 40., diff: bool = False) -> torch.Tensor:
  209. r"""Compute luts for a batched set of tiles.
  210. Same approach as in OpenCV (https://github.com/opencv/opencv/blob/master/modules/imgproc/src/clahe.cpp)
  211. Args:
  212. tiles_x_im (torch.Tensor): set of tiles per image to apply the lut. (B, GH, GW, C, TH, TW)
  213. num_bins (int, optional): number of bins. default: 256
  214. clip (float): threshold value for contrast limiting. If it is 0 then the clipping is disabled. Default: 40.
  215. diff (bool, optional): denote if the differentiable histagram will be used. Default: False
  216. Returns:
  217. torch.Tensor: Lut for each tile (B, GH, GW, C, 256)
  218. """
  219. assert tiles_x_im.dim() == 6, "Tensor must be 6D."
  220. b, gh, gw, c, th, tw = tiles_x_im.shape
  221. pixels: int = th * tw
  222. tiles: torch.Tensor = tiles_x_im.reshape(-1, pixels) # test with view # T x (THxTW)
  223. histos: torch.Tensor = torch.empty((tiles.shape[0], num_bins), device=tiles.device)
  224. if not diff:
  225. for i in range(tiles.shape[0]):
  226. histos[i] = torch.histc(tiles[i], bins=num_bins, min=0, max=1)
  227. else:
  228. bins: torch.Tensor = torch.linspace(0, 1, num_bins, device=tiles.device)
  229. histos = histogram(tiles, bins, torch.tensor(0.001)).squeeze()
  230. histos *= pixels
  231. # clip limit (TODO: optimice the code)
  232. if clip > 0.:
  233. clip_limit: torch.Tensor = torch.tensor(
  234. max(clip * pixels // num_bins, 1), dtype=histos.dtype, device=tiles.device)
  235. clip_idxs: torch.Tensor = histos > clip_limit
  236. for i in range(histos.shape[0]):
  237. hist: torch.Tensor = histos[i]
  238. idxs = clip_idxs[i]
  239. if idxs.any():
  240. clipped: float = float((hist[idxs] - clip_limit).sum().item())
  241. hist = torch.where(idxs, clip_limit, hist)
  242. redist: float = clipped // num_bins
  243. hist += redist
  244. residual: float = clipped - redist * num_bins
  245. if residual:
  246. hist[0:int(residual)] += 1
  247. histos[i] = hist
  248. lut_scale: float = (num_bins - 1) / pixels
  249. luts: torch.Tensor = torch.cumsum(histos, 1) * lut_scale
  250. luts = luts.clamp(0, num_bins - 1).floor() # to get the same values as converting to int maintaining the type
  251. luts = luts.view((b, gh, gw, c, num_bins))
  252. return luts
  253. def _compute_equalized_tiles(interp_tiles: torch.Tensor, luts: torch.Tensor) -> torch.Tensor:
  254. r"""Equalize the tiles.
  255. Args:
  256. interp_tiles (torch.Tensor): set of interpolation tiles, values must be in the range [0, 1].
  257. (B, 2GH, 2GW, C, TH/2, TW/2)
  258. luts (torch.Tensor): luts for each one of the original tiles. (B, GH, GW, C, 256)
  259. Returns:
  260. torch.Tensor: equalized tiles (B, 2GH, 2GW, C, TH/2, TW/2)
  261. """
  262. assert interp_tiles.dim() == 6, "interp_tiles tensor must be 6D."
  263. assert luts.dim() == 5, "luts tensor must be 5D."
  264. mapped_luts: torch.Tensor = _map_luts(interp_tiles, luts) # Bx2GHx2GWx4xCx256
  265. # gh, gw -> 2x the number of tiles used to compute the histograms
  266. # th, tw -> /2 the sizes of the tiles used to compute the histograms
  267. num_imgs, gh, gw, c, th, tw = interp_tiles.shape
  268. # print(interp_tiles.max())
  269. # equalize tiles
  270. flatten_interp_tiles: torch.Tensor = (interp_tiles * 255).long().flatten(-2, -1) # B x GH x GW x 4 x C x (THxTW)
  271. flatten_interp_tiles = flatten_interp_tiles.unsqueeze(-3).expand(num_imgs, gh, gw, 4, c, th * tw)
  272. # raise ValueError(flatten_interp_tiles.max())
  273. k=torch.gather(mapped_luts, 5, flatten_interp_tiles)
  274. preinterp_tiles_equalized = torch.gather(mapped_luts, 5, flatten_interp_tiles).reshape(num_imgs, gh, gw, 4, c, th, tw) # B x GH x GW x 4 x C x TH x TW
  275. # interp tiles
  276. tiles_equalized: torch.Tensor = torch.zeros_like(interp_tiles, dtype=torch.long)
  277. # compute the interpolation weights (shapes are 2 x TH x TW because they must be applied to 2 interp tiles)
  278. ih = torch.arange(2 * th - 1, -1, -1, device=interp_tiles.device).div(
  279. 2. * th - 1)[None].transpose(-2, -1).expand(2 * th, tw)
  280. ih = ih.unfold(0, th, th).unfold(1, tw, tw) # 2 x 1 x TH x TW
  281. iw = torch.arange(2 * tw - 1, -1, -1, device=interp_tiles.device).div(2. * tw - 1).expand(th, 2 * tw)
  282. iw = iw.unfold(0, th, th).unfold(1, tw, tw) # 1 x 2 x TH x TW
  283. # compute row and column interpolation weigths
  284. tiw = iw.expand((gw - 2) // 2, 2, th, tw).reshape(gw - 2, 1, th, tw).unsqueeze(0) # 1 x GW-2 x 1 x TH x TW
  285. tih = ih.repeat((gh - 2) // 2, 1, 1, 1).unsqueeze(1) # GH-2 x 1 x 1 x TH x TW
  286. # internal regions
  287. tl, tr, bl, br = preinterp_tiles_equalized[:, 1:-1, 1:-1].unbind(3)
  288. t = tiw * (tl - tr) + tr
  289. b = tiw * (bl - br) + br
  290. tiles_equalized[:, 1:-1, 1:-1] = tih * (t - b) + b
  291. # corner regions
  292. tiles_equalized[:, 0::gh - 1, 0::gw - 1] = preinterp_tiles_equalized[:, 0::gh - 1, 0::gw - 1, 0]
  293. # border region (h)
  294. t, b, _, _ = preinterp_tiles_equalized[:, 1:-1, 0].unbind(2)
  295. tiles_equalized[:, 1:-1, 0] = tih.squeeze(1) * (t - b) + b
  296. t, b, _, _ = preinterp_tiles_equalized[:, 1:-1, gh - 1].unbind(2)
  297. tiles_equalized[:, 1:-1, gh - 1] = tih.squeeze(1) * (t - b) + b
  298. # border region (w)
  299. l, r, _, _ = preinterp_tiles_equalized[:, 0, 1:-1].unbind(2)
  300. tiles_equalized[:, 0, 1:-1] = tiw * (l - r) + r
  301. l, r, _, _ = preinterp_tiles_equalized[:, gw - 1, 1:-1].unbind(2)
  302. tiles_equalized[:, gw - 1, 1:-1] = tiw * (l - r) + r
  303. # same type as the input
  304. return tiles_equalized.to(interp_tiles).div(255.)
  305. def equalize_clahe(input: torch.Tensor, clip_limit: float = 40., grid_size: Tuple[int, int] = (8, 8)) -> torch.Tensor:
  306. r"""Apply clahe equalization on the input tensor.
  307. NOTE: Lut computation uses the same approach as in OpenCV, in next versions this can change.
  308. Args:
  309. input (torch.Tensor): images tensor to equalize with values in the range [0, 1] and shapes like
  310. :math:`(C, H, W)` or :math:`(B, C, H, W)`.
  311. clip_limit (float): threshold value for contrast limiting. If 0 clipping is disabled. Default: 40.
  312. grid_size (Tuple[int, int]): number of tiles to be cropped in each direction (GH, GW). Default: (8, 8).
  313. Returns:
  314. torch.Tensor: Equalized image or images with shape as the input.
  315. Examples:
  316. >>> img = torch.rand(1, 10, 20)
  317. >>> res = equalize_clahe(img)
  318. >>> res.shape
  319. torch.Size([1, 10, 20])
  320. >>> img = torch.rand(2, 3, 10, 20)
  321. >>> res = equalize_clahe(img)
  322. >>> res.shape
  323. torch.Size([2, 3, 10, 20])
  324. """
  325. if not isinstance(input, torch.Tensor):
  326. raise TypeError(f"Input input type is not a torch.Tensor. Got {type(input)}")
  327. if input.dim() not in [3, 4]:
  328. raise ValueError(f"Invalid input shape, we expect CxHxW or BxCxHxW. Got: {input.shape}")
  329. if input.dim() ==3 and len(input) not in [1,3]:
  330. raise ValueError(f'What type of image is this? The first dimension should be batch or channel number')
  331. if input.numel() == 0:
  332. raise ValueError("Invalid input tensor, it is empty.")
  333. if not isinstance(clip_limit, float):
  334. raise TypeError(f"Input clip_limit type is not float. Got {type(clip_limit)}")
  335. if not isinstance(grid_size, tuple):
  336. raise TypeError(f"Input grid_size type is not Tuple. Got {type(grid_size)}")
  337. if len(grid_size) != 2:
  338. raise TypeError(f"Input grid_size is not a Tuple with 2 elements. Got {len(grid_size)}")
  339. if isinstance(grid_size[0], float) or isinstance(grid_size[1], float):
  340. raise TypeError("Input grid_size type is not valid, must be a Tuple[int, int].")
  341. if grid_size[0] <= 0 or grid_size[1] <= 0:
  342. raise ValueError("Input grid_size elements must be positive. Got {grid_size}")
  343. imgs: torch.Tensor = _to_bchw(input) # B x C x H x W
  344. # hist_tiles: torch.Tensor # B x GH x GW x C x TH x TW # not supported by JIT
  345. # img_padded: torch.Tensor # B x C x H' x W' # not supported by JIT
  346. # the size of the tiles must be even in order to divide them into 4 tiles for the interpolation
  347. hist_tiles, img_padded = _compute_tiles(imgs, grid_size, True)
  348. tile_size: Tuple[int, int] = (hist_tiles.shape[-2], hist_tiles.shape[-1])
  349. # print(imgs.max())
  350. interp_tiles: torch.Tensor = (
  351. _compute_interpolation_tiles(img_padded, tile_size)) # B x 2GH x 2GW x C x TH/2 x TW/2
  352. luts: torch.Tensor = _compute_luts(hist_tiles, clip=clip_limit) # B x GH x GW x C x B
  353. equalized_tiles: torch.Tensor = _compute_equalized_tiles(interp_tiles, luts) # B x 2GH x 2GW x C x TH/2 x TW/2
  354. # reconstruct the images form the tiles
  355. eq_imgs: torch.Tensor = torch.cat(equalized_tiles.unbind(2), 4)
  356. eq_imgs = torch.cat(eq_imgs.unbind(1), 2)
  357. h, w = imgs.shape[-2:]
  358. eq_imgs = eq_imgs[..., :h, :w] # crop imgs if they were padded
  359. # remove batch if the input was not in batch form
  360. if input.dim() != eq_imgs.dim():
  361. eq_imgs = eq_imgs.squeeze(0)
  362. return eq_imgs
  363. ##############
  364. def histo(image):
  365. # Calculate the histogram
  366. min=np.min(image)
  367. max=np.max(image)
  368. histogram, bins = np.histogram(image.flatten(), bins=np.linspace(min,max,100))
  369. # Plot the histogram
  370. plt.figure()
  371. plt.title('Histogram')
  372. plt.xlabel('Pixel Value')
  373. plt.ylabel('Frequency')
  374. plt.bar(bins[:-1], histogram, width=1)
  375. # Display the histogram
  376. # plt.show()
  377. ## class for data pre-processing
  378. class PreProcessing:
  379. def CLAHE(img, clip_limit=40.0,grid_size=(8,8)):
  380. img=equalize_clahe( img,clip_limit,grid_size)
  381. return img
  382. def INTERPOLATE(img, size=(64),mode='linear',align_corners=False):
  383. img=F.interpolate( img,size=size, mode=mode, align_corners=False)
  384. return img