|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- import torch.nn.init as init
- import os
- import torch
- import pickle
- from options import config
- import gc
- import torch.nn as nn
- from torch.nn import functional as F
-
-
- class ClustringModule(torch.nn.Module):
- def __init__(self, config_param):
- super(ClustringModule, self).__init__()
- self.h1_dim = config_param['cluster_h1_dim']
- self.h2_dim = config_param['cluster_h2_dim']
- self.final_dim = config_param['cluster_final_dim']
- self.dropout_rate = config_param['cluster_dropout_rate']
-
- layers = [nn.Linear(config_param['embedding_dim'] * 8 + 1, self.h1_dim),
- torch.nn.Dropout(self.dropout_rate),
- nn.ReLU(inplace=True),
- # nn.BatchNorm1d(self.h1_dim),
- nn.Linear(self.h1_dim, self.h2_dim),
- torch.nn.Dropout(self.dropout_rate),
- nn.ReLU(inplace=True),
- # nn.BatchNorm1d(self.h2_dim),
- nn.Linear(self.h2_dim, self.final_dim)]
- self.input_to_hidden = nn.Sequential(*layers)
-
- self.clusters_k = config_param['cluster_k']
- self.embed_size = self.final_dim
- self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size)))
- self.temperature = config_param['temperature']
-
- def aggregate(self, z_i):
- return torch.mean(z_i, dim=0)
-
- def forward(self, task_embed, y, training=True):
- y = y.view(-1, 1)
- input_pairs = torch.cat((task_embed, y), dim=1)
- task_embed = self.input_to_hidden(input_pairs)
-
- # todo : may be useless
- mean_task = self.aggregate(task_embed)
-
- res = torch.norm(mean_task - self.array, p=2, dim=1, keepdim=True)
- res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2)
- # 1*k
- C = torch.transpose(res / res.sum(), 0, 1)
- # 1*k, k*d, 1*d
- value = torch.mm(C, self.array)
- # simple add operation
- # new_task_embed = value + mean_task
- # new_task_embed = value
- new_task_embed = mean_task
-
- # print("injam1:", new_task_embed)
- # print("injam2:", self.array)
- list_dist = []
- # list_dist = torch.norm(new_task_embed - self.array, p=2, dim=1,keepdim=True)
- list_dist = torch.sum(torch.pow(new_task_embed - self.array,2),dim=1)
- stack_dist = list_dist
-
- # print("injam3:", stack_dist)
-
- ## Second, find the minimum squared distance for softmax normalization
- min_dist = min(list_dist)
- # print("injam4:", min_dist)
-
- ## Third, compute exponentials shifted with min_dist to avoid underflow (0/0) issues in softmaxes
- alpha = config['kmeans_alpha'] # Placeholder tensor for alpha
- list_exp = []
- for i in range(self.clusters_k):
- exp = torch.exp(-alpha * (stack_dist[i] - min_dist))
- list_exp.append(exp)
- stack_exp = torch.stack(list_exp)
- sum_exponentials = torch.sum(stack_exp)
-
- # print("injam5:", stack_exp, sum_exponentials)
-
- ## Fourth, compute softmaxes and the embedding/representative distances weighted by softmax
- list_softmax = []
- list_weighted_dist = []
- for j in range(self.clusters_k):
- softmax = stack_exp[j] / sum_exponentials
- weighted_dist = stack_dist[j] * softmax
- list_softmax.append(softmax)
- list_weighted_dist.append(weighted_dist)
- stack_weighted_dist = torch.stack(list_weighted_dist)
-
- kmeans_loss = torch.sum(stack_weighted_dist, dim=0)
- return C, new_task_embed, kmeans_loss
-
-
- class Trainer(torch.nn.Module):
- def __init__(self, config_param, head=None):
- super(Trainer, self).__init__()
- fc1_in_dim = config_param['embedding_dim'] * 8
- fc2_in_dim = config_param['first_fc_hidden_dim']
- fc2_out_dim = config_param['second_fc_hidden_dim']
- self.fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
- self.fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
- self.linear_out = torch.nn.Linear(fc2_out_dim, 1)
- # cluster module
- self.cluster_module = ClustringModule(config_param)
- # self.task_dim = fc1_in_dim
- self.task_dim = config_param['cluster_final_dim']
- # transform task to weights
- self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
- self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
- self.film_layer_2_beta = nn.Linear(self.task_dim, fc2_out_dim, bias=False)
- self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False)
- # self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False)
- # self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False)
- # self.dropout_rate = 0
- self.dropout_rate = config_param['trainer_dropout_rate']
- self.dropout = nn.Dropout(self.dropout_rate)
-
- def aggregate(self, z_i):
- return torch.mean(z_i, dim=0)
-
- def forward(self, task_embed, y, training, adaptation_data=None, adaptation_labels=None):
- if training:
- C, clustered_task_embed, k_loss = self.cluster_module(task_embed, y)
- # hidden layers
- # todo : adding activation function or remove it
- hidden_1 = self.fc1(task_embed)
- beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
- gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
- hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
- hidden_1 = self.dropout(hidden_1)
- hidden_2 = F.relu(hidden_1)
-
- hidden_2 = self.fc2(hidden_2)
- beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
- gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))
- hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
- hidden_2 = self.dropout(hidden_2)
- hidden_3 = F.relu(hidden_2)
-
- y_pred = self.linear_out(hidden_3)
-
- else:
- C, clustered_task_embed, k_loss = self.cluster_module(adaptation_data, adaptation_labels)
- beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
- gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
- beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
- gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))
-
- hidden_1 = self.fc1(task_embed)
- hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
- hidden_1 = self.dropout(hidden_1)
- hidden_2 = F.relu(hidden_1)
-
- hidden_2 = self.fc2(hidden_2)
- hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
- hidden_2 = self.dropout(hidden_2)
- hidden_3 = F.relu(hidden_2)
-
- y_pred = self.linear_out(hidden_3)
-
- return y_pred, C, k_loss
|