|
|
@@ -11,25 +11,27 @@ from torch.nn import functional as F |
|
|
|
class ClustringModule(torch.nn.Module): |
|
|
|
def __init__(self, config): |
|
|
|
super(ClustringModule, self).__init__() |
|
|
|
self.h1_dim = 64 |
|
|
|
self.h2_dim = 32 |
|
|
|
self.h1_dim = 128 |
|
|
|
self.h2_dim = 64 |
|
|
|
# self.final_dim = fc1_in_dim |
|
|
|
self.final_dim = 32 |
|
|
|
self.final_dim = 64 |
|
|
|
self.dropout_rate = 0 |
|
|
|
|
|
|
|
layers = [nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim), |
|
|
|
torch.nn.Dropout(self.dropout_rate), |
|
|
|
nn.ReLU(inplace=True), |
|
|
|
# nn.BatchNorm1d(self.h1_dim), |
|
|
|
nn.Linear(self.h1_dim, self.h2_dim), |
|
|
|
torch.nn.Dropout(self.dropout_rate), |
|
|
|
nn.ReLU(inplace=True), |
|
|
|
# nn.BatchNorm1d(self.h2_dim), |
|
|
|
nn.Linear(self.h2_dim, self.final_dim)] |
|
|
|
self.input_to_hidden = nn.Sequential(*layers) |
|
|
|
|
|
|
|
self.clusters_k = 7 |
|
|
|
self.embed_size = self.final_dim |
|
|
|
self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size))) |
|
|
|
self.temperature = 10.0 |
|
|
|
self.temperature = 1.0 |
|
|
|
|
|
|
|
def aggregate(self, z_i): |
|
|
|
return torch.mean(z_i, dim=0) |
|
|
@@ -67,7 +69,7 @@ class Trainer(torch.nn.Module): |
|
|
|
# cluster module |
|
|
|
self.cluster_module = ClustringModule(config) |
|
|
|
# self.task_dim = fc1_in_dim |
|
|
|
self.task_dim = 32 |
|
|
|
self.task_dim = 64 |
|
|
|
# transform task to weights |
|
|
|
self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False) |
|
|
|
self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False) |
|
|
@@ -75,14 +77,13 @@ class Trainer(torch.nn.Module): |
|
|
|
self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False) |
|
|
|
# self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False) |
|
|
|
# self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False) |
|
|
|
self.dropout_rate = 0.1 |
|
|
|
self.dropout_rate = 0 |
|
|
|
self.dropout = nn.Dropout(self.dropout_rate) |
|
|
|
self.gamma_1, self.beta_1, self.gamma_2, self.beta_2 = None, None, None, None |
|
|
|
|
|
|
|
def aggregate(self, z_i): |
|
|
|
return torch.mean(z_i, dim=0) |
|
|
|
|
|
|
|
def forward(self, task_embed, y, training): |
|
|
|
def forward(self, task_embed, y, training,adaptation_data=None,adaptation_labels=None): |
|
|
|
if training: |
|
|
|
C, clustered_task_embed = self.cluster_module(task_embed, y) |
|
|
|
# hidden layers |
|
|
@@ -102,16 +103,21 @@ class Trainer(torch.nn.Module): |
|
|
|
hidden_3 = F.relu(hidden_2) |
|
|
|
|
|
|
|
y_pred = self.linear_out(hidden_3) |
|
|
|
self.gamma_1, self.beta_1, self.gamma_2, self.beta_2 = gamma_1, beta_1, gamma_2, beta_2 |
|
|
|
|
|
|
|
else: |
|
|
|
C, clustered_task_embed = self.cluster_module(adaptation_data, adaptation_labels) |
|
|
|
beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) |
|
|
|
gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) |
|
|
|
beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) |
|
|
|
gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) |
|
|
|
|
|
|
|
hidden_1 = self.fc1(task_embed) |
|
|
|
hidden_1 = torch.mul(hidden_1, self.gamma_1) + self.beta_1 |
|
|
|
hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1 |
|
|
|
hidden_1 = self.dropout(hidden_1) |
|
|
|
hidden_2 = F.relu(hidden_1) |
|
|
|
|
|
|
|
hidden_2 = self.fc2(hidden_2) |
|
|
|
hidden_2 = torch.mul(hidden_2, self.gamma_2) + self.beta_2 |
|
|
|
hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2 |
|
|
|
hidden_2 = self.dropout(hidden_2) |
|
|
|
hidden_3 = F.relu(hidden_2) |
|
|
|
|