| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501 |
- # This code has been taken from monai
- import matplotlib.pyplot as plt
- from typing import Tuple, Optional
- import os
- import math
- import torch.nn.functional as F
- import glob
- from pprint import pprint
- import tempfile
- import shutil
-
- import torchvision.transforms as transforms
- import pandas as pd
- import numpy as np
- import torch
- from einops import rearrange
- import nibabel as nib
-
-
- ###############################
- from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union, Mapping, Callable, Generator
- from enum import Enum
- from abc import ABC, abstractmethod
- from typing import Any, TypeVar
- from warnings import warn
-
- def _map_luts(interp_tiles: torch.Tensor, luts: torch.Tensor) -> torch.Tensor:
- r"""Assign the required luts to each tile.
-
- Args:
- interp_tiles (torch.Tensor): set of interpolation tiles. (B, 2GH, 2GW, C, TH/2, TW/2)
- luts (torch.Tensor): luts for each one of the original tiles. (B, GH, GW, C, 256)
-
- Returns:
- torch.Tensor: mapped luts (B, 2GH, 2GW, 4, C, 256)
-
- """
- assert interp_tiles.dim() == 6, "interp_tiles tensor must be 6D."
- assert luts.dim() == 5, "luts tensor must be 5D."
-
- # gh, gw -> 2x the number of tiles used to compute the histograms
- # th, tw -> /2 the sizes of the tiles used to compute the histograms
- num_imgs, gh, gw, c, th, tw = interp_tiles.shape
-
- # precompute idxs for non corner regions (doing it in cpu seems sligthly faster)
- j_idxs = torch.ones(gh - 2, 4, dtype=torch.long) * torch.arange(1, gh - 1).reshape(gh - 2, 1)
- i_idxs = torch.ones(gw - 2, 4, dtype=torch.long) * torch.arange(1, gw - 1).reshape(gw - 2, 1)
- j_idxs = j_idxs // 2 + j_idxs % 2
- j_idxs[:, 0:2] -= 1
- i_idxs = i_idxs // 2 + i_idxs % 2
- # i_idxs[:, [0, 2]] -= 1 # this slicing is not supported by jit
- i_idxs[:, 0] -= 1
- i_idxs[:, 2] -= 1
-
- # selection of luts to interpolate each patch
- # create a tensor with dims: interp_patches height and width x 4 x num channels x bins in the histograms
- # the tensor is init to -1 to denote non init hists
- luts_x_interp_tiles: torch.Tensor = -torch.ones(
- num_imgs, gh, gw, 4, c, luts.shape[-1], device=interp_tiles.device) # B x GH x GW x 4 x C x 256
- # corner regions
- luts_x_interp_tiles[:, 0::gh - 1, 0::gw - 1, 0] = luts[:, 0::max(gh // 2 - 1, 1), 0::max(gw // 2 - 1, 1)]
- # border region (h)
- luts_x_interp_tiles[:, 1:-1, 0::gw - 1, 0] = luts[:, j_idxs[:, 0], 0::max(gw // 2 - 1, 1)]
- luts_x_interp_tiles[:, 1:-1, 0::gw - 1, 1] = luts[:, j_idxs[:, 2], 0::max(gw // 2 - 1, 1)]
- # border region (w)
- luts_x_interp_tiles[:, 0::gh - 1, 1:-1, 0] = luts[:, 0::max(gh // 2 - 1, 1), i_idxs[:, 0]]
- luts_x_interp_tiles[:, 0::gh - 1, 1:-1, 1] = luts[:, 0::max(gh // 2 - 1, 1), i_idxs[:, 1]]
- # internal region
- luts_x_interp_tiles[:, 1:-1, 1:-1, :] = luts[
- :, j_idxs.repeat(max(gh - 2, 1), 1, 1).permute(1, 0, 2), i_idxs.repeat(max(gw - 2, 1), 1, 1)]
-
- return luts_x_interp_tiles
-
- def marginal_pdf(values: torch.Tensor, bins: torch.Tensor, sigma: torch.Tensor,
- epsilon: float = 1e-10) -> Tuple[torch.Tensor, torch.Tensor]:
- """Function that calculates the marginal probability distribution function of the input tensor
- based on the number of histogram bins.
-
- Args:
- values (torch.Tensor): shape [BxNx1].
- bins (torch.Tensor): shape [NUM_BINS].
- sigma (torch.Tensor): shape [1], gaussian smoothing factor.
- epsilon: (float), scalar, for numerical stability.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]:
- - torch.Tensor: shape [BxN].
- - torch.Tensor: shape [BxNxNUM_BINS].
-
- """
-
- if not isinstance(values, torch.Tensor):
- raise TypeError("Input values type is not a torch.Tensor. Got {}"
- .format(type(values)))
-
- if not isinstance(bins, torch.Tensor):
- raise TypeError("Input bins type is not a torch.Tensor. Got {}"
- .format(type(bins)))
-
- if not isinstance(sigma, torch.Tensor):
- raise TypeError("Input sigma type is not a torch.Tensor. Got {}"
- .format(type(sigma)))
-
- if not values.dim() == 3:
- raise ValueError("Input values must be a of the shape BxNx1."
- " Got {}".format(values.shape))
-
- if not bins.dim() == 1:
- raise ValueError("Input bins must be a of the shape NUM_BINS"
- " Got {}".format(bins.shape))
-
- if not sigma.dim() == 0:
- raise ValueError("Input sigma must be a of the shape 1"
- " Got {}".format(sigma.shape))
-
- residuals = values - bins.unsqueeze(0).unsqueeze(0)
- kernel_values = torch.exp(-0.5 * (residuals / sigma).pow(2))
-
- pdf = torch.mean(kernel_values, dim=1)
- normalization = torch.sum(pdf, dim=1).unsqueeze(1) + epsilon
- pdf = pdf / normalization
-
- return (pdf, kernel_values)
-
- def histogram(x: torch.Tensor, bins: torch.Tensor, bandwidth: torch.Tensor,
- epsilon: float = 1e-10) -> torch.Tensor:
- """Function that estimates the histogram of the input tensor.
-
- The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.
-
- Args:
- x (torch.Tensor): Input tensor to compute the histogram with shape :math:`(B, D)`.
- bins (torch.Tensor): The number of bins to use the histogram :math:`(N_{bins})`.
- bandwidth (torch.Tensor): Gaussian smoothing factor with shape shape [1].
- epsilon (float): A scalar, for numerical stability. Default: 1e-10.
-
- Returns:
- torch.Tensor: Computed histogram of shape :math:`(B, N_{bins})`.
-
- Examples:
- >>> x = torch.rand(1, 10)
- >>> bins = torch.torch.linspace(0, 255, 128)
- >>> hist = histogram(x, bins, bandwidth=torch.tensor(0.9))
- >>> hist.shape
- torch.Size([1, 128])
- """
-
- pdf, _ = marginal_pdf(x.unsqueeze(2), bins, bandwidth, epsilon)
-
- return pdf
-
- def _compute_tiles(imgs: torch.Tensor, grid_size: Tuple[int, int], even_tile_size: bool = False
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- r"""Compute tiles on an image according to a grid size.
-
- Note that padding can be added to the image in order to crop properly the image.
- So, the grid_size (GH, GW) x tile_size (TH, TW) >= image_size (H, W)
-
- Args:
- imgs (torch.Tensor): batch of 2D images with shape (B, C, H, W) or (C, H, W).
- grid_size (Tuple[int, int]): number of tiles to be cropped in each direction (GH, GW)
- even_tile_size (bool, optional): Determine if the width and height of the tiles must be even. Default: False.
-
- Returns:
- torch.Tensor: tensor with tiles (B, GH, GW, C, TH, TW). B = 1 in case of a single image is provided.
- torch.Tensor: tensor with the padded batch of 2D imageswith shape (B, C, H', W')
-
- """
-
- batch: torch.Tensor = _to_bchw(imgs) # B x C x H x W
-
- # compute stride and kernel size
- h, w = batch.shape[-2:]
- # raise ValueError(batch.shape)
- kernel_vert: int = math.ceil(h / grid_size[0])
- kernel_horz: int = math.ceil(w / grid_size[1])
-
- if even_tile_size:
- kernel_vert += 1 if kernel_vert % 2 else 0
- kernel_horz += 1 if kernel_horz % 2 else 0
-
- # add padding (with that kernel size we could need some extra cols and rows...)
- pad_vert = kernel_vert * grid_size[0] - h
- pad_horz = kernel_horz * grid_size[1] - w
- # raise ValueError(pad_horz)
- # add the padding in the last coluns and rows
-
- if pad_vert > 0 or pad_horz > 0:
-
-
- batch = F.pad(batch, (0, pad_horz, 0, pad_vert), mode='reflect') # B x C x H' x W'
-
-
- # compute tiles
- c: int = batch.shape[-3]
- tiles: torch.Tensor = (batch.unfold(1, c, c) # unfold(dimension, size, step)
- .unfold(2, kernel_vert, kernel_vert)
- .unfold(3, kernel_horz, kernel_horz)
- .squeeze(1)) # GH x GW x C x TH x TW
- assert tiles.shape[-5] == grid_size[0] # check the grid size
- assert tiles.shape[-4] == grid_size[1]
- return tiles, batch
-
- def _to_bchw(tensor: torch.Tensor, color_channel_num: Optional[int] = None) -> torch.Tensor:
- """Converts a PyTorch tensor image to BCHW format.
-
- Args:
- tensor (torch.Tensor): image of the form :math:`(H, W)`, :math:`(C, H, W)`, :math:`(H, W, C)` or
- :math:`(B, C, H, W)`.
- color_channel_num (Optional[int]): Color channel of the input tensor.
- If None, it will not alter the input channel.
-
- Returns:
- torch.Tensor: input tensor of the form :math:`(B, C, H, W)`.
- """
- if not isinstance(tensor, torch.Tensor):
- raise TypeError(f"Input type is not a torch.Tensor. Got {type(tensor)}")
-
- if len(tensor.shape) > 4 or len(tensor.shape) < 2:
- raise ValueError(f"Input size must be a two, three or four dimensional tensor. Got {tensor.shape}")
-
- if len(tensor.shape) == 2:
- tensor = tensor.unsqueeze(0)
-
- if len(tensor.shape) == 3:
- tensor = tensor.unsqueeze(0)
-
- # TODO(jian): this function is never used. Besides is not feasible for torchscript.
- # In addition, the docs must be updated. I don't understand what is doing.
- # if color_channel_num is not None and color_channel_num != 1:
- # channel_list = [0, 1, 2, 3]
- # channel_list.insert(1, channel_list.pop(color_channel_num))
- # tensor = tensor.permute(*channel_list)
- return tensor
-
- def _compute_interpolation_tiles(padded_imgs: torch.Tensor, tile_size: Tuple[int, int]) -> torch.Tensor:
- r"""Compute interpolation tiles on a properly padded set of images.
-
- Note that images must be padded. So, the tile_size (TH, TW) * grid_size (GH, GW) = image_size (H, W)
-
- Args:
- padded_imgs (torch.Tensor): batch of 2D images with shape (B, C, H, W) already padded to extract tiles
- of size (TH, TW).
- tile_size (Tuple[int, int]): shape of the current tiles (TH, TW).
-
- Returns:
- torch.Tensor: tensor with the interpolation tiles (B, 2GH, 2GW, C, TH/2, TW/2).
-
- """
- assert padded_imgs.dim() == 4, "Images Tensor must be 4D."
- assert padded_imgs.shape[-2] % tile_size[0] == 0, "Images are not correctly padded."
- assert padded_imgs.shape[-1] % tile_size[1] == 0, "Images are not correctly padded."
-
- # tiles to be interpolated are built by dividing in 4 each alrady existing
- interp_kernel_vert: int = tile_size[0] // 2
- interp_kernel_horz: int = tile_size[1] // 2
-
- c: int = padded_imgs.shape[-3]
- interp_tiles: torch.Tensor = (padded_imgs.unfold(1, c, c)
- .unfold(2, interp_kernel_vert, interp_kernel_vert)
- .unfold(3, interp_kernel_horz, interp_kernel_horz)
- .squeeze(1)) # 2GH x 2GW x C x TH/2 x TW/2
- assert interp_tiles.shape[-3] == c
- assert interp_tiles.shape[-2] == tile_size[0] / 2
- assert interp_tiles.shape[-1] == tile_size[1] / 2
- return interp_tiles
-
- def _compute_luts(tiles_x_im: torch.Tensor, num_bins: int = 256, clip: float = 40., diff: bool = False) -> torch.Tensor:
- r"""Compute luts for a batched set of tiles.
-
- Same approach as in OpenCV (https://github.com/opencv/opencv/blob/master/modules/imgproc/src/clahe.cpp)
-
- Args:
- tiles_x_im (torch.Tensor): set of tiles per image to apply the lut. (B, GH, GW, C, TH, TW)
- num_bins (int, optional): number of bins. default: 256
- clip (float): threshold value for contrast limiting. If it is 0 then the clipping is disabled. Default: 40.
- diff (bool, optional): denote if the differentiable histagram will be used. Default: False
-
- Returns:
- torch.Tensor: Lut for each tile (B, GH, GW, C, 256)
-
- """
- assert tiles_x_im.dim() == 6, "Tensor must be 6D."
-
- b, gh, gw, c, th, tw = tiles_x_im.shape
- pixels: int = th * tw
- tiles: torch.Tensor = tiles_x_im.reshape(-1, pixels) # test with view # T x (THxTW)
- histos: torch.Tensor = torch.empty((tiles.shape[0], num_bins), device=tiles.device)
- if not diff:
- for i in range(tiles.shape[0]):
-
- histos[i] = torch.histc(tiles[i], bins=num_bins, min=0, max=1)
- else:
- bins: torch.Tensor = torch.linspace(0, 1, num_bins, device=tiles.device)
- histos = histogram(tiles, bins, torch.tensor(0.001)).squeeze()
- histos *= pixels
-
- # clip limit (TODO: optimice the code)
- if clip > 0.:
- clip_limit: torch.Tensor = torch.tensor(
- max(clip * pixels // num_bins, 1), dtype=histos.dtype, device=tiles.device)
-
- clip_idxs: torch.Tensor = histos > clip_limit
- for i in range(histos.shape[0]):
- hist: torch.Tensor = histos[i]
- idxs = clip_idxs[i]
- if idxs.any():
- clipped: float = float((hist[idxs] - clip_limit).sum().item())
- hist = torch.where(idxs, clip_limit, hist)
-
- redist: float = clipped // num_bins
- hist += redist
-
- residual: float = clipped - redist * num_bins
- if residual:
- hist[0:int(residual)] += 1
- histos[i] = hist
-
- lut_scale: float = (num_bins - 1) / pixels
- luts: torch.Tensor = torch.cumsum(histos, 1) * lut_scale
- luts = luts.clamp(0, num_bins - 1).floor() # to get the same values as converting to int maintaining the type
- luts = luts.view((b, gh, gw, c, num_bins))
- return luts
-
- def _compute_equalized_tiles(interp_tiles: torch.Tensor, luts: torch.Tensor) -> torch.Tensor:
- r"""Equalize the tiles.
-
- Args:
- interp_tiles (torch.Tensor): set of interpolation tiles, values must be in the range [0, 1].
- (B, 2GH, 2GW, C, TH/2, TW/2)
- luts (torch.Tensor): luts for each one of the original tiles. (B, GH, GW, C, 256)
-
- Returns:
- torch.Tensor: equalized tiles (B, 2GH, 2GW, C, TH/2, TW/2)
-
- """
- assert interp_tiles.dim() == 6, "interp_tiles tensor must be 6D."
- assert luts.dim() == 5, "luts tensor must be 5D."
-
- mapped_luts: torch.Tensor = _map_luts(interp_tiles, luts) # Bx2GHx2GWx4xCx256
-
- # gh, gw -> 2x the number of tiles used to compute the histograms
- # th, tw -> /2 the sizes of the tiles used to compute the histograms
- num_imgs, gh, gw, c, th, tw = interp_tiles.shape
- # print(interp_tiles.max())
- # equalize tiles
- flatten_interp_tiles: torch.Tensor = (interp_tiles * 255).long().flatten(-2, -1) # B x GH x GW x 4 x C x (THxTW)
- flatten_interp_tiles = flatten_interp_tiles.unsqueeze(-3).expand(num_imgs, gh, gw, 4, c, th * tw)
- # raise ValueError(flatten_interp_tiles.max())
- k=torch.gather(mapped_luts, 5, flatten_interp_tiles)
- preinterp_tiles_equalized = torch.gather(mapped_luts, 5, flatten_interp_tiles).reshape(num_imgs, gh, gw, 4, c, th, tw) # B x GH x GW x 4 x C x TH x TW
-
- # interp tiles
- tiles_equalized: torch.Tensor = torch.zeros_like(interp_tiles, dtype=torch.long)
-
- # compute the interpolation weights (shapes are 2 x TH x TW because they must be applied to 2 interp tiles)
- ih = torch.arange(2 * th - 1, -1, -1, device=interp_tiles.device).div(
- 2. * th - 1)[None].transpose(-2, -1).expand(2 * th, tw)
- ih = ih.unfold(0, th, th).unfold(1, tw, tw) # 2 x 1 x TH x TW
- iw = torch.arange(2 * tw - 1, -1, -1, device=interp_tiles.device).div(2. * tw - 1).expand(th, 2 * tw)
- iw = iw.unfold(0, th, th).unfold(1, tw, tw) # 1 x 2 x TH x TW
-
- # compute row and column interpolation weigths
- tiw = iw.expand((gw - 2) // 2, 2, th, tw).reshape(gw - 2, 1, th, tw).unsqueeze(0) # 1 x GW-2 x 1 x TH x TW
- tih = ih.repeat((gh - 2) // 2, 1, 1, 1).unsqueeze(1) # GH-2 x 1 x 1 x TH x TW
-
- # internal regions
- tl, tr, bl, br = preinterp_tiles_equalized[:, 1:-1, 1:-1].unbind(3)
- t = tiw * (tl - tr) + tr
- b = tiw * (bl - br) + br
- tiles_equalized[:, 1:-1, 1:-1] = tih * (t - b) + b
-
- # corner regions
- tiles_equalized[:, 0::gh - 1, 0::gw - 1] = preinterp_tiles_equalized[:, 0::gh - 1, 0::gw - 1, 0]
-
- # border region (h)
- t, b, _, _ = preinterp_tiles_equalized[:, 1:-1, 0].unbind(2)
- tiles_equalized[:, 1:-1, 0] = tih.squeeze(1) * (t - b) + b
- t, b, _, _ = preinterp_tiles_equalized[:, 1:-1, gh - 1].unbind(2)
- tiles_equalized[:, 1:-1, gh - 1] = tih.squeeze(1) * (t - b) + b
-
- # border region (w)
- l, r, _, _ = preinterp_tiles_equalized[:, 0, 1:-1].unbind(2)
- tiles_equalized[:, 0, 1:-1] = tiw * (l - r) + r
- l, r, _, _ = preinterp_tiles_equalized[:, gw - 1, 1:-1].unbind(2)
- tiles_equalized[:, gw - 1, 1:-1] = tiw * (l - r) + r
-
- # same type as the input
- return tiles_equalized.to(interp_tiles).div(255.)
-
- def equalize_clahe(input: torch.Tensor, clip_limit: float = 40., grid_size: Tuple[int, int] = (8, 8)) -> torch.Tensor:
- r"""Apply clahe equalization on the input tensor.
-
- NOTE: Lut computation uses the same approach as in OpenCV, in next versions this can change.
-
- Args:
- input (torch.Tensor): images tensor to equalize with values in the range [0, 1] and shapes like
- :math:`(C, H, W)` or :math:`(B, C, H, W)`.
- clip_limit (float): threshold value for contrast limiting. If 0 clipping is disabled. Default: 40.
- grid_size (Tuple[int, int]): number of tiles to be cropped in each direction (GH, GW). Default: (8, 8).
-
- Returns:
- torch.Tensor: Equalized image or images with shape as the input.
-
- Examples:
- >>> img = torch.rand(1, 10, 20)
- >>> res = equalize_clahe(img)
- >>> res.shape
- torch.Size([1, 10, 20])
-
- >>> img = torch.rand(2, 3, 10, 20)
- >>> res = equalize_clahe(img)
- >>> res.shape
- torch.Size([2, 3, 10, 20])
-
- """
-
- if not isinstance(input, torch.Tensor):
- raise TypeError(f"Input input type is not a torch.Tensor. Got {type(input)}")
-
- if input.dim() not in [3, 4]:
- raise ValueError(f"Invalid input shape, we expect CxHxW or BxCxHxW. Got: {input.shape}")
- if input.dim() ==3 and len(input) not in [1,3]:
- raise ValueError(f'What type of image is this? The first dimension should be batch or channel number')
- if input.numel() == 0:
- raise ValueError("Invalid input tensor, it is empty.")
-
- if not isinstance(clip_limit, float):
- raise TypeError(f"Input clip_limit type is not float. Got {type(clip_limit)}")
-
- if not isinstance(grid_size, tuple):
- raise TypeError(f"Input grid_size type is not Tuple. Got {type(grid_size)}")
-
- if len(grid_size) != 2:
- raise TypeError(f"Input grid_size is not a Tuple with 2 elements. Got {len(grid_size)}")
-
- if isinstance(grid_size[0], float) or isinstance(grid_size[1], float):
- raise TypeError("Input grid_size type is not valid, must be a Tuple[int, int].")
-
- if grid_size[0] <= 0 or grid_size[1] <= 0:
- raise ValueError("Input grid_size elements must be positive. Got {grid_size}")
-
- imgs: torch.Tensor = _to_bchw(input) # B x C x H x W
-
- # hist_tiles: torch.Tensor # B x GH x GW x C x TH x TW # not supported by JIT
- # img_padded: torch.Tensor # B x C x H' x W' # not supported by JIT
- # the size of the tiles must be even in order to divide them into 4 tiles for the interpolation
- hist_tiles, img_padded = _compute_tiles(imgs, grid_size, True)
- tile_size: Tuple[int, int] = (hist_tiles.shape[-2], hist_tiles.shape[-1])
- # print(imgs.max())
-
- interp_tiles: torch.Tensor = (
- _compute_interpolation_tiles(img_padded, tile_size)) # B x 2GH x 2GW x C x TH/2 x TW/2
- luts: torch.Tensor = _compute_luts(hist_tiles, clip=clip_limit) # B x GH x GW x C x B
- equalized_tiles: torch.Tensor = _compute_equalized_tiles(interp_tiles, luts) # B x 2GH x 2GW x C x TH/2 x TW/2
-
- # reconstruct the images form the tiles
- eq_imgs: torch.Tensor = torch.cat(equalized_tiles.unbind(2), 4)
- eq_imgs = torch.cat(eq_imgs.unbind(1), 2)
- h, w = imgs.shape[-2:]
- eq_imgs = eq_imgs[..., :h, :w] # crop imgs if they were padded
-
- # remove batch if the input was not in batch form
- if input.dim() != eq_imgs.dim():
- eq_imgs = eq_imgs.squeeze(0)
- return eq_imgs
-
- ##############
-
-
-
- def histo(image):
- # Calculate the histogram
- min=np.min(image)
- max=np.max(image)
- histogram, bins = np.histogram(image.flatten(), bins=np.linspace(min,max,100))
-
- # Plot the histogram
- plt.figure()
- plt.title('Histogram')
- plt.xlabel('Pixel Value')
- plt.ylabel('Frequency')
- plt.bar(bins[:-1], histogram, width=1)
-
- # Display the histogram
- # plt.show()
-
- ## class for data pre-processing
- class PreProcessing:
-
-
- def CLAHE(img, clip_limit=40.0,grid_size=(8,8)):
-
- img=equalize_clahe( img,clip_limit,grid_size)
- return img
-
- def INTERPOLATE(img, size=(64),mode='linear',align_corners=False):
- img=F.interpolate( img,size=size, mode=mode, align_corners=False)
- return img
-
|