Browse Source

solve gamma and beta problem

define_task
mohamad maheri 2 years ago
parent
commit
dda2618d06
2 changed files with 19 additions and 14 deletions
  1. 17
    11
      clustering.py
  2. 2
    3
      fast_adapt.py

+ 17
- 11
clustering.py View File

@@ -11,25 +11,27 @@ from torch.nn import functional as F
class ClustringModule(torch.nn.Module):
def __init__(self, config):
super(ClustringModule, self).__init__()
self.h1_dim = 64
self.h2_dim = 32
self.h1_dim = 128
self.h2_dim = 64
# self.final_dim = fc1_in_dim
self.final_dim = 32
self.final_dim = 64
self.dropout_rate = 0

layers = [nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim),
torch.nn.Dropout(self.dropout_rate),
nn.ReLU(inplace=True),
# nn.BatchNorm1d(self.h1_dim),
nn.Linear(self.h1_dim, self.h2_dim),
torch.nn.Dropout(self.dropout_rate),
nn.ReLU(inplace=True),
# nn.BatchNorm1d(self.h2_dim),
nn.Linear(self.h2_dim, self.final_dim)]
self.input_to_hidden = nn.Sequential(*layers)

self.clusters_k = 7
self.embed_size = self.final_dim
self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size)))
self.temperature = 10.0
self.temperature = 1.0

def aggregate(self, z_i):
return torch.mean(z_i, dim=0)
@@ -67,7 +69,7 @@ class Trainer(torch.nn.Module):
# cluster module
self.cluster_module = ClustringModule(config)
# self.task_dim = fc1_in_dim
self.task_dim = 32
self.task_dim = 64
# transform task to weights
self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
@@ -75,14 +77,13 @@ class Trainer(torch.nn.Module):
self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False)
# self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False)
# self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False)
self.dropout_rate = 0.1
self.dropout_rate = 0
self.dropout = nn.Dropout(self.dropout_rate)
self.gamma_1, self.beta_1, self.gamma_2, self.beta_2 = None, None, None, None

def aggregate(self, z_i):
return torch.mean(z_i, dim=0)

def forward(self, task_embed, y, training):
def forward(self, task_embed, y, training,adaptation_data=None,adaptation_labels=None):
if training:
C, clustered_task_embed = self.cluster_module(task_embed, y)
# hidden layers
@@ -102,16 +103,21 @@ class Trainer(torch.nn.Module):
hidden_3 = F.relu(hidden_2)

y_pred = self.linear_out(hidden_3)
self.gamma_1, self.beta_1, self.gamma_2, self.beta_2 = gamma_1, beta_1, gamma_2, beta_2

else:
C, clustered_task_embed = self.cluster_module(adaptation_data, adaptation_labels)
beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))

hidden_1 = self.fc1(task_embed)
hidden_1 = torch.mul(hidden_1, self.gamma_1) + self.beta_1
hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
hidden_1 = self.dropout(hidden_1)
hidden_2 = F.relu(hidden_1)

hidden_2 = self.fc2(hidden_2)
hidden_2 = torch.mul(hidden_2, self.gamma_2) + self.beta_2
hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
hidden_2 = self.dropout(hidden_2)
hidden_3 = F.relu(hidden_2)


+ 2
- 3
fast_adapt.py View File

@@ -15,9 +15,8 @@ def fast_adapt(
train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
learn.adapt(train_error)

predictions = learn(evaluation_data, None , training=False)
# loss = torch.nn.MSELoss(reduction='mean')
# valid_error = loss(predictions, evaluation_labels)
predictions = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
adaptation_labels=adaptation_labels)
valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)

if get_predictions:

Loading…
Cancel
Save