1234567891011121314151617 |
- from random import random
-
- import torch
-
-
- class ClippedBrightnessAugment(torch.nn.Module):
- def __init__(self, min_factor, max_factor, clip_min=-1 * float('inf'), clip_max=float('inf')):
- super().__init__()
- self._min_factor = min_factor
- self._max_factor = max_factor
- self._clip_min = clip_min
- self._clip_max = clip_max
-
- def forward(self, x):
- random_scale = random()
- random_scale = random_scale * (self._max_factor - self._min_factor) + self._min_factor
- return torch.clip(x * random_scale, self._clip_min, self._clip_max)
|