extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

clustering.py 7.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import torch.nn.init as init
  2. import os
  3. import torch
  4. import pickle
  5. from options import config
  6. import gc
  7. import torch.nn as nn
  8. from torch.nn import functional as F
  9. class ClustringModule(torch.nn.Module):
  10. def __init__(self, config_param):
  11. super(ClustringModule, self).__init__()
  12. self.h1_dim = config_param['cluster_h1_dim']
  13. self.h2_dim = config_param['cluster_h2_dim']
  14. self.final_dim = config_param['cluster_final_dim']
  15. self.dropout_rate = config_param['cluster_dropout_rate']
  16. layers = [
  17. # nn.Linear(config_param['embedding_dim'] * 8 + 1, self.h1_dim),
  18. nn.Linear(config_param['embedding_dim'] * 8, self.h1_dim),
  19. torch.nn.Dropout(self.dropout_rate),
  20. nn.ReLU(inplace=True),
  21. # nn.BatchNorm1d(self.h1_dim),
  22. nn.Linear(self.h1_dim, self.h2_dim),
  23. torch.nn.Dropout(self.dropout_rate),
  24. nn.ReLU(inplace=True),
  25. # nn.BatchNorm1d(self.h2_dim),
  26. nn.Linear(self.h2_dim, self.final_dim)]
  27. self.input_to_hidden = nn.Sequential(*layers)
  28. self.clusters_k = config_param['cluster_k']
  29. self.embed_size = self.final_dim
  30. self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size)))
  31. self.temperature = config_param['temperature']
  32. def aggregate(self, z_i):
  33. return torch.mean(z_i, dim=0)
  34. def forward(self, task_embed, y, training=True):
  35. y = y.view(-1, 1)
  36. high_idx = y > 3
  37. high_idx = high_idx.squeeze()
  38. if high_idx.sum() > 0:
  39. input_pairs = task_embed.detach()[high_idx]
  40. else:
  41. input_pairs = torch.ones(size=(1, 8 * config['embedding_dim'])).cuda()
  42. print("found")
  43. # input_pairs = torch.cat((task_embed, y), dim=1)
  44. task_embed = self.input_to_hidden(input_pairs)
  45. # todo : may be useless
  46. mean_task = self.aggregate(task_embed)
  47. res = torch.norm(mean_task - self.array, p=2, dim=1, keepdim=True)
  48. res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2)
  49. # 1*k
  50. C = torch.transpose(res / res.sum(), 0, 1)
  51. # 1*k, k*d, 1*d
  52. value = torch.mm(C, self.array)
  53. # simple add operation
  54. # new_task_embed = value + mean_task
  55. # new_task_embed = value
  56. new_task_embed = mean_task
  57. # print("injam1:", new_task_embed)
  58. # print("injam2:", self.array)
  59. list_dist = []
  60. # list_dist = torch.norm(new_task_embed - self.array, p=2, dim=1,keepdim=True)
  61. list_dist = torch.sum(torch.pow(new_task_embed - self.array,2),dim=1)
  62. stack_dist = list_dist
  63. # print("injam3:", stack_dist)
  64. ## Second, find the minimum squared distance for softmax normalization
  65. min_dist = min(list_dist)
  66. # print("injam4:", min_dist)
  67. ## Third, compute exponentials shifted with min_dist to avoid underflow (0/0) issues in softmaxes
  68. alpha = config['kmeans_alpha'] # Placeholder tensor for alpha
  69. list_exp = []
  70. for i in range(self.clusters_k):
  71. exp = torch.exp(-alpha * (stack_dist[i] - min_dist))
  72. list_exp.append(exp)
  73. stack_exp = torch.stack(list_exp)
  74. sum_exponentials = torch.sum(stack_exp)
  75. # print("injam5:", stack_exp, sum_exponentials)
  76. ## Fourth, compute softmaxes and the embedding/representative distances weighted by softmax
  77. list_softmax = []
  78. list_weighted_dist = []
  79. for j in range(self.clusters_k):
  80. softmax = stack_exp[j] / sum_exponentials
  81. weighted_dist = stack_dist[j] * softmax
  82. list_softmax.append(softmax)
  83. list_weighted_dist.append(weighted_dist)
  84. stack_weighted_dist = torch.stack(list_weighted_dist)
  85. kmeans_loss = torch.sum(stack_weighted_dist, dim=0)
  86. return C, new_task_embed, kmeans_loss
  87. class Trainer(torch.nn.Module):
  88. def __init__(self, config_param, head=None):
  89. super(Trainer, self).__init__()
  90. fc1_in_dim = config_param['embedding_dim'] * 8
  91. fc2_in_dim = config_param['first_fc_hidden_dim']
  92. fc2_out_dim = config_param['second_fc_hidden_dim']
  93. self.fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
  94. self.fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
  95. self.linear_out = torch.nn.Linear(fc2_out_dim, 1)
  96. # cluster module
  97. self.cluster_module = ClustringModule(config_param)
  98. # self.task_dim = fc1_in_dim
  99. self.task_dim = config_param['cluster_final_dim']
  100. # transform task to weights
  101. self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
  102. self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
  103. self.film_layer_2_beta = nn.Linear(self.task_dim, fc2_out_dim, bias=False)
  104. self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False)
  105. # self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False)
  106. # self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False)
  107. # self.dropout_rate = 0
  108. self.dropout_rate = config_param['trainer_dropout_rate']
  109. self.dropout = nn.Dropout(self.dropout_rate)
  110. def aggregate(self, z_i):
  111. return torch.mean(z_i, dim=0)
  112. def forward(self, task_embed, y, training, adaptation_data=None, adaptation_labels=None):
  113. if training:
  114. C, clustered_task_embed, k_loss = self.cluster_module(task_embed, y)
  115. # hidden layers
  116. # todo : adding activation function or remove it
  117. hidden_1 = self.fc1(task_embed)
  118. beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
  119. gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
  120. hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
  121. hidden_1 = self.dropout(hidden_1)
  122. hidden_2 = F.relu(hidden_1)
  123. hidden_2 = self.fc2(hidden_2)
  124. beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
  125. gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))
  126. hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
  127. hidden_2 = self.dropout(hidden_2)
  128. hidden_3 = F.relu(hidden_2)
  129. y_pred = self.linear_out(hidden_3)
  130. else:
  131. C, clustered_task_embed, k_loss = self.cluster_module(adaptation_data, adaptation_labels)
  132. beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
  133. gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
  134. beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
  135. gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))
  136. hidden_1 = self.fc1(task_embed)
  137. hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
  138. hidden_1 = self.dropout(hidden_1)
  139. hidden_2 = F.relu(hidden_1)
  140. hidden_2 = self.fc2(hidden_2)
  141. hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
  142. hidden_2 = self.dropout(hidden_2)
  143. hidden_3 = F.relu(hidden_2)
  144. y_pred = self.linear_out(hidden_3)
  145. return y_pred, C, k_loss