import torch import pickle from options import config import random def cl_loss(c): alpha = config['alpha'] beta = config['beta'] d = config['d'] a = torch.div(1, torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta)))))) # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta))) b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a)))) # b = 1 * a * (1 - a) * (1 - 2 * a) loss = torch.sum(b) return loss def fast_adapt( learn, adaptation_data, evaluation_data, adaptation_labels, evaluation_labels, adaptation_steps, get_predictions=False, epoch=None): is_print = random.random() < 0.05 for step in range(adaptation_steps): temp, c, k_loss = learn(adaptation_data, adaptation_labels, training=True) train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) # cluster_loss = cl_loss(c) # total_loss = train_error + config['cluster_loss_weight'] * cluster_loss total_loss = train_error + config['kmeans_loss_weight'] * k_loss learn.adapt(total_loss) if is_print: # print("in support:\t", round(k_loss.item(),4)) pass predictions, c, k_loss = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data, adaptation_labels=adaptation_labels) valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) # cluster_loss = cl_loss(c) # total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss total_loss = valid_error + config['kmeans_loss_weight'] * k_loss if is_print: # print("in query:\t", round(k_loss.item(),4)) print(c[0].detach().cpu().numpy(),"\t",round(k_loss.item(),3),"\n") # if random.random() < 0.05: # print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy()) if get_predictions: return total_loss, predictions return total_loss, c, k_loss.item()