extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fast_adapt.py 2.1KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import torch
  2. import pickle
  3. from options import config
  4. import random
  5. def cl_loss(c):
  6. alpha = config['alpha']
  7. beta = config['beta']
  8. d = config['d']
  9. a = torch.div(1,
  10. torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta))))))
  11. # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
  12. b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
  13. # b = 1 * a * (1 - a) * (1 - 2 * a)
  14. loss = torch.sum(b)
  15. return loss
  16. def fast_adapt(
  17. learn,
  18. adaptation_data,
  19. evaluation_data,
  20. adaptation_labels,
  21. evaluation_labels,
  22. adaptation_steps,
  23. get_predictions=False,
  24. epoch=None):
  25. is_print = random.random() < 0.05
  26. for step in range(adaptation_steps):
  27. temp, c, k_loss = learn(adaptation_data, adaptation_labels, training=True)
  28. train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
  29. # cluster_loss = cl_loss(c)
  30. # total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
  31. total_loss = train_error + config['kmeans_loss_weight'] * k_loss
  32. learn.adapt(total_loss)
  33. if is_print:
  34. # print("in support:\t", round(k_loss.item(),4))
  35. pass
  36. predictions, c, k_loss = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
  37. adaptation_labels=adaptation_labels)
  38. valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
  39. # cluster_loss = cl_loss(c)
  40. # total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss
  41. total_loss = valid_error + config['kmeans_loss_weight'] * k_loss
  42. if is_print:
  43. # print("in query:\t", round(k_loss.item(),4))
  44. print(c[0].detach().cpu().numpy(),"\t",round(k_loss.item(),3),"\n")
  45. # if random.random() < 0.05:
  46. # print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())
  47. if get_predictions:
  48. return total_loss, predictions
  49. return total_loss, c, k_loss.item()