extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fast_adapt.py 1.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import torch
  2. import pickle
  3. from options import config
  4. import random
  5. def cl_loss(c):
  6. alpha = config['alpha']
  7. beta = config['beta']
  8. d = config['d']
  9. a = torch.div(1, torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta))))))
  10. # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
  11. b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
  12. # b = 1 * a * (1 - a) * (1 - 2 * a)
  13. loss = torch.sum(b)
  14. return loss
  15. def fast_adapt(
  16. learn,
  17. adaptation_data,
  18. evaluation_data,
  19. adaptation_labels,
  20. evaluation_labels,
  21. adaptation_steps,
  22. get_predictions=False,
  23. epoch=None):
  24. for step in range(adaptation_steps):
  25. temp, c = learn(adaptation_data, adaptation_labels, training=True)
  26. train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
  27. # cluster_loss = cl_loss(c)
  28. # total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
  29. total_loss = train_error
  30. learn.adapt(total_loss)
  31. predictions, c = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
  32. adaptation_labels=adaptation_labels)
  33. valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
  34. # cluster_loss = cl_loss(c)
  35. # total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss
  36. total_loss = valid_error
  37. # if random.random() < 0.05:
  38. # print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())
  39. if get_predictions:
  40. return total_loss, predictions
  41. return total_loss,c