You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mixstyle.py 3.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import random
  2. import torch
  3. import torch.nn as nn
  4. class MixStyle(nn.Module):
  5. """MixStyle.
  6. Reference:
  7. Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
  8. """
  9. def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix='random'):
  10. """
  11. Args:
  12. p (float): probability of using MixStyle.
  13. alpha (float): parameter of the Beta distribution.
  14. eps (float): scaling parameter to avoid numerical issues.
  15. mix (str): how to mix.
  16. """
  17. super().__init__()
  18. self.p = p
  19. self.beta = torch.distributions.Beta(alpha, alpha)
  20. self.eps = eps
  21. self.alpha = alpha
  22. self.mix = mix
  23. self._activated = True
  24. def __repr__(self):
  25. return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})'
  26. def set_activation_status(self, status=True):
  27. self._activated = status
  28. def update_mix_method(self, mix='random'):
  29. self.mix = mix
  30. # def forward(self, x):
  31. # if not self.training or not self._activated:
  32. # return x
  33. # if random.random() > self.p:
  34. # return x
  35. # B = x.size(0)
  36. # mu = x.mean(dim=[2, 3], keepdim=True)
  37. # var = x.var(dim=[2, 3], keepdim=True)
  38. # sig = (var + self.eps).sqrt()
  39. # mu, sig = mu.detach(), sig.detach()
  40. # x_normed = (x-mu) / sig
  41. # lmda = self.beta.sample((B, 1, 1, 1))
  42. # lmda = lmda.to(x.device)
  43. # if self.mix == 'random':
  44. # # random shuffle
  45. # perm = torch.randperm(B)
  46. # elif self.mix == 'crossdomain':
  47. # # split into two halves and swap the order
  48. # perm = torch.arange(B - 1, -1, -1) # inverse index
  49. # perm_b, perm_a = perm.chunk(2)
  50. # perm_b = perm_b[torch.randperm(B // 2)]
  51. # perm_a = perm_a[torch.randperm(B // 2)]
  52. # perm = torch.cat([perm_b, perm_a], 0)
  53. # else:
  54. # raise NotImplementedError
  55. # mu2, sig2 = mu[perm], sig[perm]
  56. # mu_mix = mu*lmda + mu2 * (1-lmda)
  57. # sig_mix = sig*lmda + sig2 * (1-lmda)
  58. # return x_normed*sig_mix + mu_mix
  59. def forward(self, x):
  60. if not self.training or not self._activated:
  61. return x
  62. if random.random() > self.p:
  63. return x
  64. B = x.size(0) # Batch size
  65. # Check if input is 2D or 4D
  66. if x.dim() == 4: # CNN feature maps
  67. mu = x.mean(dim=[2, 3], keepdim=True)
  68. var = x.var(dim=[2, 3], keepdim=True)
  69. elif x.dim() == 2: # DINO embeddings (batch_size x feature_dim)
  70. mu = x.mean(dim=1, keepdim=True)
  71. var = x.var(dim=1, keepdim=True)
  72. else:
  73. raise ValueError(f"Unexpected input shape {x.shape}")
  74. sig = (var + self.eps).sqrt()
  75. mu, sig = mu.detach(), sig.detach()
  76. x_normed = (x - mu) / sig
  77. lmda = self.beta.sample((B, 1)).to(x.device) # Adjusted for 2D case
  78. perm = torch.randperm(B)
  79. mu2, sig2 = mu[perm], sig[perm]
  80. mu_mix = mu * lmda + mu2 * (1 - lmda)
  81. sig_mix = sig * lmda + sig2 * (1 - lmda)
  82. return x_normed * sig_mix + mu_mix