extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

clustering.py 5.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import torch.nn.init as init
  2. import os
  3. import torch
  4. import pickle
  5. from options import config
  6. import gc
  7. import torch.nn as nn
  8. from torch.nn import functional as F
  9. class ClustringModule(torch.nn.Module):
  10. def __init__(self, config):
  11. super(ClustringModule, self).__init__()
  12. # self.h1_dim = 128
  13. self.h1_dim = config['cluster_h1_dim']
  14. # self.h2_dim = 64
  15. self.h2_dim = config['cluster_h2_dim']
  16. # self.final_dim = fc1_in_dim
  17. # self.final_dim = 64
  18. self.final_dim = config['cluster_final_dim']
  19. # self.dropout_rate = 0
  20. self.dropout_rate = config['cluster_dropout_rate']
  21. layers = [nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim),
  22. torch.nn.Dropout(self.dropout_rate),
  23. nn.ReLU(inplace=True),
  24. # nn.BatchNorm1d(self.h1_dim),
  25. nn.Linear(self.h1_dim, self.h2_dim),
  26. torch.nn.Dropout(self.dropout_rate),
  27. nn.ReLU(inplace=True),
  28. # nn.BatchNorm1d(self.h2_dim),
  29. nn.Linear(self.h2_dim, self.final_dim)]
  30. self.input_to_hidden = nn.Sequential(*layers)
  31. # self.clusters_k = 7
  32. self.clusters_k = config['cluster_k']
  33. self.embed_size = self.final_dim
  34. self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size)))
  35. # self.temperature = 1.0
  36. self.temperature = config['temperature']
  37. def aggregate(self, z_i):
  38. return torch.mean(z_i, dim=0)
  39. def forward(self, task_embed, y, training=True):
  40. y = y.view(-1, 1)
  41. input_pairs = torch.cat((task_embed, y), dim=1)
  42. task_embed = self.input_to_hidden(input_pairs)
  43. # todo : may be useless
  44. mean_task = self.aggregate(task_embed)
  45. # C_distribution, new_task_embed = self.memoryunit(mean_task)
  46. res = torch.norm(mean_task - self.array, p=2, dim=1, keepdim=True)
  47. res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2)
  48. # 1*k
  49. C = torch.transpose(res / res.sum(), 0, 1)
  50. # 1*k, k*d, 1*d
  51. value = torch.mm(C, self.array)
  52. # simple add operation
  53. new_task_embed = value + mean_task
  54. # calculate target distribution
  55. return C, new_task_embed
  56. class Trainer(torch.nn.Module):
  57. def __init__(self, config, head=None):
  58. super(Trainer, self).__init__()
  59. fc1_in_dim = config['embedding_dim'] * 8
  60. fc2_in_dim = config['first_fc_hidden_dim']
  61. fc2_out_dim = config['second_fc_hidden_dim']
  62. self.fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
  63. self.fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
  64. self.linear_out = torch.nn.Linear(fc2_out_dim, 1)
  65. # cluster module
  66. self.cluster_module = ClustringModule(config)
  67. # self.task_dim = fc1_in_dim
  68. self.task_dim = config['cluster_final_dim']
  69. # transform task to weights
  70. self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
  71. self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
  72. self.film_layer_2_beta = nn.Linear(self.task_dim, fc2_out_dim, bias=False)
  73. self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False)
  74. # self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False)
  75. # self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False)
  76. # self.dropout_rate = 0
  77. self.dropout_rate = config['trainer_dropout_rate']
  78. self.dropout = nn.Dropout(self.dropout_rate)
  79. def aggregate(self, z_i):
  80. return torch.mean(z_i, dim=0)
  81. def forward(self, task_embed, y, training,adaptation_data=None,adaptation_labels=None):
  82. if training:
  83. C, clustered_task_embed = self.cluster_module(task_embed, y)
  84. # hidden layers
  85. # todo : adding activation function or remove it
  86. hidden_1 = self.fc1(task_embed)
  87. beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
  88. gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
  89. hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
  90. hidden_1 = self.dropout(hidden_1)
  91. hidden_2 = F.relu(hidden_1)
  92. hidden_2 = self.fc2(hidden_2)
  93. beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
  94. gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))
  95. hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
  96. hidden_2 = self.dropout(hidden_2)
  97. hidden_3 = F.relu(hidden_2)
  98. y_pred = self.linear_out(hidden_3)
  99. else:
  100. C, clustered_task_embed = self.cluster_module(adaptation_data, adaptation_labels)
  101. beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
  102. gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
  103. beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
  104. gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))
  105. hidden_1 = self.fc1(task_embed)
  106. hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
  107. hidden_1 = self.dropout(hidden_1)
  108. hidden_2 = F.relu(hidden_1)
  109. hidden_2 = self.fc2(hidden_2)
  110. hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
  111. hidden_2 = self.dropout(hidden_2)
  112. hidden_3 = F.relu(hidden_2)
  113. y_pred = self.linear_out(hidden_3)
  114. return y_pred