import torch import pickle from options import config import random def cl_loss(c): alpha = config['alpha'] beta = config['beta'] d = config['d'] a = torch.div(1, torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta)))))) # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta))) b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a)))) # b = 1 * a * (1 - a) * (1 - 2 * a) loss = torch.sum(b) return loss def fast_adapt( head, adaptation_data, evaluation_data, adaptation_labels, evaluation_labels, adaptation_steps, get_predictions=False, trainer=None, test=False, iteration=None): for step in range(adaptation_steps): # g1,b1,g2,b2,c,k_loss,rec_loss = trainer(adaptation_data,adaptation_labels,training=True) g1, b1, g2, b2, c, k_loss = trainer(adaptation_data, adaptation_labels, training=True) temp = head(adaptation_data, g1, b1, g2, b2) train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) # train_error = train_error + config['kmeans_loss_weight'] * k_loss + config['rec_loss_weight']*rec_loss train_error = train_error + config['kmeans_loss_weight'] * k_loss head.adapt(train_error) # g1,b1,g2,b2,c,k_loss,rec_loss = trainer(adaptation_data,adaptation_labels,training=False) g1, b1, g2, b2, c, k_loss = trainer(adaptation_data, adaptation_labels, training=False) predictions = head(evaluation_data, g1, b1, g2, b2) valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) # total_loss = valid_error + config['kmeans_loss_weight'] * k_loss + config['rec_loss_weight']*rec_loss total_loss = valid_error + config['kmeans_loss_weight'] * k_loss if get_predictions: return predictions.detach().cpu(), c # return total_loss,c,k_loss.detach().cpu().item(),rec_loss.detach().cpu().item() return total_loss, c, k_loss.detach().cpu().item()