|
|
@@ -17,7 +17,7 @@ class ClustringModule(torch.nn.Module): |
|
|
|
self.final_dim = 32 |
|
|
|
self.dropout_rate = 0 |
|
|
|
|
|
|
|
layers = [nn.Linear(config['embedding_dim'] * 8, self.h1_dim), |
|
|
|
layers = [nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim), |
|
|
|
torch.nn.Dropout(self.dropout_rate), |
|
|
|
nn.ReLU(inplace=True), |
|
|
|
nn.Linear(self.h1_dim, self.h2_dim), |
|
|
@@ -29,13 +29,15 @@ class ClustringModule(torch.nn.Module): |
|
|
|
self.clusters_k = 7 |
|
|
|
self.embed_size = self.final_dim |
|
|
|
self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size))) |
|
|
|
self.temperature = 1.0 |
|
|
|
self.temperature = 10.0 |
|
|
|
|
|
|
|
def aggregate(self, z_i): |
|
|
|
return torch.mean(z_i, dim=0) |
|
|
|
|
|
|
|
def forward(self, task_embed, training=True): |
|
|
|
task_embed = self.input_to_hidden(task_embed) |
|
|
|
def forward(self, task_embed, y, training=True): |
|
|
|
y = y.view(-1, 1) |
|
|
|
input_pairs = torch.cat((task_embed, y), dim=1) |
|
|
|
task_embed = self.input_to_hidden(input_pairs) |
|
|
|
|
|
|
|
# todo : may be useless |
|
|
|
mean_task = self.aggregate(task_embed) |
|
|
@@ -73,29 +75,46 @@ class Trainer(torch.nn.Module): |
|
|
|
self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False) |
|
|
|
# self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False) |
|
|
|
# self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False) |
|
|
|
self.dropout_rate = 0 |
|
|
|
self.dropout_rate = 0.1 |
|
|
|
self.dropout = nn.Dropout(self.dropout_rate) |
|
|
|
self.gamma_1, self.beta_1, self.gamma_2, self.beta_2 = None, None, None, None |
|
|
|
|
|
|
|
def aggregate(self, z_i): |
|
|
|
return torch.mean(z_i, dim=0) |
|
|
|
|
|
|
|
def forward(self, task_embed): |
|
|
|
C, clustered_task_embed = self.cluster_module(task_embed) |
|
|
|
# hidden layers |
|
|
|
# todo : adding activation function or remove it |
|
|
|
hidden_1 = self.fc1(task_embed) |
|
|
|
beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) |
|
|
|
gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) |
|
|
|
hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1 |
|
|
|
hidden_1 = self.dropout(hidden_1) |
|
|
|
hidden_2 = F.relu(hidden_1) |
|
|
|
|
|
|
|
hidden_2 = self.fc2(hidden_2) |
|
|
|
beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) |
|
|
|
gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) |
|
|
|
hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2 |
|
|
|
hidden_2 = self.dropout(hidden_2) |
|
|
|
hidden_3 = F.relu(hidden_2) |
|
|
|
|
|
|
|
y_pred = self.linear_out(hidden_3) |
|
|
|
def forward(self, task_embed, y, training): |
|
|
|
if training: |
|
|
|
C, clustered_task_embed = self.cluster_module(task_embed, y) |
|
|
|
# hidden layers |
|
|
|
# todo : adding activation function or remove it |
|
|
|
hidden_1 = self.fc1(task_embed) |
|
|
|
beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) |
|
|
|
gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) |
|
|
|
hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1 |
|
|
|
hidden_1 = self.dropout(hidden_1) |
|
|
|
hidden_2 = F.relu(hidden_1) |
|
|
|
|
|
|
|
hidden_2 = self.fc2(hidden_2) |
|
|
|
beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) |
|
|
|
gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) |
|
|
|
hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2 |
|
|
|
hidden_2 = self.dropout(hidden_2) |
|
|
|
hidden_3 = F.relu(hidden_2) |
|
|
|
|
|
|
|
y_pred = self.linear_out(hidden_3) |
|
|
|
self.gamma_1, self.beta_1, self.gamma_2, self.beta_2 = gamma_1, beta_1, gamma_2, beta_2 |
|
|
|
|
|
|
|
else: |
|
|
|
hidden_1 = self.fc1(task_embed) |
|
|
|
hidden_1 = torch.mul(hidden_1, self.gamma_1) + self.beta_1 |
|
|
|
hidden_1 = self.dropout(hidden_1) |
|
|
|
hidden_2 = F.relu(hidden_1) |
|
|
|
|
|
|
|
hidden_2 = self.fc2(hidden_2) |
|
|
|
hidden_2 = torch.mul(hidden_2, self.gamma_2) + self.beta_2 |
|
|
|
hidden_2 = self.dropout(hidden_2) |
|
|
|
hidden_3 = F.relu(hidden_2) |
|
|
|
|
|
|
|
y_pred = self.linear_out(hidden_3) |
|
|
|
|
|
|
|
return y_pred |