| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import random
- import torch
- import torch.nn as nn
-
-
- class MixStyle(nn.Module):
- """MixStyle.
- Reference:
- Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
- """
-
- def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix='random'):
- """
- Args:
- p (float): probability of using MixStyle.
- alpha (float): parameter of the Beta distribution.
- eps (float): scaling parameter to avoid numerical issues.
- mix (str): how to mix.
- """
- super().__init__()
- self.p = p
- self.beta = torch.distributions.Beta(alpha, alpha)
- self.eps = eps
- self.alpha = alpha
- self.mix = mix
- self._activated = True
-
- def __repr__(self):
- return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})'
-
- def set_activation_status(self, status=True):
- self._activated = status
-
- def update_mix_method(self, mix='random'):
- self.mix = mix
-
- # def forward(self, x):
- # if not self.training or not self._activated:
- # return x
-
- # if random.random() > self.p:
- # return x
-
- # B = x.size(0)
-
- # mu = x.mean(dim=[2, 3], keepdim=True)
- # var = x.var(dim=[2, 3], keepdim=True)
- # sig = (var + self.eps).sqrt()
- # mu, sig = mu.detach(), sig.detach()
- # x_normed = (x-mu) / sig
-
- # lmda = self.beta.sample((B, 1, 1, 1))
- # lmda = lmda.to(x.device)
-
- # if self.mix == 'random':
- # # random shuffle
- # perm = torch.randperm(B)
-
- # elif self.mix == 'crossdomain':
- # # split into two halves and swap the order
- # perm = torch.arange(B - 1, -1, -1) # inverse index
- # perm_b, perm_a = perm.chunk(2)
- # perm_b = perm_b[torch.randperm(B // 2)]
- # perm_a = perm_a[torch.randperm(B // 2)]
- # perm = torch.cat([perm_b, perm_a], 0)
-
- # else:
- # raise NotImplementedError
-
- # mu2, sig2 = mu[perm], sig[perm]
- # mu_mix = mu*lmda + mu2 * (1-lmda)
- # sig_mix = sig*lmda + sig2 * (1-lmda)
-
- # return x_normed*sig_mix + mu_mix
- def forward(self, x):
- if not self.training or not self._activated:
- return x
-
- if random.random() > self.p:
- return x
-
- B = x.size(0) # Batch size
-
- # Check if input is 2D or 4D
- if x.dim() == 4: # CNN feature maps
- mu = x.mean(dim=[2, 3], keepdim=True)
- var = x.var(dim=[2, 3], keepdim=True)
- elif x.dim() == 2: # DINO embeddings (batch_size x feature_dim)
- mu = x.mean(dim=1, keepdim=True)
- var = x.var(dim=1, keepdim=True)
- else:
- raise ValueError(f"Unexpected input shape {x.shape}")
-
- sig = (var + self.eps).sqrt()
- mu, sig = mu.detach(), sig.detach()
- x_normed = (x - mu) / sig
-
- lmda = self.beta.sample((B, 1)).to(x.device) # Adjusted for 2D case
-
- perm = torch.randperm(B)
- mu2, sig2 = mu[perm], sig[perm]
- mu_mix = mu * lmda + mu2 * (1 - lmda)
- sig_mix = sig * lmda + sig2 * (1 - lmda)
-
- return x_normed * sig_mix + mu_mix
|