import torch.nn.init as init import os import torch import pickle # from options import config import gc import torch.nn as nn from torch.nn import functional as F import numpy as np class ClustringModule(torch.nn.Module): def __init__(self, config): super(ClustringModule, self).__init__() self.final_dim = config['task_dim'] self.dropout_rate = config['rnn_dropout'] self.embedding_dim = config['embedding_dim'] self.kmeans_alpha = config['kmeans_alpha'] # layers = [ # nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim), # # nn.Linear(config['embedding_dim'] * 8, self.h1_dim), # torch.nn.Dropout(self.dropout_rate), # nn.ReLU(inplace=True), # nn.Linear(self.h1_dim, self.h2_dim), # torch.nn.Dropout(self.dropout_rate), # nn.ReLU(inplace=True), # nn.Linear(self.h2_dim, self.final_dim), # ] # layers_out = [ # nn.Linear(self.final_dim,self.h3_dim), # torch.nn.Dropout(self.dropout_rate), # nn.ReLU(inplace=True), # nn.Linear(self.h3_dim,self.h4_dim), # torch.nn.Dropout(self.dropout_rate), # nn.ReLU(inplace=True), # nn.Linear(self.h4_dim,config['embedding_dim'] * 8 + 1), # # nn.Linear(self.h4_dim,config['embedding_dim'] * 8), # # torch.nn.Dropout(self.dropout_rate), # # nn.ReLU(inplace=True), # ] # self.input_to_hidden = nn.Sequential(*layers) # self.hidden_to_output = nn.Sequential(*layers_out) # self.recon_loss = nn.MSELoss() # self.hidden_dim = 64 # self.l1_dim = 64 self.hidden_dim = config['rnn_hidden'] self.l1_dim = config['rnn_l1'] self.rnn = nn.LSTM(4 * config['embedding_dim'] + 1, self.hidden_dim, batch_first=True) layers = [ nn.Linear(config['embedding_dim'] * 4 + self.hidden_dim, self.l1_dim), torch.nn.Dropout(self.dropout_rate), nn.ReLU(inplace=True), nn.Linear(self.l1_dim, self.final_dim), ] self.input_to_hidden = nn.Sequential(*layers) self.clusters_k = config['cluster_k'] self.embed_size = self.final_dim self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size))) # self.array = nn.Parameter(init.zeros_(torch.FloatTensor(self.clusters_k, self.embed_size))) self.temperature = config['temperature'] def aggregate(self, z_i): return torch.mean(z_i, dim=0) def forward(self, task_embed, y, training=True): y = y.view(-1, 1) idx = 4 * self.embedding_dim items = torch.cat((task_embed[:, 0:idx], y), dim=1).unsqueeze(0) output, (hn, cn) = self.rnn(items) items_embed = output.squeeze()[-1] user_embed = task_embed[0, idx:] task_embed = self.input_to_hidden(torch.cat((items_embed, user_embed), dim=0)) mean_task = task_embed res = torch.norm((mean_task) - (self.array), p=2, dim=1, keepdim=True) res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2) C = torch.transpose(res / res.sum(), 0, 1) value = torch.mm(C, self.array) # new_task_embed = value + mean_task # new_task_embed = value new_task_embed = mean_task # compute clustering loss list_dist = torch.sum(torch.pow(new_task_embed - self.array, 2), dim=1) stack_dist = list_dist min_dist = min(list_dist) alpha = self.kmeans_alpha # Placeholder tensor for alpha # alpha = alphas[iteration] list_exp = [] for i in range(self.clusters_k): exp = torch.exp(-alpha * (stack_dist[i] - min_dist)) list_exp.append(exp) stack_exp = torch.stack(list_exp) sum_exponentials = torch.sum(stack_exp) list_softmax = [] list_weighted_dist = [] for j in range(self.clusters_k): softmax = stack_exp[j] / sum_exponentials weighted_dist = stack_dist[j] * softmax list_softmax.append(softmax) list_weighted_dist.append(weighted_dist) stack_weighted_dist = torch.stack(list_weighted_dist) kmeans_loss = torch.sum(stack_weighted_dist, dim=0) # rec_loss = self.recon_loss(input_pairs,output) # return C, new_task_embed,kmeans_loss,rec_loss return C, new_task_embed, kmeans_loss class Trainer(torch.nn.Module): def __init__(self, config, head=None): super(Trainer, self).__init__() # cluster module self.cluster_module = ClustringModule(config) # self.task_dim = 64 self.task_dim = config['task_dim'] self.fc2_in_dim = config['first_fc_hidden_dim'] self.fc2_out_dim = config['second_fc_hidden_dim'] self.film_layer_1_beta = nn.Linear(self.task_dim, self.fc2_in_dim, bias=False) self.film_layer_1_gamma = nn.Linear(self.task_dim, self.fc2_in_dim, bias=False) self.film_layer_2_beta = nn.Linear(self.task_dim, self.fc2_out_dim, bias=False) self.film_layer_2_gamma = nn.Linear(self.task_dim, self.fc2_out_dim, bias=False) self.dropout_rate = config['trainer_dropout'] self.dropout = nn.Dropout(self.dropout_rate) self.label_noise_std = config['label_noise_std'] def aggregate(self, z_i): return torch.mean(z_i, dim=0) def forward(self, task_embed, y, training=True, adaptation_data=None, adaptation_labels=None): # if training: t = torch.Tensor(np.random.normal(0, self.label_noise_std, size=len(y))).cuda() noise_y = t + y # C, clustered_task_embed,k_loss,rec_loss = self.cluster_module(task_embed, noise_y) C, clustered_task_embed, k_loss = self.cluster_module(task_embed, noise_y) beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) # return gamma_1,beta_1,gamma_2,beta_2,C,k_loss,rec_loss return gamma_1, beta_1, gamma_2, beta_2, C, k_loss