| value = torch.mm(C, self.array) | value = torch.mm(C, self.array) | ||||
| # simple add operation | # simple add operation | ||||
| # new_task_embed = value + mean_task | # new_task_embed = value + mean_task | ||||
| new_task_embed = value | |||||
| return C, new_task_embed | |||||
| # new_task_embed = value | |||||
| new_task_embed = mean_task | |||||
| # print("injam1:", new_task_embed) | |||||
| # print("injam2:", self.array) | |||||
| list_dist = [] | |||||
| # list_dist = torch.norm(new_task_embed - self.array, p=2, dim=1,keepdim=True) | |||||
| list_dist = torch.sum(torch.pow(new_task_embed - self.array,2),dim=1) | |||||
| stack_dist = list_dist | |||||
| # print("injam3:", stack_dist) | |||||
| ## Second, find the minimum squared distance for softmax normalization | |||||
| min_dist = min(list_dist) | |||||
| # print("injam4:", min_dist) | |||||
| ## Third, compute exponentials shifted with min_dist to avoid underflow (0/0) issues in softmaxes | |||||
| alpha = config['kmeans_alpha'] # Placeholder tensor for alpha | |||||
| list_exp = [] | |||||
| for i in range(self.clusters_k): | |||||
| exp = torch.exp(-alpha * (stack_dist[i] - min_dist)) | |||||
| list_exp.append(exp) | |||||
| stack_exp = torch.stack(list_exp) | |||||
| sum_exponentials = torch.sum(stack_exp) | |||||
| # print("injam5:", stack_exp, sum_exponentials) | |||||
| ## Fourth, compute softmaxes and the embedding/representative distances weighted by softmax | |||||
| list_softmax = [] | |||||
| list_weighted_dist = [] | |||||
| for j in range(self.clusters_k): | |||||
| softmax = stack_exp[j] / sum_exponentials | |||||
| weighted_dist = stack_dist[j] * softmax | |||||
| list_softmax.append(softmax) | |||||
| list_weighted_dist.append(weighted_dist) | |||||
| stack_weighted_dist = torch.stack(list_weighted_dist) | |||||
| kmeans_loss = torch.sum(stack_weighted_dist, dim=0) | |||||
| return C, new_task_embed, kmeans_loss | |||||
| class Trainer(torch.nn.Module): | class Trainer(torch.nn.Module): | ||||
| def forward(self, task_embed, y, training, adaptation_data=None, adaptation_labels=None): | def forward(self, task_embed, y, training, adaptation_data=None, adaptation_labels=None): | ||||
| if training: | if training: | ||||
| C, clustered_task_embed = self.cluster_module(task_embed, y) | |||||
| C, clustered_task_embed, k_loss = self.cluster_module(task_embed, y) | |||||
| # hidden layers | # hidden layers | ||||
| # todo : adding activation function or remove it | # todo : adding activation function or remove it | ||||
| hidden_1 = self.fc1(task_embed) | hidden_1 = self.fc1(task_embed) | ||||
| y_pred = self.linear_out(hidden_3) | y_pred = self.linear_out(hidden_3) | ||||
| else: | else: | ||||
| C, clustered_task_embed = self.cluster_module(adaptation_data, adaptation_labels) | |||||
| C, clustered_task_embed, k_loss = self.cluster_module(adaptation_data, adaptation_labels) | |||||
| beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) | beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed)) | ||||
| gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) | gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed)) | ||||
| beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) | beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed)) | ||||
| y_pred = self.linear_out(hidden_3) | y_pred = self.linear_out(hidden_3) | ||||
| return y_pred, C | |||||
| return y_pred, C, k_loss |
| alpha = config['alpha'] | alpha = config['alpha'] | ||||
| beta = config['beta'] | beta = config['beta'] | ||||
| d = config['d'] | d = config['d'] | ||||
| a = torch.div(1, torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta)))))) | |||||
| a = torch.div(1, | |||||
| torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta)))))) | |||||
| # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta))) | # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta))) | ||||
| b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a)))) | b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a)))) | ||||
| # b = 1 * a * (1 - a) * (1 - 2 * a) | # b = 1 * a * (1 - a) * (1 - 2 * a) | ||||
| adaptation_steps, | adaptation_steps, | ||||
| get_predictions=False, | get_predictions=False, | ||||
| epoch=None): | epoch=None): | ||||
| is_print = random.random() < 0.05 | |||||
| for step in range(adaptation_steps): | for step in range(adaptation_steps): | ||||
| temp, c = learn(adaptation_data, adaptation_labels, training=True) | |||||
| temp, c, k_loss = learn(adaptation_data, adaptation_labels, training=True) | |||||
| train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) | train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) | ||||
| # cluster_loss = cl_loss(c) | # cluster_loss = cl_loss(c) | ||||
| # total_loss = train_error + config['cluster_loss_weight'] * cluster_loss | # total_loss = train_error + config['cluster_loss_weight'] * cluster_loss | ||||
| total_loss = train_error | |||||
| total_loss = train_error + config['kmeans_loss_weight'] * k_loss | |||||
| learn.adapt(total_loss) | learn.adapt(total_loss) | ||||
| if is_print: | |||||
| # print("in support:\t", round(k_loss.item(),4)) | |||||
| pass | |||||
| predictions, c = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data, | |||||
| adaptation_labels=adaptation_labels) | |||||
| predictions, c, k_loss = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data, | |||||
| adaptation_labels=adaptation_labels) | |||||
| valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) | valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) | ||||
| # cluster_loss = cl_loss(c) | # cluster_loss = cl_loss(c) | ||||
| # total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss | # total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss | ||||
| total_loss = valid_error | |||||
| total_loss = valid_error + config['kmeans_loss_weight'] * k_loss | |||||
| if is_print: | |||||
| # print("in query:\t", round(k_loss.item(),4)) | |||||
| print(c[0].detach().cpu().numpy(),"\t",round(k_loss.item(),3),"\n") | |||||
| # if random.random() < 0.05: | # if random.random() < 0.05: | ||||
| # print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy()) | # print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy()) | ||||
| if get_predictions: | if get_predictions: | ||||
| return total_loss, predictions | return total_loss, predictions | ||||
| return total_loss,c | |||||
| return total_loss, c, k_loss.item() |
| for i in range(num_batch): | for i in range(num_batch): | ||||
| meta_train_error = 0.0 | meta_train_error = 0.0 | ||||
| meta_cluster_error = 0.0 | |||||
| optimizer.zero_grad() | optimizer.zero_grad() | ||||
| print("EPOCH: ", iteration, " BATCH: ", i) | print("EPOCH: ", iteration, " BATCH: ", i) | ||||
| supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] | supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] | ||||
| temp_sxs = emb(supp_xs[task]) | temp_sxs = emb(supp_xs[task]) | ||||
| temp_qxs = emb(query_xs[task]) | temp_qxs = emb(query_xs[task]) | ||||
| evaluation_error, c = fast_adapt(learner, | |||||
| temp_sxs, | |||||
| temp_qxs, | |||||
| supp_ys[task], | |||||
| query_ys[task], | |||||
| config['inner'], | |||||
| epoch=iteration) | |||||
| evaluation_error, c, k_loss = fast_adapt(learner, | |||||
| temp_sxs, | |||||
| temp_qxs, | |||||
| supp_ys[task], | |||||
| query_ys[task], | |||||
| config['inner'], | |||||
| epoch=iteration) | |||||
| C_distribs.append(c) | |||||
| # C_distribs.append(c) | |||||
| evaluation_error.backward(retain_graph=True) | evaluation_error.backward(retain_graph=True) | ||||
| meta_train_error += evaluation_error.item() | meta_train_error += evaluation_error.item() | ||||
| meta_cluster_error += k_loss | |||||
| # supp_xs[task].cpu() | # supp_xs[task].cpu() | ||||
| # query_xs[task].cpu() | # query_xs[task].cpu() | ||||
| # Print some metrics | # Print some metrics | ||||
| print('Iteration', iteration) | print('Iteration', iteration) | ||||
| print('Meta Train Error', meta_train_error / batch_sz) | print('Meta Train Error', meta_train_error / batch_sz) | ||||
| print('KL Train Error', meta_cluster_error / batch_sz) | |||||
| clustering_loss = config['kl_loss_weight'] * kl_loss(C_distribs) | |||||
| clustering_loss.backward() | |||||
| print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy()) | |||||
| # clustering_loss = config['kl_loss_weight'] * kl_loss(C_distribs) | |||||
| # clustering_loss.backward() | |||||
| # print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy()) | |||||
| # if i != (num_batch - 1): | # if i != (num_batch - 1): | ||||
| # C_distribs = [] | # C_distribs = [] | ||||
| gc.collect() | gc.collect() | ||||
| print("===============================================\n") | print("===============================================\n") | ||||
| if iteration % 2 == 0 and iteration != 0: | |||||
| # testing | |||||
| print("start of test phase") | |||||
| trainer.eval() | |||||
| with open("results2.txt", "a") as f: | |||||
| f.write("epoch:{}\n".format(iteration)) | |||||
| for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: | |||||
| test_dataset = None | |||||
| test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) | |||||
| supp_xs_s = [] | |||||
| supp_ys_s = [] | |||||
| query_xs_s = [] | |||||
| query_ys_s = [] | |||||
| gc.collect() | |||||
| print("===================== " + test_state + " =====================") | |||||
| mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'], | |||||
| num_epoch=config['num_epoch'], test_state=test_state, args=args) | |||||
| with open("results2.txt", "a") as f: | |||||
| f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) | |||||
| print("===================================================") | |||||
| del (test_dataset) | |||||
| gc.collect() | |||||
| trainer.train() | |||||
| with open("results2.txt", "a") as f: | |||||
| f.write("\n") | |||||
| print("\n\n\n") | |||||
| # if iteration % 2 == 0 and iteration != 0: | |||||
| # # testing | |||||
| # print("start of test phase") | |||||
| # trainer.eval() | |||||
| # | |||||
| # with open("results2.txt", "a") as f: | |||||
| # f.write("epoch:{}\n".format(iteration)) | |||||
| # | |||||
| # for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: | |||||
| # test_dataset = None | |||||
| # test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) | |||||
| # supp_xs_s = [] | |||||
| # supp_ys_s = [] | |||||
| # query_xs_s = [] | |||||
| # query_ys_s = [] | |||||
| # gc.collect() | |||||
| # | |||||
| # print("===================== " + test_state + " =====================") | |||||
| # mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'], | |||||
| # num_epoch=config['num_epoch'], test_state=test_state, args=args) | |||||
| # with open("results2.txt", "a") as f: | |||||
| # f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) | |||||
| # print("===================================================") | |||||
| # del (test_dataset) | |||||
| # gc.collect() | |||||
| # | |||||
| # trainer.train() | |||||
| # with open("results2.txt", "a") as f: | |||||
| # f.write("\n") | |||||
| # print("\n\n\n") | |||||
| # save model | # save model | ||||
| # final_model = torch.nn.Sequential(emb, head) | # final_model = torch.nn.Sequential(emb, head) |