import random import torch import torch.nn as nn class MixStyle(nn.Module): """MixStyle. Reference: Zhou et al. Domain Generalization with MixStyle. ICLR 2021. """ def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix='random'): """ Args: p (float): probability of using MixStyle. alpha (float): parameter of the Beta distribution. eps (float): scaling parameter to avoid numerical issues. mix (str): how to mix. """ super().__init__() self.p = p self.beta = torch.distributions.Beta(alpha, alpha) self.eps = eps self.alpha = alpha self.mix = mix self._activated = True def __repr__(self): return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})' def set_activation_status(self, status=True): self._activated = status def update_mix_method(self, mix='random'): self.mix = mix # def forward(self, x): # if not self.training or not self._activated: # return x # if random.random() > self.p: # return x # B = x.size(0) # mu = x.mean(dim=[2, 3], keepdim=True) # var = x.var(dim=[2, 3], keepdim=True) # sig = (var + self.eps).sqrt() # mu, sig = mu.detach(), sig.detach() # x_normed = (x-mu) / sig # lmda = self.beta.sample((B, 1, 1, 1)) # lmda = lmda.to(x.device) # if self.mix == 'random': # # random shuffle # perm = torch.randperm(B) # elif self.mix == 'crossdomain': # # split into two halves and swap the order # perm = torch.arange(B - 1, -1, -1) # inverse index # perm_b, perm_a = perm.chunk(2) # perm_b = perm_b[torch.randperm(B // 2)] # perm_a = perm_a[torch.randperm(B // 2)] # perm = torch.cat([perm_b, perm_a], 0) # else: # raise NotImplementedError # mu2, sig2 = mu[perm], sig[perm] # mu_mix = mu*lmda + mu2 * (1-lmda) # sig_mix = sig*lmda + sig2 * (1-lmda) # return x_normed*sig_mix + mu_mix def forward(self, x): if not self.training or not self._activated: return x if random.random() > self.p: return x B = x.size(0) # Batch size # Check if input is 2D or 4D if x.dim() == 4: # CNN feature maps mu = x.mean(dim=[2, 3], keepdim=True) var = x.var(dim=[2, 3], keepdim=True) elif x.dim() == 2: # DINO embeddings (batch_size x feature_dim) mu = x.mean(dim=1, keepdim=True) var = x.var(dim=1, keepdim=True) else: raise ValueError(f"Unexpected input shape {x.shape}") sig = (var + self.eps).sqrt() mu, sig = mu.detach(), sig.detach() x_normed = (x - mu) / sig lmda = self.beta.sample((B, 1)).to(x.device) # Adjusted for 2D case perm = torch.randperm(B) mu2, sig2 = mu[perm], sig[perm] mu_mix = mu * lmda + mu2 * (1 - lmda) sig_mix = sig * lmda + sig2 * (1 - lmda) return x_normed * sig_mix + mu_mix