import torch.nn.init as init import os import torch import pickle from options import config import gc import torch.nn as nn from torch.nn import functional as F class ClustringModule(torch.nn.Module): def __init__(self, config_param): super(ClustringModule, self).__init__() self.h1_dim = config_param['cluster_h1_dim'] self.h2_dim = config_param['cluster_h2_dim'] self.final_dim = config_param['cluster_final_dim'] self.dropout_rate = config_param['cluster_dropout_rate'] layers = [nn.Linear(config_param['embedding_dim'] * 8 + 1, self.h1_dim), torch.nn.Dropout(self.dropout_rate), nn.ReLU(inplace=True), # nn.BatchNorm1d(self.h1_dim), nn.Linear(self.h1_dim, self.h2_dim), torch.nn.Dropout(self.dropout_rate), nn.ReLU(inplace=True), # nn.BatchNorm1d(self.h2_dim), nn.Linear(self.h2_dim, self.final_dim)] self.input_to_hidden = nn.Sequential(*layers) self.clusters_k = config_param['cluster_k'] self.embed_size = self.final_dim self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size))) self.temperature = config_param['temperature'] def aggregate(self, z_i): return torch.mean(z_i, dim=0) def forward(self, task_embed, y, training=True): y = y.view(-1, 1) input_pairs = torch.cat((task_embed, y), dim=1) task_embed = self.input_to_hidden(input_pairs) # todo : may be useless mean_task = self.aggregate(task_embed) res = torch.norm(mean_task - self.array, p=2, dim=1, keepdim=True) res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2) # 1*k C = torch.transpose(res / res.sum(), 0, 1) # 1*k, k*d, 1*d value = torch.mm(C, self.array) # simple add operation # new_task_embed = value + mean_task new_task_embed = value return C, new_task_embed class Trainer(torch.nn.Module): def __init__(self, config_param, head=None): super(Trainer, self).__init__() fc1_in_dim = config_param['embedding_dim'] * 8 fc2_in_dim = config_param['first_fc_hidden_dim'] fc2_out_dim = config_param['second_fc_hidden_dim'] self.fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) self.fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) self.linear_out = torch.nn.Linear(fc2_out_dim, 1) # cluster module self.cluster_module = ClustringModule(config_param) # self.task_dim = fc1_in_dim self.task_dim = config_param['cluster_final_dim'] # transform task to weights self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False) self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False) self.film_layer_2_beta = nn.Linear(self.task_dim, fc2_out_dim, bias=False) self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False) # self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False) # self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False) # self.dropout_rate = 0 self.dropout_rate = config_param['trainer_dropout_rate'] self.dropout = nn.Dropout(self.dropout_rate) def aggregate(self, z_i): return torch.mean(z_i, dim=0) def forward(self, task_embed, y, training, adaptation_data=None, adaptation_labels=None): if training: C, clustered_task_embed = self.cluster_module(task_embed, y) # hidden layers # todo : adding activation function or remove it hidden_1 = self.fc1(task_embed) beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1 hidden_1 = self.dropout(hidden_1) hidden_2 = F.relu(hidden_1) hidden_2 = self.fc2(hidden_2) beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2 hidden_2 = self.dropout(hidden_2) hidden_3 = F.relu(hidden_2) y_pred = self.linear_out(hidden_3) else: C, clustered_task_embed = self.cluster_module(adaptation_data, adaptation_labels) beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) hidden_1 = self.fc1(task_embed) hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1 hidden_1 = self.dropout(hidden_1) hidden_2 = F.relu(hidden_1) hidden_2 = self.fc2(hidden_2) hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2 hidden_2 = self.dropout(hidden_2) hidden_3 = F.relu(hidden_2) y_pred = self.linear_out(hidden_3) return y_pred, C