| import torch | |||||
| import torch.nn.functional as F | |||||
| class Head(torch.nn.Module): | |||||
| def __init__(self, config): | |||||
| super(Head, self).__init__() | |||||
| self.embedding_dim = config['embedding_dim'] | |||||
| self.fc1_in_dim = config['embedding_dim'] * 8 | |||||
| self.fc2_in_dim = config['first_fc_hidden_dim'] | |||||
| self.fc2_out_dim = config['second_fc_hidden_dim'] | |||||
| self.use_cuda = True | |||||
| self.fc1 = torch.nn.Linear(self.fc1_in_dim, self.fc2_in_dim) | |||||
| self.fc2 = torch.nn.Linear(self.fc2_in_dim, self.fc2_out_dim) | |||||
| self.linear_out = torch.nn.Linear(self.fc2_out_dim, 1) | |||||
| self.dropout_rate = config['head_dropout'] | |||||
| self.dropout = torch.nn.Dropout(self.dropout_rate) | |||||
| def forward(self, task_embed, gamma_1, beta_1, gamma_2, beta_2): | |||||
| hidden_1 = self.fc1(task_embed) | |||||
| hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1 | |||||
| hidden_1 = self.dropout(hidden_1) | |||||
| hidden_2 = F.relu(hidden_1) | |||||
| hidden_2 = self.fc2(hidden_2) | |||||
| hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2 | |||||
| hidden_2 = self.dropout(hidden_2) | |||||
| hidden_3 = F.relu(hidden_2) | |||||
| y_pred = self.linear_out(hidden_3) | |||||
| return y_pred |
| import os | import os | ||||
| import torch | import torch | ||||
| import pickle | import pickle | ||||
| from options import config | |||||
| # from options import config | |||||
| import gc | import gc | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from torch.nn import functional as F | from torch.nn import functional as F | ||||
| import numpy as np | |||||
| class ClustringModule(torch.nn.Module): | class ClustringModule(torch.nn.Module): | ||||
| def __init__(self, config_param): | |||||
| def __init__(self, config): | |||||
| super(ClustringModule, self).__init__() | super(ClustringModule, self).__init__() | ||||
| self.h1_dim = config_param['cluster_h1_dim'] | |||||
| self.h2_dim = config_param['cluster_h2_dim'] | |||||
| self.final_dim = config_param['cluster_final_dim'] | |||||
| self.dropout_rate = config_param['cluster_dropout_rate'] | |||||
| self.final_dim = config['task_dim'] | |||||
| self.dropout_rate = config['rnn_dropout'] | |||||
| self.embedding_dim = config['embedding_dim'] | |||||
| self.kmeans_alpha = config['kmeans_alpha'] | |||||
| # layers = [ | |||||
| # nn.Linear(config['embedding_dim'] * 8 + 1, self.h1_dim), | |||||
| # # nn.Linear(config['embedding_dim'] * 8, self.h1_dim), | |||||
| # torch.nn.Dropout(self.dropout_rate), | |||||
| # nn.ReLU(inplace=True), | |||||
| # nn.Linear(self.h1_dim, self.h2_dim), | |||||
| # torch.nn.Dropout(self.dropout_rate), | |||||
| # nn.ReLU(inplace=True), | |||||
| # nn.Linear(self.h2_dim, self.final_dim), | |||||
| # ] | |||||
| # layers_out = [ | |||||
| # nn.Linear(self.final_dim,self.h3_dim), | |||||
| # torch.nn.Dropout(self.dropout_rate), | |||||
| # nn.ReLU(inplace=True), | |||||
| # nn.Linear(self.h3_dim,self.h4_dim), | |||||
| # torch.nn.Dropout(self.dropout_rate), | |||||
| # nn.ReLU(inplace=True), | |||||
| # nn.Linear(self.h4_dim,config['embedding_dim'] * 8 + 1), | |||||
| # # nn.Linear(self.h4_dim,config['embedding_dim'] * 8), | |||||
| # # torch.nn.Dropout(self.dropout_rate), | |||||
| # # nn.ReLU(inplace=True), | |||||
| # ] | |||||
| # self.input_to_hidden = nn.Sequential(*layers) | |||||
| # self.hidden_to_output = nn.Sequential(*layers_out) | |||||
| # self.recon_loss = nn.MSELoss() | |||||
| # self.hidden_dim = 64 | |||||
| # self.l1_dim = 64 | |||||
| self.hidden_dim = config['rnn_hidden'] | |||||
| self.l1_dim = config['rnn_l1'] | |||||
| self.rnn = nn.LSTM(4 * config['embedding_dim'] + 1, self.hidden_dim, batch_first=True) | |||||
| layers = [ | layers = [ | ||||
| # nn.Linear(config_param['embedding_dim'] * 8 + 1, self.h1_dim), | |||||
| nn.Linear(config_param['embedding_dim'] * 8, self.h1_dim), | |||||
| torch.nn.Dropout(self.dropout_rate), | |||||
| nn.ReLU(inplace=True), | |||||
| # nn.BatchNorm1d(self.h1_dim), | |||||
| nn.Linear(self.h1_dim, self.h2_dim), | |||||
| torch.nn.Dropout(self.dropout_rate), | |||||
| nn.ReLU(inplace=True), | |||||
| # nn.BatchNorm1d(self.h2_dim), | |||||
| nn.Linear(self.h2_dim, self.final_dim)] | |||||
| nn.Linear(config['embedding_dim'] * 4 + self.hidden_dim, self.l1_dim), | |||||
| torch.nn.Dropout(self.dropout_rate), | |||||
| nn.ReLU(inplace=True), | |||||
| nn.Linear(self.l1_dim, self.final_dim), | |||||
| ] | |||||
| self.input_to_hidden = nn.Sequential(*layers) | self.input_to_hidden = nn.Sequential(*layers) | ||||
| self.clusters_k = config_param['cluster_k'] | |||||
| self.clusters_k = config['cluster_k'] | |||||
| self.embed_size = self.final_dim | self.embed_size = self.final_dim | ||||
| self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size))) | self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size))) | ||||
| self.temperature = config_param['temperature'] | |||||
| # self.array = nn.Parameter(init.zeros_(torch.FloatTensor(self.clusters_k, self.embed_size))) | |||||
| self.temperature = config['temperature'] | |||||
| def aggregate(self, z_i): | def aggregate(self, z_i): | ||||
| return torch.mean(z_i, dim=0) | return torch.mean(z_i, dim=0) | ||||
| def forward(self, task_embed, y, training=True): | def forward(self, task_embed, y, training=True): | ||||
| y = y.view(-1, 1) | y = y.view(-1, 1) | ||||
| high_idx = y > 3 | |||||
| high_idx = high_idx.squeeze() | |||||
| if high_idx.sum() > 0: | |||||
| input_pairs = task_embed.detach()[high_idx] | |||||
| else: | |||||
| input_pairs = torch.ones(size=(1, 8 * config['embedding_dim'])).cuda() | |||||
| print("found") | |||||
| # input_pairs = torch.cat((task_embed, y), dim=1) | |||||
| task_embed = self.input_to_hidden(input_pairs) | |||||
| idx = 4 * self.embedding_dim | |||||
| items = torch.cat((task_embed[:, 0:idx], y), dim=1).unsqueeze(0) | |||||
| output, (hn, cn) = self.rnn(items) | |||||
| items_embed = output.squeeze()[-1] | |||||
| user_embed = task_embed[0, idx:] | |||||
| # todo : may be useless | |||||
| mean_task = self.aggregate(task_embed) | |||||
| task_embed = self.input_to_hidden(torch.cat((items_embed, user_embed), dim=0)) | |||||
| mean_task = task_embed | |||||
| res = torch.norm(mean_task - self.array, p=2, dim=1, keepdim=True) | |||||
| res = torch.norm((mean_task) - (self.array), p=2, dim=1, keepdim=True) | |||||
| res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2) | res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2) | ||||
| # 1*k | |||||
| C = torch.transpose(res / res.sum(), 0, 1) | C = torch.transpose(res / res.sum(), 0, 1) | ||||
| # 1*k, k*d, 1*d | |||||
| value = torch.mm(C, self.array) | value = torch.mm(C, self.array) | ||||
| # simple add operation | |||||
| # new_task_embed = value + mean_task | # new_task_embed = value + mean_task | ||||
| # new_task_embed = value | # new_task_embed = value | ||||
| new_task_embed = mean_task | new_task_embed = mean_task | ||||
| # print("injam1:", new_task_embed) | |||||
| # print("injam2:", self.array) | |||||
| list_dist = [] | |||||
| # list_dist = torch.norm(new_task_embed - self.array, p=2, dim=1,keepdim=True) | |||||
| list_dist = torch.sum(torch.pow(new_task_embed - self.array,2),dim=1) | |||||
| # compute clustering loss | |||||
| list_dist = torch.sum(torch.pow(new_task_embed - self.array, 2), dim=1) | |||||
| stack_dist = list_dist | stack_dist = list_dist | ||||
| # print("injam3:", stack_dist) | |||||
| ## Second, find the minimum squared distance for softmax normalization | |||||
| min_dist = min(list_dist) | min_dist = min(list_dist) | ||||
| # print("injam4:", min_dist) | |||||
| ## Third, compute exponentials shifted with min_dist to avoid underflow (0/0) issues in softmaxes | |||||
| alpha = config['kmeans_alpha'] # Placeholder tensor for alpha | |||||
| alpha = self.kmeans_alpha # Placeholder tensor for alpha | |||||
| # alpha = alphas[iteration] | |||||
| list_exp = [] | list_exp = [] | ||||
| for i in range(self.clusters_k): | for i in range(self.clusters_k): | ||||
| exp = torch.exp(-alpha * (stack_dist[i] - min_dist)) | exp = torch.exp(-alpha * (stack_dist[i] - min_dist)) | ||||
| list_exp.append(exp) | list_exp.append(exp) | ||||
| stack_exp = torch.stack(list_exp) | stack_exp = torch.stack(list_exp) | ||||
| sum_exponentials = torch.sum(stack_exp) | sum_exponentials = torch.sum(stack_exp) | ||||
| # print("injam5:", stack_exp, sum_exponentials) | |||||
| ## Fourth, compute softmaxes and the embedding/representative distances weighted by softmax | |||||
| list_softmax = [] | list_softmax = [] | ||||
| list_weighted_dist = [] | list_weighted_dist = [] | ||||
| for j in range(self.clusters_k): | for j in range(self.clusters_k): | ||||
| list_softmax.append(softmax) | list_softmax.append(softmax) | ||||
| list_weighted_dist.append(weighted_dist) | list_weighted_dist.append(weighted_dist) | ||||
| stack_weighted_dist = torch.stack(list_weighted_dist) | stack_weighted_dist = torch.stack(list_weighted_dist) | ||||
| kmeans_loss = torch.sum(stack_weighted_dist, dim=0) | kmeans_loss = torch.sum(stack_weighted_dist, dim=0) | ||||
| # rec_loss = self.recon_loss(input_pairs,output) | |||||
| # return C, new_task_embed,kmeans_loss,rec_loss | |||||
| return C, new_task_embed, kmeans_loss | return C, new_task_embed, kmeans_loss | ||||
| class Trainer(torch.nn.Module): | class Trainer(torch.nn.Module): | ||||
| def __init__(self, config_param, head=None): | |||||
| def __init__(self, config, head=None): | |||||
| super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
| fc1_in_dim = config_param['embedding_dim'] * 8 | |||||
| fc2_in_dim = config_param['first_fc_hidden_dim'] | |||||
| fc2_out_dim = config_param['second_fc_hidden_dim'] | |||||
| self.fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) | |||||
| self.fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) | |||||
| self.linear_out = torch.nn.Linear(fc2_out_dim, 1) | |||||
| # cluster module | # cluster module | ||||
| self.cluster_module = ClustringModule(config_param) | |||||
| # self.task_dim = fc1_in_dim | |||||
| self.task_dim = config_param['cluster_final_dim'] | |||||
| # transform task to weights | |||||
| self.film_layer_1_beta = nn.Linear(self.task_dim, fc2_in_dim, bias=False) | |||||
| self.film_layer_1_gamma = nn.Linear(self.task_dim, fc2_in_dim, bias=False) | |||||
| self.film_layer_2_beta = nn.Linear(self.task_dim, fc2_out_dim, bias=False) | |||||
| self.film_layer_2_gamma = nn.Linear(self.task_dim, fc2_out_dim, bias=False) | |||||
| # self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False) | |||||
| # self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False) | |||||
| # self.dropout_rate = 0 | |||||
| self.dropout_rate = config_param['trainer_dropout_rate'] | |||||
| self.cluster_module = ClustringModule(config) | |||||
| # self.task_dim = 64 | |||||
| self.task_dim = config['task_dim'] | |||||
| self.fc2_in_dim = config['first_fc_hidden_dim'] | |||||
| self.fc2_out_dim = config['second_fc_hidden_dim'] | |||||
| self.film_layer_1_beta = nn.Linear(self.task_dim, self.fc2_in_dim, bias=False) | |||||
| self.film_layer_1_gamma = nn.Linear(self.task_dim, self.fc2_in_dim, bias=False) | |||||
| self.film_layer_2_beta = nn.Linear(self.task_dim, self.fc2_out_dim, bias=False) | |||||
| self.film_layer_2_gamma = nn.Linear(self.task_dim, self.fc2_out_dim, bias=False) | |||||
| self.dropout_rate = config['trainer_dropout'] | |||||
| self.dropout = nn.Dropout(self.dropout_rate) | self.dropout = nn.Dropout(self.dropout_rate) | ||||
| self.label_noise_std = config['label_noise_std'] | |||||
| def aggregate(self, z_i): | def aggregate(self, z_i): | ||||
| return torch.mean(z_i, dim=0) | return torch.mean(z_i, dim=0) | ||||
| def forward(self, task_embed, y, training, adaptation_data=None, adaptation_labels=None): | |||||
| if training: | |||||
| C, clustered_task_embed, k_loss = self.cluster_module(task_embed, y) | |||||
| # hidden layers | |||||
| # todo : adding activation function or remove it | |||||
| hidden_1 = self.fc1(task_embed) | |||||
| beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) | |||||
| gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) | |||||
| hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1 | |||||
| hidden_1 = self.dropout(hidden_1) | |||||
| hidden_2 = F.relu(hidden_1) | |||||
| hidden_2 = self.fc2(hidden_2) | |||||
| beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) | |||||
| gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) | |||||
| hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2 | |||||
| hidden_2 = self.dropout(hidden_2) | |||||
| hidden_3 = F.relu(hidden_2) | |||||
| y_pred = self.linear_out(hidden_3) | |||||
| else: | |||||
| C, clustered_task_embed, k_loss = self.cluster_module(adaptation_data, adaptation_labels) | |||||
| beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) | |||||
| gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) | |||||
| beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) | |||||
| gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) | |||||
| hidden_1 = self.fc1(task_embed) | |||||
| hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1 | |||||
| hidden_1 = self.dropout(hidden_1) | |||||
| hidden_2 = F.relu(hidden_1) | |||||
| hidden_2 = self.fc2(hidden_2) | |||||
| hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2 | |||||
| hidden_2 = self.dropout(hidden_2) | |||||
| hidden_3 = F.relu(hidden_2) | |||||
| y_pred = self.linear_out(hidden_3) | |||||
| return y_pred, C, k_loss | |||||
| def forward(self, task_embed, y, training=True, adaptation_data=None, adaptation_labels=None): | |||||
| # if training: | |||||
| t = torch.Tensor(np.random.normal(0, self.label_noise_std, size=len(y))).cuda() | |||||
| noise_y = t + y | |||||
| # C, clustered_task_embed,k_loss,rec_loss = self.cluster_module(task_embed, noise_y) | |||||
| C, clustered_task_embed, k_loss = self.cluster_module(task_embed, noise_y) | |||||
| beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) | |||||
| gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) | |||||
| beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) | |||||
| gamma_2 = torch.tanh(self.film_layer_2_gamma(clustered_task_embed)) | |||||
| # return gamma_1,beta_1,gamma_2,beta_2,C,k_loss,rec_loss | |||||
| return gamma_1, beta_1, gamma_2, beta_2, C, k_loss |
| def fast_adapt( | def fast_adapt( | ||||
| learn, | |||||
| head, | |||||
| adaptation_data, | adaptation_data, | ||||
| evaluation_data, | evaluation_data, | ||||
| adaptation_labels, | adaptation_labels, | ||||
| evaluation_labels, | evaluation_labels, | ||||
| adaptation_steps, | adaptation_steps, | ||||
| get_predictions=False, | get_predictions=False, | ||||
| epoch=None): | |||||
| is_print = random.random() < 0.05 | |||||
| trainer=None, | |||||
| test=False, | |||||
| iteration=None): | |||||
| for step in range(adaptation_steps): | for step in range(adaptation_steps): | ||||
| temp, c, k_loss = learn(adaptation_data, adaptation_labels, training=True) | |||||
| # g1,b1,g2,b2,c,k_loss,rec_loss = trainer(adaptation_data,adaptation_labels,training=True) | |||||
| g1, b1, g2, b2, c, k_loss = trainer(adaptation_data, adaptation_labels, training=True) | |||||
| temp = head(adaptation_data, g1, b1, g2, b2) | |||||
| train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) | train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) | ||||
| # cluster_loss = cl_loss(c) | |||||
| # total_loss = train_error + config['cluster_loss_weight'] * cluster_loss | |||||
| total_loss = train_error + config['kmeans_loss_weight'] * k_loss | |||||
| learn.adapt(total_loss) | |||||
| # train_error = train_error + config['kmeans_loss_weight'] * k_loss + config['rec_loss_weight']*rec_loss | |||||
| train_error = train_error + config['kmeans_loss_weight'] * k_loss | |||||
| head.adapt(train_error) | |||||
| predictions, c, k_loss = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data, | |||||
| adaptation_labels=adaptation_labels) | |||||
| # g1,b1,g2,b2,c,k_loss,rec_loss = trainer(adaptation_data,adaptation_labels,training=False) | |||||
| g1, b1, g2, b2, c, k_loss = trainer(adaptation_data, adaptation_labels, training=False) | |||||
| predictions = head(evaluation_data, g1, b1, g2, b2) | |||||
| valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) | valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) | ||||
| # cluster_loss = cl_loss(c) | |||||
| # total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss | |||||
| # total_loss = valid_error + config['kmeans_loss_weight'] * k_loss + config['rec_loss_weight']*rec_loss | |||||
| total_loss = valid_error + config['kmeans_loss_weight'] * k_loss | total_loss = valid_error + config['kmeans_loss_weight'] * k_loss | ||||
| if is_print: | |||||
| # print("in query:\t", round(k_loss.item(),4)) | |||||
| print(c[0].detach().cpu().numpy(),"\t",round(k_loss.item(),3),"\n") | |||||
| # if random.random() < 0.05: | |||||
| # print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy()) | |||||
| if get_predictions: | if get_predictions: | ||||
| return total_loss, predictions | |||||
| return total_loss, c, k_loss.item() | |||||
| return predictions.detach().cpu(), c | |||||
| # return total_loss,c,k_loss.detach().cpu().item(),rec_loss.detach().cpu().item() | |||||
| return total_loss, c, k_loss.detach().cpu().item() |
| from functools import partial | from functools import partial | ||||
| from hyper_tunning import train_melu | from hyper_tunning import train_melu | ||||
| import numpy as np | import numpy as np | ||||
| import torch | |||||
| def main(num_samples, max_num_epochs=20, gpus_per_trial=2): | def main(num_samples, max_num_epochs=20, gpus_per_trial=2): | ||||
| data_dir = os.path.abspath("/media/external_10TB/10TB/maheri/define_task_melu_data") | |||||
| load_data(data_dir) | |||||
| data_dir = os.path.abspath("/media/external_10TB/10TB/maheri/new_data_dir3") | |||||
| # load_data(data_dir) | |||||
| config = { | config = { | ||||
| # "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | |||||
| # "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | |||||
| # "lr": tune.loguniform(1e-4, 1e-1), | |||||
| # "batch_size": tune.choice([2, 4, 8, 16]) | |||||
| "transformer": tune.choice(['kronoker']), | |||||
| "meta_algo": tune.choice(['gbml']), | |||||
| "first_order": tune.choice([False]), | |||||
| "adapt_transform": tune.choice([True, False]), | |||||
| # "local_lr":tune.choice([5e-6,5e-4,5e-3]), | |||||
| # "lr":tune.choice([5e-5,5e-4]), | |||||
| # meta learning | |||||
| "meta_algo": tune.choice(['metasgd']), | |||||
| "transformer": tune.choice(['metasgd']), | |||||
| "first_order": tune.choice([True]), | |||||
| "adapt_transform": tune.choice([False]), | |||||
| "local_lr": tune.loguniform(5e-6, 5e-3), | "local_lr": tune.loguniform(5e-6, 5e-3), | ||||
| "lr": tune.loguniform(5e-5, 5e-3), | "lr": tune.loguniform(5e-5, 5e-3), | ||||
| "batch_size": tune.choice([16, 32, 64]), | "batch_size": tune.choice([16, 32, 64]), | ||||
| "inner": tune.choice([7, 5, 4, 3, 1]), | |||||
| "inner": tune.choice([1, 3, 4, 5, 7]), | |||||
| "test_state": tune.choice(["user_and_item_cold_state"]), | "test_state": tune.choice(["user_and_item_cold_state"]), | ||||
| # head | |||||
| "embedding_dim": tune.choice([16, 32, 64]), | "embedding_dim": tune.choice([16, 32, 64]), | ||||
| "first_fc_hidden_dim": tune.choice([32, 64, 128]), | "first_fc_hidden_dim": tune.choice([32, 64, 128]), | ||||
| "second_fc_hidden_dim": tune.choice([32, 64]), | "second_fc_hidden_dim": tune.choice([32, 64]), | ||||
| 'cluster_h1_dim': tune.choice([256, 128, 64]), | |||||
| 'cluster_h2_dim': tune.choice([128, 64, 32]), | |||||
| 'cluster_final_dim': tune.choice([64, 32]), | |||||
| # clustering module | |||||
| 'cluster_dropout_rate': tune.choice([0, 0.01, 0.1]), | 'cluster_dropout_rate': tune.choice([0, 0.01, 0.1]), | ||||
| 'cluster_k': tune.choice([3, 5, 7, 9, 11]), | 'cluster_k': tune.choice([3, 5, 7, 9, 11]), | ||||
| 'temperature': tune.choice([0.1, 0.5, 1.0, 2.0, 10.0]), | |||||
| 'trainer_dropout_rate': tune.choice([0, 0.01, 0.1]), | |||||
| 'kmeans_alpha': tune.choice([100, 0.1, 10, 20, 50, 200]), | |||||
| 'rnn_dropout': tune.choice([0, 0.01, 0.1]), | |||||
| 'rnn_hidden': tune.choice([32, 64, 128]), | |||||
| 'rnn_l1': tune.choice([32, 64, 128]), | |||||
| 'kmeans_loss_weight': tune.choice([0, 1, 10, 50, 100, 200]), | |||||
| 'temperature': tune.choice([0.1, 0.5, 1.0, 2.0, 5.0, 10.0]), | |||||
| # 'trainer_dropout_rate': tune.choice([0, 0.01, 0.1]), | |||||
| 'distribution_power': tune.choice([0.1, 0.8, 1, 3, 5, 7, 8, 9]), | |||||
| 'data_selection_pow': tune.choice([0.6, 0.65, 0.7, 0.75, 0.8, 0.9, 1, 1.1, 1.2, 1.4]), | |||||
| 'task_dim': tune.choice([16, 32, 64, 128, 256]), | |||||
| 'trainer_dropout': tune.choice([0, 0.001, 0.01, 0.05, 0.1]), | |||||
| 'label_noise_std': tune.choice([0, 0.01, 0.1, 0.2, 0.3, 1, 2]), | |||||
| 'head_dropout': tune.choice([0, 0.001, 0.01, 0.05, 0.1]), | |||||
| 'num_epoch': tune.choice([40]), | |||||
| 'use_cuda': tune.choice([True]), | |||||
| 'num_rate': tune.choice([6]), | |||||
| 'num_genre': tune.choice([25]), | |||||
| 'num_director': tune.choice([2186]), | |||||
| 'num_actor': tune.choice([8030]), | |||||
| 'num_gender': tune.choice([2]), | |||||
| 'num_age': tune.choice([7]), | |||||
| 'num_occupation': tune.choice([21]), | |||||
| 'num_zipcode': tune.choice([3402]), | |||||
| } | } | ||||
| scheduler = ASHAScheduler( | scheduler = ASHAScheduler( | ||||
| metric="loss", | metric="loss", | ||||
| mode="min", | mode="min", | ||||
| max_t=30, | |||||
| max_t=max_num_epochs, | |||||
| grace_period=10, | grace_period=10, | ||||
| reduction_factor=2) | reduction_factor=2) | ||||
| reporter = CLIReporter( | reporter = CLIReporter( | ||||
| metric_columns=["loss", "ndcg1", "ndcg3", "training_iteration"]) | metric_columns=["loss", "ndcg1", "ndcg3", "training_iteration"]) | ||||
| result = tune.run( | result = tune.run( | ||||
| partial(train_melu, data_dir=data_dir), | partial(train_melu, data_dir=data_dir), | ||||
| resources_per_trial={"cpu": 4, "gpu": gpus_per_trial}, | |||||
| resources_per_trial={"cpu": 4, "gpu": 0.5}, | |||||
| config=config, | config=config, | ||||
| num_samples=num_samples, | num_samples=num_samples, | ||||
| scheduler=scheduler, | scheduler=scheduler, | ||||
| progress_reporter=reporter, | progress_reporter=reporter, | ||||
| log_to_file=True, | log_to_file=True, | ||||
| # resume=True, | # resume=True, | ||||
| local_dir="./hyper_tunning_all_cold", | |||||
| name="melu_all_cold_clustered", | |||||
| local_dir="./hyper_tunning_all_cold3", | |||||
| name="rnn_cluster_module", | |||||
| ) | ) | ||||
| best_trial = result.get_best_trial("loss", "min", "last") | best_trial = result.get_best_trial("loss", "min", "last") | ||||
| print(result.results_df) | print(result.results_df) | ||||
| print("=======================================================\n") | print("=======================================================\n") | ||||
| # best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"]) | |||||
| # device = "cpu" | |||||
| # if torch.cuda.is_available(): | |||||
| # device = "cuda:0" | |||||
| # if gpus_per_trial > 1: | |||||
| # best_trained_model = nn.DataParallel(best_trained_model) | |||||
| # best_trained_model.to(device) | |||||
| # | |||||
| # best_checkpoint_dir = best_trial.checkpoint.value | |||||
| # model_state, optimizer_state = torch.load(os.path.join( | |||||
| # best_checkpoint_dir, "checkpoint")) | |||||
| # best_trained_model.load_state_dict(model_state) | |||||
| # | |||||
| # test_acc = test_accuracy(best_trained_model, device) | |||||
| # print("Best trial test set accuracy: {}".format(test_acc)) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| # You can change the number of GPUs per trial here: | # You can change the number of GPUs per trial here: | ||||
| main(num_samples=150, max_num_epochs=25, gpus_per_trial=1) | |||||
| main(num_samples=150, max_num_epochs=50, gpus_per_trial=1) |
| import numpy as np | import numpy as np | ||||
| from fast_adapt import fast_adapt | from fast_adapt import fast_adapt | ||||
| from sklearn.metrics import ndcg_score | from sklearn.metrics import ndcg_score | ||||
| import gc | |||||
| import pickle | |||||
| import os | |||||
| def hyper_test(embedding, head, total_dataset, adaptation_step): | |||||
| test_set_size = len(total_dataset) | |||||
| random.shuffle(total_dataset) | |||||
| a, b, c, d = zip(*total_dataset) | |||||
| def hyper_test(embedding, head, trainer, batch_size, master_path, test_state, adaptation_step, num_epoch=None): | |||||
| test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) | |||||
| indexes = list(np.arange(test_set_size)) | |||||
| random.shuffle(indexes) | |||||
| # test_set_size = len(total_dataset) | |||||
| # random.shuffle(total_dataset) | |||||
| # a, b, c, d = zip(*total_dataset) | |||||
| # a, b, c, d = list(a), list(b), list(c), list(d) | |||||
| losses_q = [] | losses_q = [] | ||||
| ndcgs11 = [] | ndcgs11 = [] | ||||
| ndcgs33 = [] | ndcgs33 = [] | ||||
| head.eval() | head.eval() | ||||
| trainer.eval() | |||||
| for iterator in range(test_set_size): | for iterator in range(test_set_size): | ||||
| a = pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, iterator), "rb")) | |||||
| b = pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, iterator), "rb")) | |||||
| c = pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, iterator), "rb")) | |||||
| d = pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, iterator), "rb")) | |||||
| try: | try: | ||||
| supp_xs = a[iterator].cuda() | |||||
| supp_ys = b[iterator].cuda() | |||||
| query_xs = c[iterator].cuda() | |||||
| query_ys = d[iterator].cuda() | |||||
| supp_xs = a.cuda() | |||||
| supp_ys = b.cuda() | |||||
| query_xs = c.cuda() | |||||
| query_ys = d.cuda() | |||||
| except IndexError: | except IndexError: | ||||
| print("index error in test method") | print("index error in test method") | ||||
| continue | continue | ||||
| temp_sxs = embedding(supp_xs) | temp_sxs = embedding(supp_xs) | ||||
| temp_qxs = embedding(query_xs) | temp_qxs = embedding(query_xs) | ||||
| evaluation_error, predictions = fast_adapt(learner, | |||||
| temp_sxs, | |||||
| temp_qxs, | |||||
| supp_ys, | |||||
| query_ys, | |||||
| adaptation_step, | |||||
| get_predictions=True) | |||||
| predictions, c = fast_adapt( | |||||
| learner, | |||||
| temp_sxs, | |||||
| temp_qxs, | |||||
| supp_ys, | |||||
| query_ys, | |||||
| adaptation_step, | |||||
| get_predictions=True, | |||||
| trainer=trainer, | |||||
| test=True, | |||||
| iteration=num_epoch | |||||
| ) | |||||
| l1 = L1Loss(reduction='mean') | l1 = L1Loss(reduction='mean') | ||||
| loss_q = l1(predictions.view(-1), query_ys) | |||||
| loss_q = l1(predictions.view(-1), query_ys.cpu()) | |||||
| losses_q.append(float(loss_q)) | losses_q.append(float(loss_q)) | ||||
| predictions = predictions.view(-1) | predictions = predictions.view(-1) | ||||
| y_true = query_ys.cpu().detach().numpy() | y_true = query_ys.cpu().detach().numpy() | ||||
| ndcg3 = 0 | ndcg3 = 0 | ||||
| head.train() | head.train() | ||||
| trainer.train() | |||||
| gc.collect() | |||||
| return losses_q, ndcg1, ndcg3 | return losses_q, ndcg1, ndcg3 |
| import torch.nn as nn | import torch.nn as nn | ||||
| from ray import tune | from ray import tune | ||||
| import pickle | import pickle | ||||
| from options import config | |||||
| from embedding_module import EmbeddingModule | from embedding_module import EmbeddingModule | ||||
| import learn2learn as l2l | import learn2learn as l2l | ||||
| import random | import random | ||||
| from learn2learn.optim.transforms import KroneckerTransform | from learn2learn.optim.transforms import KroneckerTransform | ||||
| from hyper_testing import hyper_test | from hyper_testing import hyper_test | ||||
| from clustering import Trainer | from clustering import Trainer | ||||
| from Head import Head | |||||
| import numpy as np | |||||
| # Define paths (for data) | # Define paths (for data) | ||||
| # master_path= "/media/external_10TB/10TB/maheri/melu_data5" | # master_path= "/media/external_10TB/10TB/maheri/melu_data5" | ||||
| def load_data(data_dir=None, test_state='warm_state'): | def load_data(data_dir=None, test_state='warm_state'): | ||||
| training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4) | |||||
| supp_xs_s = [] | |||||
| supp_ys_s = [] | |||||
| query_xs_s = [] | |||||
| query_ys_s = [] | |||||
| for idx in range(training_set_size): | |||||
| supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||||
| del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||||
| trainset = total_dataset | |||||
| # training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4) | |||||
| # supp_xs_s = [] | |||||
| # supp_ys_s = [] | |||||
| # query_xs_s = [] | |||||
| # query_ys_s = [] | |||||
| # for idx in range(training_set_size): | |||||
| # supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| # supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| # query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| # query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| # total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||||
| # del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||||
| # trainset = total_dataset | |||||
| test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4) | test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4) | ||||
| supp_xs_s = [] | supp_xs_s = [] | ||||
| del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | ||||
| random.shuffle(test_dataset) | random.shuffle(test_dataset) | ||||
| random.shuffle(trainset) | |||||
| val_size = int(test_set_size * 0.2) | |||||
| # random.shuffle(trainset) | |||||
| val_size = int(test_set_size * 0.3) | |||||
| validationset = test_dataset[:val_size] | validationset = test_dataset[:val_size] | ||||
| testset = test_dataset[val_size:] | |||||
| # testset = test_dataset[val_size:] | |||||
| return trainset, validationset, testset | |||||
| return None, validationset, None | |||||
| def train_melu(conf, checkpoint_dir=None, data_dir=None): | |||||
| print("inajm1:", checkpoint_dir) | |||||
| embedding_dim = conf['embedding_dim'] | |||||
| fc1_in_dim = conf['embedding_dim'] * 8 | |||||
| fc2_in_dim = conf['first_fc_hidden_dim'] | |||||
| fc2_out_dim = conf['second_fc_hidden_dim'] | |||||
| def data_batching_new(indexes, C_distribs, batch_size, training_set_size, num_clusters,config): | |||||
| probs = np.squeeze(C_distribs) | |||||
| probs = np.array(probs) ** config['distribution_power'] / np.sum(np.array(probs) ** config['distribution_power'], | |||||
| axis=1, keepdims=True) | |||||
| cs = [np.random.choice(num_clusters, p=i) for i in probs] | |||||
| num_batch = int(training_set_size / batch_size) | |||||
| res = [[] for i in range(num_batch)] | |||||
| clas = [[] for i in range(num_clusters)] | |||||
| clas_temp = [[] for i in range(num_clusters)] | |||||
| for idx, c in zip(indexes, cs): | |||||
| clas[c].append(idx) | |||||
| for i in range(num_clusters): | |||||
| random.shuffle(clas[i]) | |||||
| # t = np.array([len(i) for i in clas]) | |||||
| t = np.array([len(i) ** config['data_selection_pow'] for i in clas]) | |||||
| t = t / t.sum() | |||||
| dif = list(set(list(np.arange(training_set_size))) - set(indexes[0:(num_batch * batch_size)])) | |||||
| cnt = 0 | |||||
| for i in range(len(res)): | |||||
| for j in range(batch_size): | |||||
| temp = np.random.choice(num_clusters, p=t) | |||||
| if len(clas[temp]) > 0: | |||||
| selected = clas[temp].pop(0) | |||||
| res[i].append(selected) | |||||
| clas_temp[temp].append(selected) | |||||
| else: | |||||
| # res[i].append(indexes[training_set_size-1-cnt]) | |||||
| if len(dif) > 0: | |||||
| if random.random() < 0.5 or len(clas_temp[temp]) == 0: | |||||
| res[i].append(dif.pop(0)) | |||||
| else: | |||||
| selected = clas_temp[temp].pop(0) | |||||
| clas_temp[temp].append(selected) | |||||
| res[i].append(selected) | |||||
| else: | |||||
| selected = clas_temp[temp].pop(0) | |||||
| res[i].append(selected) | |||||
| # fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) | |||||
| # fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) | |||||
| # linear_out = torch.nn.Linear(fc2_out_dim, 1) | |||||
| # head = torch.nn.Sequential(fc1, fc2, linear_out) | |||||
| cnt = cnt + 1 | |||||
| emb = EmbeddingModule(config).cuda() | |||||
| print("data_batching : ", cnt) | |||||
| res = np.random.permutation(res) | |||||
| final_result = np.array(res).flatten() | |||||
| return final_result | |||||
| def train_melu(conf, checkpoint_dir=None, data_dir=None): | |||||
| config = conf | |||||
| master_path = data_dir | |||||
| emb = EmbeddingModule(conf).cuda() | |||||
| transform = None | transform = None | ||||
| if conf['transformer'] == "kronoker": | if conf['transformer'] == "kronoker": | ||||
| elif conf['transformer'] == "linear": | elif conf['transformer'] == "linear": | ||||
| transform = l2l.optim.ModuleTransform(torch.nn.Linear) | transform = l2l.optim.ModuleTransform(torch.nn.Linear) | ||||
| trainer = Trainer(config) | |||||
| trainer = Trainer(conf) | |||||
| trainer.cuda() | |||||
| head = Head(config) | |||||
| # define meta algorithm | # define meta algorithm | ||||
| if conf['meta_algo'] == "maml": | if conf['meta_algo'] == "maml": | ||||
| trainer = l2l.algorithms.MAML(trainer, lr=conf['local_lr'], first_order=conf['first_order']) | |||||
| head = l2l.algorithms.MAML(head, lr=conf['local_lr'], first_order=conf['first_order']) | |||||
| elif conf['meta_algo'] == 'metasgd': | elif conf['meta_algo'] == 'metasgd': | ||||
| trainer = l2l.algorithms.MetaSGD(trainer, lr=conf['local_lr'], first_order=conf['first_order']) | |||||
| head = l2l.algorithms.MetaSGD(head, lr=conf['local_lr'], first_order=conf['first_order']) | |||||
| elif conf['meta_algo'] == 'gbml': | elif conf['meta_algo'] == 'gbml': | ||||
| trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=conf['local_lr'], | |||||
| head = l2l.algorithms.GBML(head, transform=transform, lr=conf['local_lr'], | |||||
| adapt_transform=conf['adapt_transform'], first_order=conf['first_order']) | adapt_transform=conf['adapt_transform'], first_order=conf['first_order']) | ||||
| trainer.cuda() | |||||
| # net = nn.Sequential(emb, head) | |||||
| head.cuda() | |||||
| criterion = nn.MSELoss() | criterion = nn.MSELoss() | ||||
| all_parameters = list(emb.parameters()) + list(trainer.parameters()) | |||||
| all_parameters = list(emb.parameters()) + list(trainer.parameters()) + list(head.parameters()) | |||||
| optimizer = torch.optim.Adam(all_parameters, lr=conf['lr']) | optimizer = torch.optim.Adam(all_parameters, lr=conf['lr']) | ||||
| # Load training dataset. | |||||
| print("LOAD DATASET PHASE") | |||||
| training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4) | |||||
| supp_xs_s = [] | |||||
| supp_ys_s = [] | |||||
| query_xs_s = [] | |||||
| query_ys_s = [] | |||||
| if checkpoint_dir: | if checkpoint_dir: | ||||
| print("in checkpoint - bug happened") | print("in checkpoint - bug happened") | ||||
| # model_state, optimizer_state = torch.load( | # model_state, optimizer_state = torch.load( | ||||
| # optimizer.load_state_dict(optimizer_state) | # optimizer.load_state_dict(optimizer_state) | ||||
| # loading data | # loading data | ||||
| train_dataset, validation_dataset, test_dataset = load_data(data_dir, test_state=conf['test_state']) | |||||
| # _, validation_dataset, _ = load_data(data_dir, test_state=conf['test_state']) | |||||
| batch_size = conf['batch_size'] | batch_size = conf['batch_size'] | ||||
| num_batch = int(len(train_dataset) / batch_size) | |||||
| # num_batch = int(len(train_dataset) / batch_size) | |||||
| # a, b, c, d = zip(*train_dataset) | |||||
| C_distribs = [] | |||||
| indexes = list(np.arange(training_set_size)) | |||||
| all_test_users = [] | |||||
| for iteration in range(conf['num_epoch']): # loop over the dataset multiple times | |||||
| print("iteration:", iteration) | |||||
| num_batch = int(training_set_size / batch_size) | |||||
| if iteration == 0: | |||||
| print("changing cluster centroids started ...") | |||||
| indexes = list(np.arange(training_set_size)) | |||||
| supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] | |||||
| for idx in range(0, 2500): | |||||
| supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) | |||||
| supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) | |||||
| query_xs.append( | |||||
| pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) | |||||
| query_ys.append( | |||||
| pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) | |||||
| batch_sz = len(supp_xs) | |||||
| user_embeddings = [] | |||||
| for task in range(batch_sz): | |||||
| # Compute meta-training loss | |||||
| supp_xs[task] = supp_xs[task].cuda() | |||||
| supp_ys[task] = supp_ys[task].cuda() | |||||
| temp_sxs = emb(supp_xs[task]) | |||||
| y = supp_ys[task].view(-1, 1) | |||||
| input_pairs = torch.cat((temp_sxs, y), dim=1) | |||||
| _, mean_task, _ = trainer.cluster_module(temp_sxs, y) | |||||
| user_embeddings.append(mean_task.detach().cpu().numpy()) | |||||
| a, b, c, d = zip(*train_dataset) | |||||
| supp_xs[task] = supp_xs[task].cpu() | |||||
| supp_ys[task] = supp_ys[task].cpu() | |||||
| from sklearn.cluster import KMeans | |||||
| user_embeddings = np.array(user_embeddings) | |||||
| kmeans_model = KMeans(n_clusters=conf['cluster_k'], init="k-means++").fit(user_embeddings) | |||||
| trainer.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda() | |||||
| if iteration > (0): | |||||
| indexes = data_batching_new(indexes, C_distribs, batch_size, training_set_size, conf['cluster_k'], conf) | |||||
| else: | |||||
| random.shuffle(indexes) | |||||
| C_distribs = [] | |||||
| for epoch in range(config['num_epoch']): # loop over the dataset multiple times | |||||
| for i in range(num_batch): | for i in range(num_batch): | ||||
| optimizer.zero_grad() | optimizer.zero_grad() | ||||
| meta_train_error = 0.0 | meta_train_error = 0.0 | ||||
| meta_cluster_error = 0.0 | |||||
| # print("EPOCH: ", epoch, " BATCH: ", i) | |||||
| supp_xs = list(a[batch_size * i:batch_size * (i + 1)]) | |||||
| supp_ys = list(b[batch_size * i:batch_size * (i + 1)]) | |||||
| query_xs = list(c[batch_size * i:batch_size * (i + 1)]) | |||||
| query_ys = list(d[batch_size * i:batch_size * (i + 1)]) | |||||
| supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] | |||||
| for idx in range(batch_size * i, batch_size * (i + 1)): | |||||
| supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) | |||||
| supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) | |||||
| query_xs.append( | |||||
| pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) | |||||
| query_ys.append( | |||||
| pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) | |||||
| batch_sz = len(supp_xs) | batch_sz = len(supp_xs) | ||||
| # iterate over all tasks | |||||
| for task in range(batch_sz): | for task in range(batch_sz): | ||||
| sxs = supp_xs[task].cuda() | |||||
| qxs = query_xs[task].cuda() | |||||
| sys = supp_ys[task].cuda() | |||||
| qys = query_ys[task].cuda() | |||||
| learner = trainer.clone() | |||||
| temp_sxs = emb(sxs) | |||||
| temp_qxs = emb(qxs) | |||||
| evaluation_error = fast_adapt(learner, | |||||
| temp_sxs, | |||||
| temp_qxs, | |||||
| sys, | |||||
| qys, | |||||
| conf['inner']) | |||||
| evaluation_error.backward() | |||||
| # Compute meta-training loss | |||||
| supp_xs[task] = supp_xs[task].cuda() | |||||
| supp_ys[task] = supp_ys[task].cuda() | |||||
| query_xs[task] = query_xs[task].cuda() | |||||
| query_ys[task] = query_ys[task].cuda() | |||||
| learner = head.clone() | |||||
| temp_sxs = emb(supp_xs[task]) | |||||
| temp_qxs = emb(query_xs[task]) | |||||
| evaluation_error, c, K_LOSS = fast_adapt(learner, | |||||
| temp_sxs, | |||||
| temp_qxs, | |||||
| supp_ys[task], | |||||
| query_ys[task], | |||||
| conf['inner'], | |||||
| trainer=trainer, | |||||
| test=False, | |||||
| iteration=iteration | |||||
| ) | |||||
| C_distribs.append(c.detach().cpu().numpy()) | |||||
| meta_cluster_error += K_LOSS | |||||
| evaluation_error.backward(retain_graph=True) | |||||
| meta_train_error += evaluation_error.item() | meta_train_error += evaluation_error.item() | ||||
| del (sxs, qxs, sys, qys) | |||||
| supp_xs[task].cpu() | |||||
| query_xs[task].cpu() | |||||
| supp_ys[task].cpu() | |||||
| query_ys[task].cpu() | |||||
| supp_xs[task] = supp_xs[task].cpu() | |||||
| supp_ys[task] = supp_ys[task].cpu() | |||||
| query_xs[task] = query_xs[task].cpu() | |||||
| query_ys[task] = query_ys[task].cpu() | |||||
| ################################################ | |||||
| # Average the accumulated gradients and optimize (After each batch we will update params) | |||||
| # Print some metrics | |||||
| print('Iteration', iteration) | |||||
| print('Meta Train Error', meta_train_error / batch_sz) | |||||
| print('KL Train Error', round(meta_cluster_error / batch_sz, 4), "\t", C_distribs[-1]) | |||||
| # Average the accumulated gradients and optimize | |||||
| for p in all_parameters: | for p in all_parameters: | ||||
| # if p.grad!=None: | |||||
| p.grad.data.mul_(1.0 / batch_sz) | p.grad.data.mul_(1.0 / batch_sz) | ||||
| optimizer.step() | optimizer.step() | ||||
| del (supp_xs, supp_ys, query_xs, query_ys) | |||||
| gc.collect() | |||||
| # test results on the validation data | # test results on the validation data | ||||
| val_loss, val_ndcg1, val_ndcg3 = hyper_test(emb, trainer, validation_dataset, adaptation_step=conf['inner']) | |||||
| val_loss, val_ndcg1, val_ndcg3 = hyper_test(emb, head, trainer, batch_size, master_path, conf['test_state'], | |||||
| adaptation_step=conf['inner'], num_epoch=iteration) | |||||
| # with tune.checkpoint_dir(epoch) as checkpoint_dir: | # with tune.checkpoint_dir(epoch) as checkpoint_dir: | ||||
| # path = os.path.join(checkpoint_dir, "checkpoint") | # path = os.path.join(checkpoint_dir, "checkpoint") | ||||
| # torch.save((net.state_dict(), optimizer.state_dict()), path) | # torch.save((net.state_dict(), optimizer.state_dict()), path) | ||||
| tune.report(loss=val_loss, ndcg1=val_ndcg1, ndcg3=val_ndcg3) | tune.report(loss=val_loss, ndcg1=val_ndcg1, ndcg3=val_ndcg3) | ||||
| print("Finished Training") | print("Finished Training") |