Browse Source

solve gamma and beta problem

define_task
mohamad maheri 2 years ago
parent
commit
dda2618d06
2 changed files with 19 additions and 14 deletions
  1. 17
    11
      clustering.py
  2. 2
    3
      fast_adapt.py

+ 17
- 11
clustering.py View File

class ClustringModule(torch.nn.Module): class ClustringModule(torch.nn.Module):
def __init__(self, config): def __init__(self, config):
super(ClustringModule, self).__init__() super(ClustringModule, self).__init__()
self.h1_dim = 64
self.h2_dim = 32
self.h1_dim = 128
self.h2_dim = 64
# self.final_dim = fc1_in_dim # self.final_dim = fc1_in_dim
self.final_dim = 32
self.final_dim = 64
self.dropout_rate = 0 self.dropout_rate = 0


layers = [nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim), layers = [nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim),
torch.nn.Dropout(self.dropout_rate), torch.nn.Dropout(self.dropout_rate),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# nn.BatchNorm1d(self.h1_dim),
nn.Linear(self.h1_dim, self.h2_dim), nn.Linear(self.h1_dim, self.h2_dim),
torch.nn.Dropout(self.dropout_rate), torch.nn.Dropout(self.dropout_rate),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# nn.BatchNorm1d(self.h2_dim),
nn.Linear(self.h2_dim, self.final_dim)] nn.Linear(self.h2_dim, self.final_dim)]
self.input_to_hidden = nn.Sequential(*layers) self.input_to_hidden = nn.Sequential(*layers)


self.clusters_k = 7 self.clusters_k = 7
self.embed_size = self.final_dim self.embed_size = self.final_dim
self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size))) self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size)))
self.temperature = 10.0
self.temperature = 1.0


def aggregate(self, z_i): def aggregate(self, z_i):
return torch.mean(z_i, dim=0) return torch.mean(z_i, dim=0)
# cluster module # cluster module
self.cluster_module = ClustringModule(config) self.cluster_module = ClustringModule(config)
# self.task_dim = fc1_in_dim # self.task_dim = fc1_in_dim
self.task_dim = 32
self.task_dim = 64
# transform task to weights # transform task to weights
self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False) self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False) self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False)
self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False) self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False)
# self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False) # self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False)
# self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False) # self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False)
self.dropout_rate = 0.1
self.dropout_rate = 0
self.dropout = nn.Dropout(self.dropout_rate) self.dropout = nn.Dropout(self.dropout_rate)
self.gamma_1, self.beta_1, self.gamma_2, self.beta_2 = None, None, None, None


def aggregate(self, z_i): def aggregate(self, z_i):
return torch.mean(z_i, dim=0) return torch.mean(z_i, dim=0)


def forward(self, task_embed, y, training):
def forward(self, task_embed, y, training,adaptation_data=None,adaptation_labels=None):
if training: if training:
C, clustered_task_embed = self.cluster_module(task_embed, y) C, clustered_task_embed = self.cluster_module(task_embed, y)
# hidden layers # hidden layers
hidden_3 = F.relu(hidden_2) hidden_3 = F.relu(hidden_2)


y_pred = self.linear_out(hidden_3) y_pred = self.linear_out(hidden_3)
self.gamma_1, self.beta_1, self.gamma_2, self.beta_2 = gamma_1, beta_1, gamma_2, beta_2


else: else:
C, clustered_task_embed = self.cluster_module(adaptation_data, adaptation_labels)
beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed))

hidden_1 = self.fc1(task_embed) hidden_1 = self.fc1(task_embed)
hidden_1 = torch.mul(hidden_1, self.gamma_1) + self.beta_1
hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
hidden_1 = self.dropout(hidden_1) hidden_1 = self.dropout(hidden_1)
hidden_2 = F.relu(hidden_1) hidden_2 = F.relu(hidden_1)


hidden_2 = self.fc2(hidden_2) hidden_2 = self.fc2(hidden_2)
hidden_2 = torch.mul(hidden_2, self.gamma_2) + self.beta_2
hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
hidden_2 = self.dropout(hidden_2) hidden_2 = self.dropout(hidden_2)
hidden_3 = F.relu(hidden_2) hidden_3 = F.relu(hidden_2)



+ 2
- 3
fast_adapt.py View File

train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
learn.adapt(train_error) learn.adapt(train_error)


predictions = learn(evaluation_data, None , training=False)
# loss = torch.nn.MSELoss(reduction='mean')
# valid_error = loss(predictions, evaluation_labels)
predictions = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
adaptation_labels=adaptation_labels)
valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)


if get_predictions: if get_predictions:

Loading…
Cancel
Save