123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import torch.nn.init as init
- import os
- import torch
- import pickle
- from options import config
- import gc
- import torch.nn as nn
- from torch.nn import functional as F
-
-
- class ClustringModule(torch.nn.Module):
- def __init__(self, config):
- super(ClustringModule, self).__init__()
- self.h1_dim = 64
- self.h2_dim = 32
- # self.final_dim = fc1_in_dim
- self.final_dim = 32
- self.dropout_rate = 0
-
- layers = [nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim),
- torch.nn.Dropout(self.dropout_rate),
- nn.ReLU(inplace=True),
- nn.Linear(self.h1_dim, self.h2_dim),
- torch.nn.Dropout(self.dropout_rate),
- nn.ReLU(inplace=True),
- nn.Linear(self.h2_dim, self.final_dim)]
- self.input_to_hidden = nn.Sequential(*layers)
-
- self.clusters_k = 7
- self.embed_size = self.final_dim
- self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size)))
- self.temperature = 10.0
-
- def aggregate(self, z_i):
- return torch.mean(z_i, dim=0)
-
- def forward(self, task_embed, y, training=True):
- y = y.view(-1, 1)
- input_pairs = torch.cat((task_embed, y), dim=1)
- task_embed = self.input_to_hidden(input_pairs)
-
- # todo : may be useless
- mean_task = self.aggregate(task_embed)
-
- # C_distribution, new_task_embed = self.memoryunit(mean_task)
- res = torch.norm(mean_task - self.array, p=2, dim=1, keepdim=True)
- res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2)
- # 1*k
- C = torch.transpose(res / res.sum(), 0, 1)
- # 1*k, k*d, 1*d
- value = torch.mm(C, self.array)
- # simple add operation
- new_task_embed = value + mean_task
- # calculate target distribution
- return C, new_task_embed
-
-
- class Trainer(torch.nn.Module):
- def __init__(self, config, head=None):
- super(Trainer, self).__init__()
- fc1_in_dim = config['embedding_dim'] * 8
- fc2_in_dim = config['first_fc_hidden_dim']
- fc2_out_dim = config['second_fc_hidden_dim']
- self.fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
- self.fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
- self.linear_out = torch.nn.Linear(fc2_out_dim, 1)
- # cluster module
- self.cluster_module = ClustringModule(config)
- # self.task_dim = fc1_in_dim
- self.task_dim = 32
- # transform task to weights
- self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
- self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
- self.film_layer_2_beta = nn.Linear(self.task_dim, fc2_out_dim, bias=False)
- self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False)
- # self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False)
- # self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False)
- self.dropout_rate = 0.1
- self.dropout = nn.Dropout(self.dropout_rate)
- self.gamma_1, self.beta_1, self.gamma_2, self.beta_2 = None, None, None, None
-
- def aggregate(self, z_i):
- return torch.mean(z_i, dim=0)
-
- def forward(self, task_embed, y, training):
- if training:
- C, clustered_task_embed = self.cluster_module(task_embed, y)
- # hidden layers
- # todo : adding activation function or remove it
- hidden_1 = self.fc1(task_embed)
- beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
- gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
- hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
- hidden_1 = self.dropout(hidden_1)
- hidden_2 = F.relu(hidden_1)
-
- hidden_2 = self.fc2(hidden_2)
- beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
- gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))
- hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
- hidden_2 = self.dropout(hidden_2)
- hidden_3 = F.relu(hidden_2)
-
- y_pred = self.linear_out(hidden_3)
- self.gamma_1, self.beta_1, self.gamma_2, self.beta_2 = gamma_1, beta_1, gamma_2, beta_2
-
- else:
- hidden_1 = self.fc1(task_embed)
- hidden_1 = torch.mul(hidden_1, self.gamma_1) + self.beta_1
- hidden_1 = self.dropout(hidden_1)
- hidden_2 = F.relu(hidden_1)
-
- hidden_2 = self.fc2(hidden_2)
- hidden_2 = torch.mul(hidden_2, self.gamma_2) + self.beta_2
- hidden_2 = self.dropout(hidden_2)
- hidden_3 = F.relu(hidden_2)
-
- y_pred = self.linear_out(hidden_3)
-
- return y_pred
|