extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fast_adapt.py 1.6KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import torch
  2. import pickle
  3. from options import config
  4. import random
  5. def cl_loss(c):
  6. alpha = config['alpha']
  7. beta = config['beta']
  8. d = config['d']
  9. a = torch.div(1, torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c), beta))))))
  10. # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
  11. b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
  12. # b = 1 * a * (1 - a) * (1 - 2 * a)
  13. loss = torch.sum(b)
  14. return loss
  15. def fast_adapt(
  16. learn,
  17. adaptation_data,
  18. evaluation_data,
  19. adaptation_labels,
  20. evaluation_labels,
  21. adaptation_steps,
  22. get_predictions=False,
  23. epoch=None):
  24. for step in range(adaptation_steps):
  25. temp, c = learn(adaptation_data, adaptation_labels, training=True)
  26. train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
  27. cluster_loss = cl_loss(c)
  28. total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
  29. learn.adapt(total_loss)
  30. predictions, c = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
  31. adaptation_labels=adaptation_labels)
  32. valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
  33. cluster_loss = cl_loss(c)
  34. total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss
  35. if random.random() < 0.05:
  36. print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())
  37. if get_predictions:
  38. return total_loss, predictions
  39. return total_loss