1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- import torch
- import pickle
- from options import config
- import random
-
-
- def cl_loss(c):
- alpha = config['alpha']
- beta = config['beta']
- d = config['d']
- a = torch.div(1,
- torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta))))))
- # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
- b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
- # b = 1 * a * (1 - a) * (1 - 2 * a)
- loss = torch.sum(b)
- return loss
-
-
- def fast_adapt(
- head,
- adaptation_data,
- evaluation_data,
- adaptation_labels,
- evaluation_labels,
- adaptation_steps,
- get_predictions=False,
- trainer=None,
- test=False,
- iteration=None):
- for step in range(adaptation_steps):
- # g1,b1,g2,b2,c,k_loss,rec_loss = trainer(adaptation_data,adaptation_labels,training=True)
- g1, b1, g2, b2, c, k_loss = trainer(adaptation_data, adaptation_labels, training=True)
- temp = head(adaptation_data, g1, b1, g2, b2)
- train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
- # train_error = train_error + config['kmeans_loss_weight'] * k_loss + config['rec_loss_weight']*rec_loss
- train_error = train_error + config['kmeans_loss_weight'] * k_loss
- head.adapt(train_error)
-
- # g1,b1,g2,b2,c,k_loss,rec_loss = trainer(adaptation_data,adaptation_labels,training=False)
- g1, b1, g2, b2, c, k_loss = trainer(adaptation_data, adaptation_labels, training=False)
- predictions = head(evaluation_data, g1, b1, g2, b2)
- valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
- # total_loss = valid_error + config['kmeans_loss_weight'] * k_loss + config['rec_loss_weight']*rec_loss
- total_loss = valid_error + config['kmeans_loss_weight'] * k_loss
-
- if get_predictions:
- return predictions.detach().cpu(), c
- # return total_loss,c,k_loss.detach().cpu().item(),rec_loss.detach().cpu().item()
- return total_loss, c, k_loss.detach().cpu().item()
|