123456789101112131415 |
- import torch
- import torch.nn as nn
-
- class GumbalSwitch(nn.Module):
- def __init__(self, switch_count):
- super().__init__()
- self.switch_weight = nn.parameter.Parameter(torch.ones((switch_count, 2)))
-
- def forward(self):
- if self.training:
- return_value = nn.functional.gumbel_softmax(self.switch_weight, hard=True, dim=-1)
- else:
- argmax = torch.argmax(self.switch_weight, dim=-1)
- return_value = nn.functional.one_hot(argmax, num_classes=2).float()
- return return_value[:, 0]
|