from random import random import torch class ClippedBrightnessAugment(torch.nn.Module): def __init__(self, min_factor, max_factor, clip_min=-1 * float('inf'), clip_max=float('inf')): super().__init__() self._min_factor = min_factor self._max_factor = max_factor self._clip_min = clip_min self._clip_max = clip_max def forward(self, x): random_scale = random() random_scale = random_scale * (self._max_factor - self._min_factor) + self._min_factor return torch.clip(x * random_scale, self._clip_min, self._clip_max)