import torch import pickle from options import config import random def cl_loss(c): alpha = config['alpha'] beta = config['beta'] d = config['d'] a = torch.div(1, torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c), beta)))))) # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta))) b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a)))) # b = 1 * a * (1 - a) * (1 - 2 * a) loss = torch.sum(b) return loss def fast_adapt( learn, adaptation_data, evaluation_data, adaptation_labels, evaluation_labels, adaptation_steps, get_predictions=False, epoch=None): for step in range(adaptation_steps): temp, c = learn(adaptation_data, adaptation_labels, training=True) train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) cluster_loss = cl_loss(c) total_loss = train_error + config['cluster_loss_weight'] * cluster_loss learn.adapt(total_loss) predictions, c = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data, adaptation_labels=adaptation_labels) valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) cluster_loss = cl_loss(c) total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss if random.random() < 0.05: print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy()) if get_predictions: return total_loss, predictions return total_loss