|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- # import torch
- # import pickle
- # from options import config
- # import random
- #
- #
- # def cl_loss(c):
- # alpha = config['alpha']
- # beta = config['beta']
- # d = config['d']
- # a = torch.div(1,
- # torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta))))))
- # # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
- # b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
- # # b = 1 * a * (1 - a) * (1 - 2 * a)
- # loss = torch.sum(b)
- # return loss
- #
- #
- # def fast_adapt(
- # learn,
- # adaptation_data,
- # evaluation_data,
- # adaptation_labels,
- # evaluation_labels,
- # adaptation_steps,
- # get_predictions=False,
- # epoch=None):
- # is_print = random.random() < 0.05
- #
- # for step in range(adaptation_steps):
- # temp, c, k_loss = learn(adaptation_data, adaptation_labels, training=True)
- # train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
- # # cluster_loss = cl_loss(c)
- # # total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
- # total_loss = train_error + config['kmeans_loss_weight'] * k_loss
- # learn.adapt(total_loss)
- # if is_print:
- # # print("in support:\t", round(k_loss.item(),4))
- # pass
- #
- # predictions, c, k_loss = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
- # adaptation_labels=adaptation_labels)
- # valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
- # # cluster_loss = cl_loss(c)
- # # total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss
- # total_loss = valid_error + config['kmeans_loss_weight'] * k_loss
- #
- # if is_print:
- #
- # # print("in query:\t", round(k_loss.item(),4))
- # print(c[0].detach().cpu().numpy(),"\t",round(k_loss.item(),3),"\n")
- #
- # # if random.random() < 0.05:
- # # print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())
- #
- # if get_predictions:
- # return total_loss, predictions
- # return total_loss, c, k_loss.item()
-
- import torch
- import pickle
-
-
- def fast_adapt(
- learn,
- adaptation_data,
- evaluation_data,
- adaptation_labels,
- evaluation_labels,
- adaptation_steps,
- get_predictions=False):
- for step in range(adaptation_steps):
- temp = learn(adaptation_data, adaptation_labels, training=True)
- train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
- learn.adapt(train_error)
-
- predictions = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
- adaptation_labels=adaptation_labels)
- valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
-
- if get_predictions:
- return valid_error, predictions
- return valid_error
|