You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

albumentations_mixup.py 2.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import random
  2. import cv2
  3. import torch
  4. import torch.nn as nn
  5. from albumentations.augmentations.utils import read_rgb_image
  6. from albumentations.core.transforms_interface import BasicTransform, to_tuple
  7. class Mixup(BasicTransform):
  8. def __init__(self, mixups, read_fn=read_rgb_image, beta_limit=0.3, **kwargs):
  9. super().__init__(**kwargs)
  10. self.mixups = mixups
  11. self.read_fn = read_fn
  12. self.beta_limit = to_tuple(beta_limit, low=0)
  13. def apply(self, image, mixup_image=None, beta=0.1, **params):
  14. img_type = image.dtype
  15. image = ((1 - beta) * image + beta * mixup_image).astype(img_type)
  16. return image
  17. def apply_to_target(self, target, beta=0.1, mixup_target=-1, **params):
  18. target = {"img": target, "mixup": mixup_target, "beta": beta}
  19. return target
  20. def get_params_dependent_on_targets(self, params):
  21. img = params["image"]
  22. mixup = random.choice(self.mixups)
  23. mixup_image = self.read_fn(mixup[0])
  24. vertical_pad = max(0, (img.shape[0] - mixup_image.shape[0]) // 2)
  25. horizontal_pad = max(0, (img.shape[1] - mixup_image.shape[1]) // 2)
  26. try:
  27. mixup_image = cv2.copyMakeBorder(mixup_image, vertical_pad, vertical_pad, horizontal_pad, horizontal_pad,
  28. cv2.BORDER_REFLECT)
  29. except Exception as e:
  30. print(e)
  31. mixup_image = cv2.resize(mixup_image, dsize=(img.shape[1], img.shape[0]))
  32. return {"mixup_image": mixup_image, "mixup_target": mixup[1]}
  33. def get_params(self):
  34. return {"beta": random.uniform(self.beta_limit[0], self.beta_limit[1])}
  35. @property
  36. def targets(self):
  37. return {
  38. "image": self.apply,
  39. "target": self.apply_to_target,
  40. }
  41. @property
  42. def targets_as_params(self):
  43. return ["image"]
  44. def mixup_loss(output, target):
  45. if type(target) == torch.Tensor:
  46. loss = nn.CrossEntropyLoss()
  47. return loss(output, target)
  48. else: # mixup has been used
  49. loss = nn.CrossEntropyLoss(reduction="none")
  50. return ((1 - target["beta"]) * loss(output, target["img"]) + target["beta"] * loss(output,
  51. target["mixup"])).mean()