You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

random_brightness_augmentation.py 591B

1234567891011121314151617
  1. from random import random
  2. import torch
  3. class ClippedBrightnessAugment(torch.nn.Module):
  4. def __init__(self, min_factor, max_factor, clip_min=-1 * float('inf'), clip_max=float('inf')):
  5. super().__init__()
  6. self._min_factor = min_factor
  7. self._max_factor = max_factor
  8. self._clip_min = clip_min
  9. self._clip_max = clip_max
  10. def forward(self, x):
  11. random_scale = random()
  12. random_scale = random_scale * (self._max_factor - self._min_factor) + self._min_factor
  13. return torch.clip(x * random_scale, self._clip_min, self._clip_max)