import torch import torch.nn as nn class GumbalSwitch(nn.Module): def __init__(self, switch_count): super().__init__() self.switch_weight = nn.parameter.Parameter(torch.ones((switch_count, 2))) def forward(self): if self.training: return_value = nn.functional.gumbel_softmax(self.switch_weight, hard=True, dim=-1) else: argmax = torch.argmax(self.switch_weight, dim=-1) return_value = nn.functional.one_hot(argmax, num_classes=2).float() return return_value[:, 0]