You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

gumbal_switch.py 553B

123456789101112131415
  1. import torch
  2. import torch.nn as nn
  3. class GumbalSwitch(nn.Module):
  4. def __init__(self, switch_count):
  5. super().__init__()
  6. self.switch_weight = nn.parameter.Parameter(torch.ones((switch_count, 2)))
  7. def forward(self):
  8. if self.training:
  9. return_value = nn.functional.gumbel_softmax(self.switch_weight, hard=True, dim=-1)
  10. else:
  11. argmax = torch.argmax(self.switch_weight, dim=-1)
  12. return_value = nn.functional.one_hot(argmax, num_classes=2).float()
  13. return return_value[:, 0]