extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

clustering.py 6.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import torch.nn.init as init
  2. import os
  3. import torch
  4. import pickle
  5. # from options import config
  6. import gc
  7. import torch.nn as nn
  8. from torch.nn import functional as F
  9. import numpy as np
  10. class ClustringModule(torch.nn.Module):
  11. def __init__(self, config):
  12. super(ClustringModule, self).__init__()
  13. self.final_dim = config['task_dim']
  14. self.dropout_rate = config['rnn_dropout']
  15. self.embedding_dim = config['embedding_dim']
  16. self.kmeans_alpha = config['kmeans_alpha']
  17. # layers = [
  18. # nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim),
  19. # # nn.Linear(config['embedding_dim'] * 8, self.h1_dim),
  20. # torch.nn.Dropout(self.dropout_rate),
  21. # nn.ReLU(inplace=True),
  22. # nn.Linear(self.h1_dim, self.h2_dim),
  23. # torch.nn.Dropout(self.dropout_rate),
  24. # nn.ReLU(inplace=True),
  25. # nn.Linear(self.h2_dim, self.final_dim),
  26. # ]
  27. # layers_out = [
  28. # nn.Linear(self.final_dim,self.h3_dim),
  29. # torch.nn.Dropout(self.dropout_rate),
  30. # nn.ReLU(inplace=True),
  31. # nn.Linear(self.h3_dim,self.h4_dim),
  32. # torch.nn.Dropout(self.dropout_rate),
  33. # nn.ReLU(inplace=True),
  34. # nn.Linear(self.h4_dim,config['embedding_dim'] * 8 + 1),
  35. # # nn.Linear(self.h4_dim,config['embedding_dim'] * 8),
  36. # # torch.nn.Dropout(self.dropout_rate),
  37. # # nn.ReLU(inplace=True),
  38. # ]
  39. # self.input_to_hidden = nn.Sequential(*layers)
  40. # self.hidden_to_output = nn.Sequential(*layers_out)
  41. # self.recon_loss = nn.MSELoss()
  42. # self.hidden_dim = 64
  43. # self.l1_dim = 64
  44. self.hidden_dim = config['rnn_hidden']
  45. self.l1_dim = config['rnn_l1']
  46. self.rnn = nn.LSTM(4 * config['embedding_dim'] + 1, self.hidden_dim, batch_first=True)
  47. layers = [
  48. nn.Linear(config['embedding_dim'] * 4 + self.hidden_dim, self.l1_dim),
  49. torch.nn.Dropout(self.dropout_rate),
  50. nn.ReLU(inplace=True),
  51. nn.Linear(self.l1_dim, self.final_dim),
  52. ]
  53. self.input_to_hidden = nn.Sequential(*layers)
  54. self.clusters_k = config['cluster_k']
  55. self.embed_size = self.final_dim
  56. self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size)))
  57. # self.array = nn.Parameter(init.zeros_(torch.FloatTensor(self.clusters_k, self.embed_size)))
  58. self.temperature = config['temperature']
  59. def aggregate(self, z_i):
  60. return torch.mean(z_i, dim=0)
  61. def forward(self, task_embed, y, training=True):
  62. y = y.view(-1, 1)
  63. idx = 4 * self.embedding_dim
  64. items = torch.cat((task_embed[:, 0:idx], y), dim=1).unsqueeze(0)
  65. output, (hn, cn) = self.rnn(items)
  66. items_embed = output.squeeze()[-1]
  67. user_embed = task_embed[0, idx:]
  68. task_embed = self.input_to_hidden(torch.cat((items_embed, user_embed), dim=0))
  69. mean_task = task_embed
  70. res = torch.norm((mean_task) - (self.array), p=2, dim=1, keepdim=True)
  71. res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2)
  72. C = torch.transpose(res / res.sum(), 0, 1)
  73. value = torch.mm(C, self.array)
  74. # new_task_embed = value + mean_task
  75. # new_task_embed = value
  76. new_task_embed = mean_task
  77. # compute clustering loss
  78. list_dist = torch.sum(torch.pow(new_task_embed - self.array, 2), dim=1)
  79. stack_dist = list_dist
  80. min_dist = min(list_dist)
  81. alpha = self.kmeans_alpha # Placeholder tensor for alpha
  82. # alpha = alphas[iteration]
  83. list_exp = []
  84. for i in range(self.clusters_k):
  85. exp = torch.exp(-alpha * (stack_dist[i] - min_dist))
  86. list_exp.append(exp)
  87. stack_exp = torch.stack(list_exp)
  88. sum_exponentials = torch.sum(stack_exp)
  89. list_softmax = []
  90. list_weighted_dist = []
  91. for j in range(self.clusters_k):
  92. softmax = stack_exp[j] / sum_exponentials
  93. weighted_dist = stack_dist[j] * softmax
  94. list_softmax.append(softmax)
  95. list_weighted_dist.append(weighted_dist)
  96. stack_weighted_dist = torch.stack(list_weighted_dist)
  97. kmeans_loss = torch.sum(stack_weighted_dist, dim=0)
  98. # rec_loss = self.recon_loss(input_pairs,output)
  99. # return C, new_task_embed,kmeans_loss,rec_loss
  100. return C, new_task_embed, kmeans_loss
  101. class Trainer(torch.nn.Module):
  102. def __init__(self, config, head=None):
  103. super(Trainer, self).__init__()
  104. # cluster module
  105. self.cluster_module = ClustringModule(config)
  106. # self.task_dim = 64
  107. self.task_dim = config['task_dim']
  108. self.fc2_in_dim = config['first_fc_hidden_dim']
  109. self.fc2_out_dim = config['second_fc_hidden_dim']
  110. self.film_layer_1_beta = nn.Linear(self.task_dim, self.fc2_in_dim, bias=False)
  111. self.film_layer_1_gamma = nn.Linear(self.task_dim, self.fc2_in_dim, bias=False)
  112. self.film_layer_2_beta = nn.Linear(self.task_dim, self.fc2_out_dim, bias=False)
  113. self.film_layer_2_gamma = nn.Linear(self.task_dim, self.fc2_out_dim, bias=False)
  114. self.dropout_rate = config['trainer_dropout']
  115. self.dropout = nn.Dropout(self.dropout_rate)
  116. self.label_noise_std = config['label_noise_std']
  117. def aggregate(self, z_i):
  118. return torch.mean(z_i, dim=0)
  119. def forward(self, task_embed, y, training=True, adaptation_data=None, adaptation_labels=None):
  120. # if training:
  121. t = torch.Tensor(np.random.normal(0, self.label_noise_std, size=len(y))).cuda()
  122. noise_y = t + y
  123. # C, clustered_task_embed,k_loss,rec_loss = self.cluster_module(task_embed, noise_y)
  124. C, clustered_task_embed, k_loss = self.cluster_module(task_embed, noise_y)
  125. beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
  126. gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
  127. beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
  128. gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))
  129. # return gamma_1,beta_1,gamma_2,beta_2,C,k_loss,rec_loss
  130. return gamma_1, beta_1, gamma_2, beta_2, C, k_loss