|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- import torch
- import pickle
- from options import config
- import random
-
-
- def cl_loss(c):
- alpha = config['alpha']
- beta = config['beta']
- d = config['d']
- a = torch.div(1, torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta))))))
- # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
- b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
- # b = 1 * a * (1 - a) * (1 - 2 * a)
- loss = torch.sum(b)
- return loss
-
-
- def fast_adapt(
- learn,
- adaptation_data,
- evaluation_data,
- adaptation_labels,
- evaluation_labels,
- adaptation_steps,
- get_predictions=False,
- epoch=None):
- for step in range(adaptation_steps):
- temp, c = learn(adaptation_data, adaptation_labels, training=True)
- train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
- # cluster_loss = cl_loss(c)
- # total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
- total_loss = train_error
- learn.adapt(total_loss)
-
- predictions, c = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
- adaptation_labels=adaptation_labels)
- valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
- # cluster_loss = cl_loss(c)
- # total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss
- total_loss = valid_error
-
- # if random.random() < 0.05:
- # print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())
-
- if get_predictions:
- return total_loss, predictions
- return total_loss,c
|