You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

transformation.py 3.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import albumentations as A
  2. from albumentations.pytorch import ToTensorV2
  3. from albumentations_mixup import Mixup
  4. def get_transformation(augmentation, crop_size=299, base_dataset=None):
  5. scaled_center_crop_size = int(crop_size * 1.25)
  6. def random_crop_transformation(x):
  7. return A.RandomCrop(x, x, always_apply=True)
  8. def get_flip_rotate__custom__noise_transform(transform_list, random_scale=True):
  9. return A.Compose([
  10. A.Flip(p=0.25),
  11. A.Rotate(p=0.25),
  12. A.RandomScale(scale_limit=0.5, p=0.5 if random_scale else 0),
  13. A.PadIfNeeded(min_height=scaled_center_crop_size, min_width=scaled_center_crop_size,
  14. always_apply=True),
  15. A.CenterCrop(scaled_center_crop_size, scaled_center_crop_size),
  16. random_crop_transformation(crop_size),
  17. ] + transform_list + [
  18. A.Blur(p=0.25, blur_limit=2),
  19. A.GaussNoise(p=0.25, var_limit=10),
  20. ToTensorV2()
  21. ])
  22. if augmentation == "min":
  23. trans = A.Compose([
  24. A.PadIfNeeded(min_height=scaled_center_crop_size, min_width=scaled_center_crop_size, always_apply=True),
  25. A.CenterCrop(scaled_center_crop_size, scaled_center_crop_size),
  26. random_crop_transformation(crop_size),
  27. ToTensorV2()
  28. ])
  29. elif augmentation == "std":
  30. trans = get_flip_rotate__custom__noise_transform([])
  31. elif augmentation == "jit-nrs":
  32. trans = get_flip_rotate__custom__noise_transform([
  33. A.ColorJitter(p=0.5, hue=.5)
  34. ], random_scale=False)
  35. elif augmentation == "jit":
  36. trans = get_flip_rotate__custom__noise_transform([
  37. A.ColorJitter(p=0.5, hue=.5)
  38. ])
  39. elif augmentation == "fda":
  40. fda_image_paths = [sample[0] for sample in base_dataset.samples]
  41. trans = get_flip_rotate__custom__noise_transform([
  42. A.domain_adaptation.FDA(fda_image_paths, beta_limit=0.1, p=0.5)
  43. ])
  44. elif augmentation == "mixup":
  45. mixups = [sample[0:2] for sample in base_dataset.samples]
  46. trans = get_flip_rotate__custom__noise_transform([
  47. Mixup(mixups=mixups, p=0.5, beta_limit=(0.1)),
  48. ])
  49. elif augmentation == "jit-fda-mixup":
  50. p = 0.16
  51. fda_image_paths = [sample[0] for sample in base_dataset.samples]
  52. mixups = [sample[0:2] for sample in base_dataset.samples]
  53. trans = get_flip_rotate__custom__noise_transform([
  54. A.domain_adaptation.FDA(fda_image_paths, beta_limit=0.1, p=p),
  55. Mixup(mixups=mixups, p=p, beta_limit=(0.1)),
  56. A.ColorJitter(p=p, hue=.5)
  57. ])
  58. elif augmentation == "jit-fda-mixup-nrs":
  59. p = 0.16
  60. fda_image_paths = [sample[0] for sample in base_dataset.samples]
  61. mixups = [sample[0:2] for sample in base_dataset.samples]
  62. trans = get_flip_rotate__custom__noise_transform([
  63. A.domain_adaptation.FDA(fda_image_paths, beta_limit=0.1, p=p),
  64. Mixup(mixups=mixups, p=p, beta_limit=(0.1)),
  65. A.ColorJitter(p=p, hue=.5)
  66. ], random_scale=False)
  67. elif augmentation == "shear":
  68. trans = get_flip_rotate__custom__noise_transform([
  69. A.Affine(shear={"x": (-10, 10), "y": (-10, 10)}, p=0.5)
  70. ], random_scale=False)
  71. else:
  72. raise ValueError(f"Augmentation unknown: {augmentation}")
  73. return trans