import torch.nn.init as init import os import torch import pickle from options import config import gc import torch.nn as nn from torch.nn import functional as F class ClustringModule(torch.nn.Module): def __init__(self, config_param): super(ClustringModule, self).__init__() self.h1_dim = config_param['cluster_h1_dim'] self.h2_dim = config_param['cluster_h2_dim'] self.final_dim = config_param['cluster_final_dim'] self.dropout_rate = config_param['cluster_dropout_rate'] layers = [ # nn.Linear(config_param['embedding_dim'] * 8 + 1, self.h1_dim), nn.Linear(config_param['embedding_dim'] * 8, self.h1_dim), torch.nn.Dropout(self.dropout_rate), nn.ReLU(inplace=True), # nn.BatchNorm1d(self.h1_dim), nn.Linear(self.h1_dim, self.h2_dim), torch.nn.Dropout(self.dropout_rate), nn.ReLU(inplace=True), # nn.BatchNorm1d(self.h2_dim), nn.Linear(self.h2_dim, self.final_dim)] self.input_to_hidden = nn.Sequential(*layers) self.clusters_k = config_param['cluster_k'] self.embed_size = self.final_dim self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size))) self.temperature = config_param['temperature'] def aggregate(self, z_i): return torch.mean(z_i, dim=0) def forward(self, task_embed, y, training=True): y = y.view(-1, 1) high_idx = y > 3 high_idx = high_idx.squeeze() if high_idx.sum() > 0: input_pairs = task_embed.detach()[high_idx] else: input_pairs = torch.ones(size=(1, 8 * config['embedding_dim'])).cuda() print("found") # input_pairs = torch.cat((task_embed, y), dim=1) task_embed = self.input_to_hidden(input_pairs) # todo : may be useless mean_task = self.aggregate(task_embed) res = torch.norm(mean_task - self.array, p=2, dim=1, keepdim=True) res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2) # 1*k C = torch.transpose(res / res.sum(), 0, 1) # 1*k, k*d, 1*d value = torch.mm(C, self.array) # simple add operation # new_task_embed = value + mean_task # new_task_embed = value new_task_embed = mean_task # print("injam1:", new_task_embed) # print("injam2:", self.array) list_dist = [] # list_dist = torch.norm(new_task_embed - self.array, p=2, dim=1,keepdim=True) list_dist = torch.sum(torch.pow(new_task_embed - self.array,2),dim=1) stack_dist = list_dist # print("injam3:", stack_dist) ## Second, find the minimum squared distance for softmax normalization min_dist = min(list_dist) # print("injam4:", min_dist) ## Third, compute exponentials shifted with min_dist to avoid underflow (0/0) issues in softmaxes alpha = config['kmeans_alpha'] # Placeholder tensor for alpha list_exp = [] for i in range(self.clusters_k): exp = torch.exp(-alpha * (stack_dist[i] - min_dist)) list_exp.append(exp) stack_exp = torch.stack(list_exp) sum_exponentials = torch.sum(stack_exp) # print("injam5:", stack_exp, sum_exponentials) ## Fourth, compute softmaxes and the embedding/representative distances weighted by softmax list_softmax = [] list_weighted_dist = [] for j in range(self.clusters_k): softmax = stack_exp[j] / sum_exponentials weighted_dist = stack_dist[j] * softmax list_softmax.append(softmax) list_weighted_dist.append(weighted_dist) stack_weighted_dist = torch.stack(list_weighted_dist) kmeans_loss = torch.sum(stack_weighted_dist, dim=0) return C, new_task_embed, kmeans_loss class Trainer(torch.nn.Module): def __init__(self, config_param, head=None): super(Trainer, self).__init__() fc1_in_dim = config_param['embedding_dim'] * 8 fc2_in_dim = config_param['first_fc_hidden_dim'] fc2_out_dim = config_param['second_fc_hidden_dim'] self.fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) self.fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) self.linear_out = torch.nn.Linear(fc2_out_dim, 1) # cluster module self.cluster_module = ClustringModule(config_param) # self.task_dim = fc1_in_dim self.task_dim = config_param['cluster_final_dim'] # transform task to weights self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False) self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False) self.film_layer_2_beta = nn.Linear(self.task_dim, fc2_out_dim, bias=False) self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False) # self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False) # self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False) # self.dropout_rate = 0 self.dropout_rate = config_param['trainer_dropout_rate'] self.dropout = nn.Dropout(self.dropout_rate) def aggregate(self, z_i): return torch.mean(z_i, dim=0) def forward(self, task_embed, y, training, adaptation_data=None, adaptation_labels=None): if training: C, clustered_task_embed, k_loss = self.cluster_module(task_embed, y) # hidden layers # todo : adding activation function or remove it hidden_1 = self.fc1(task_embed) beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1 hidden_1 = self.dropout(hidden_1) hidden_2 = F.relu(hidden_1) hidden_2 = self.fc2(hidden_2) beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2 hidden_2 = self.dropout(hidden_2) hidden_3 = F.relu(hidden_2) y_pred = self.linear_out(hidden_3) else: C, clustered_task_embed, k_loss = self.cluster_module(adaptation_data, adaptation_labels) beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) hidden_1 = self.fc1(task_embed) hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1 hidden_1 = self.dropout(hidden_1) hidden_2 = F.relu(hidden_1) hidden_2 = self.fc2(hidden_2) hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2 hidden_2 = self.dropout(hidden_2) hidden_3 = F.relu(hidden_2) y_pred = self.linear_out(hidden_3) return y_pred, C, k_loss