1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import torch
- import pickle
- from options import config
- import random
-
-
- def cl_loss(c):
- alpha = config['alpha']
- beta = config['beta']
- d = config['d']
- a = torch.div(1,
- torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta))))))
- # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
- b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
- # b = 1 * a * (1 - a) * (1 - 2 * a)
- loss = torch.sum(b)
- return loss
-
-
- def fast_adapt(
- learn,
- adaptation_data,
- evaluation_data,
- adaptation_labels,
- evaluation_labels,
- adaptation_steps,
- get_predictions=False,
- epoch=None):
- is_print = random.random() < 0.05
-
- for step in range(adaptation_steps):
- temp, c, k_loss = learn(adaptation_data, adaptation_labels, training=True)
- train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
- # cluster_loss = cl_loss(c)
- # total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
- total_loss = train_error + config['kmeans_loss_weight'] * k_loss
- learn.adapt(total_loss)
- if is_print:
- # print("in support:\t", round(k_loss.item(),4))
- pass
-
- predictions, c, k_loss = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
- adaptation_labels=adaptation_labels)
- valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
- # cluster_loss = cl_loss(c)
- # total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss
- total_loss = valid_error + config['kmeans_loss_weight'] * k_loss
-
- if is_print:
-
- # print("in query:\t", round(k_loss.item(),4))
- print(c[0].detach().cpu().numpy(),"\t",round(k_loss.item(),3),"\n")
-
- # if random.random() < 0.05:
- # print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())
-
- if get_predictions:
- return total_loss, predictions
- return total_loss, c, k_loss.item()
|