extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Head.py 1.2KB

12345678910111213141516171819202122232425262728293031
  1. import torch
  2. import torch.nn.functional as F
  3. class Head(torch.nn.Module):
  4. def __init__(self, config):
  5. super(Head, self).__init__()
  6. self.embedding_dim = config['embedding_dim']
  7. self.fc1_in_dim = config['embedding_dim'] * 8
  8. self.fc2_in_dim = config['first_fc_hidden_dim']
  9. self.fc2_out_dim = config['second_fc_hidden_dim']
  10. self.use_cuda = True
  11. self.fc1 = torch.nn.Linear(self.fc1_in_dim, self.fc2_in_dim)
  12. self.fc2 = torch.nn.Linear(self.fc2_in_dim, self.fc2_out_dim)
  13. self.linear_out = torch.nn.Linear(self.fc2_out_dim, 1)
  14. self.dropout_rate = config['head_dropout']
  15. self.dropout = torch.nn.Dropout(self.dropout_rate)
  16. def forward(self, task_embed, gamma_1, beta_1, gamma_2, beta_2):
  17. hidden_1 = self.fc1(task_embed)
  18. hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
  19. hidden_1 = self.dropout(hidden_1)
  20. hidden_2 = F.relu(hidden_1)
  21. hidden_2 = self.fc2(hidden_2)
  22. hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
  23. hidden_2 = self.dropout(hidden_2)
  24. hidden_3 = F.relu(hidden_2)
  25. y_pred = self.linear_out(hidden_3)
  26. return y_pred