|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- import torch.nn.init as init
- import os
- import torch
- import pickle
- # from options import config
- import gc
- import torch.nn as nn
- from torch.nn import functional as F
- import numpy as np
-
-
- class ClustringModule(torch.nn.Module):
- def __init__(self, config):
- super(ClustringModule, self).__init__()
-
- self.final_dim = config['task_dim']
- self.dropout_rate = config['rnn_dropout']
- self.embedding_dim = config['embedding_dim']
- self.kmeans_alpha = config['kmeans_alpha']
-
- # layers = [
- # nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim),
- # # nn.Linear(config['embedding_dim'] * 8, self.h1_dim),
- # torch.nn.Dropout(self.dropout_rate),
- # nn.ReLU(inplace=True),
-
- # nn.Linear(self.h1_dim, self.h2_dim),
- # torch.nn.Dropout(self.dropout_rate),
- # nn.ReLU(inplace=True),
-
- # nn.Linear(self.h2_dim, self.final_dim),
- # ]
-
- # layers_out = [
- # nn.Linear(self.final_dim,self.h3_dim),
- # torch.nn.Dropout(self.dropout_rate),
- # nn.ReLU(inplace=True),
-
- # nn.Linear(self.h3_dim,self.h4_dim),
- # torch.nn.Dropout(self.dropout_rate),
- # nn.ReLU(inplace=True),
-
- # nn.Linear(self.h4_dim,config['embedding_dim'] * 8 + 1),
- # # nn.Linear(self.h4_dim,config['embedding_dim'] * 8),
- # # torch.nn.Dropout(self.dropout_rate),
- # # nn.ReLU(inplace=True),
- # ]
-
- # self.input_to_hidden = nn.Sequential(*layers)
- # self.hidden_to_output = nn.Sequential(*layers_out)
- # self.recon_loss = nn.MSELoss()
-
- # self.hidden_dim = 64
- # self.l1_dim = 64
- self.hidden_dim = config['rnn_hidden']
- self.l1_dim = config['rnn_l1']
-
- self.rnn = nn.LSTM(4 * config['embedding_dim'] + 1, self.hidden_dim, batch_first=True)
- layers = [
- nn.Linear(config['embedding_dim'] * 4 + self.hidden_dim, self.l1_dim),
- torch.nn.Dropout(self.dropout_rate),
- nn.ReLU(inplace=True),
-
- nn.Linear(self.l1_dim, self.final_dim),
- ]
- self.input_to_hidden = nn.Sequential(*layers)
-
- self.clusters_k = config['cluster_k']
- self.embed_size = self.final_dim
- self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size)))
- # self.array = nn.Parameter(init.zeros_(torch.FloatTensor(self.clusters_k, self.embed_size)))
- self.temperature = config['temperature']
-
- def aggregate(self, z_i):
- return torch.mean(z_i, dim=0)
-
- def forward(self, task_embed, y, training=True):
- y = y.view(-1, 1)
-
- idx = 4 * self.embedding_dim
- items = torch.cat((task_embed[:, 0:idx], y), dim=1).unsqueeze(0)
-
- output, (hn, cn) = self.rnn(items)
-
- items_embed = output.squeeze()[-1]
- user_embed = task_embed[0, idx:]
-
- task_embed = self.input_to_hidden(torch.cat((items_embed, user_embed), dim=0))
- mean_task = task_embed
-
- res = torch.norm((mean_task) - (self.array), p=2, dim=1, keepdim=True)
- res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2)
- C = torch.transpose(res / res.sum(), 0, 1)
- value = torch.mm(C, self.array)
-
- # new_task_embed = value + mean_task
- # new_task_embed = value
- new_task_embed = mean_task
-
- # compute clustering loss
- list_dist = torch.sum(torch.pow(new_task_embed - self.array, 2), dim=1)
- stack_dist = list_dist
- min_dist = min(list_dist)
- alpha = self.kmeans_alpha # Placeholder tensor for alpha
- # alpha = alphas[iteration]
- list_exp = []
- for i in range(self.clusters_k):
- exp = torch.exp(-alpha * (stack_dist[i] - min_dist))
- list_exp.append(exp)
- stack_exp = torch.stack(list_exp)
- sum_exponentials = torch.sum(stack_exp)
- list_softmax = []
- list_weighted_dist = []
- for j in range(self.clusters_k):
- softmax = stack_exp[j] / sum_exponentials
- weighted_dist = stack_dist[j] * softmax
- list_softmax.append(softmax)
- list_weighted_dist.append(weighted_dist)
- stack_weighted_dist = torch.stack(list_weighted_dist)
- kmeans_loss = torch.sum(stack_weighted_dist, dim=0)
-
- # rec_loss = self.recon_loss(input_pairs,output)
-
- # return C, new_task_embed,kmeans_loss,rec_loss
- return C, new_task_embed, kmeans_loss
-
-
- class Trainer(torch.nn.Module):
- def __init__(self, config, head=None):
- super(Trainer, self).__init__()
- # cluster module
- self.cluster_module = ClustringModule(config)
- # self.task_dim = 64
- self.task_dim = config['task_dim']
- self.fc2_in_dim = config['first_fc_hidden_dim']
- self.fc2_out_dim = config['second_fc_hidden_dim']
- self.film_layer_1_beta = nn.Linear(self.task_dim, self.fc2_in_dim, bias=False)
- self.film_layer_1_gamma = nn.Linear(self.task_dim, self.fc2_in_dim, bias=False)
- self.film_layer_2_beta = nn.Linear(self.task_dim, self.fc2_out_dim, bias=False)
- self.film_layer_2_gamma = nn.Linear(self.task_dim, self.fc2_out_dim, bias=False)
- self.dropout_rate = config['trainer_dropout']
- self.dropout = nn.Dropout(self.dropout_rate)
- self.label_noise_std = config['label_noise_std']
-
- def aggregate(self, z_i):
- return torch.mean(z_i, dim=0)
-
- def forward(self, task_embed, y, training=True, adaptation_data=None, adaptation_labels=None):
- # if training:
- t = torch.Tensor(np.random.normal(0, self.label_noise_std, size=len(y))).cuda()
- noise_y = t + y
- # C, clustered_task_embed,k_loss,rec_loss = self.cluster_module(task_embed, noise_y)
- C, clustered_task_embed, k_loss = self.cluster_module(task_embed, noise_y)
- beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
- gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
- beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
- gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))
- # return gamma_1,beta_1,gamma_2,beta_2,C,k_loss,rec_loss
- return gamma_1, beta_1, gamma_2, beta_2, C, k_loss
|