extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fast_adapt.py 2.0KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import torch
  2. import pickle
  3. from options import config
  4. import random
  5. def cl_loss(c):
  6. alpha = config['alpha']
  7. beta = config['beta']
  8. d = config['d']
  9. a = torch.div(1,
  10. torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta))))))
  11. # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
  12. b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
  13. # b = 1 * a * (1 - a) * (1 - 2 * a)
  14. loss = torch.sum(b)
  15. return loss
  16. def fast_adapt(
  17. head,
  18. adaptation_data,
  19. evaluation_data,
  20. adaptation_labels,
  21. evaluation_labels,
  22. adaptation_steps,
  23. get_predictions=False,
  24. trainer=None,
  25. test=False,
  26. iteration=None):
  27. for step in range(adaptation_steps):
  28. # g1,b1,g2,b2,c,k_loss,rec_loss = trainer(adaptation_data,adaptation_labels,training=True)
  29. g1, b1, g2, b2, c, k_loss = trainer(adaptation_data, adaptation_labels, training=True)
  30. temp = head(adaptation_data, g1, b1, g2, b2)
  31. train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
  32. # train_error = train_error + config['kmeans_loss_weight'] * k_loss + config['rec_loss_weight']*rec_loss
  33. train_error = train_error + config['kmeans_loss_weight'] * k_loss
  34. head.adapt(train_error)
  35. # g1,b1,g2,b2,c,k_loss,rec_loss = trainer(adaptation_data,adaptation_labels,training=False)
  36. g1, b1, g2, b2, c, k_loss = trainer(adaptation_data, adaptation_labels, training=False)
  37. predictions = head(evaluation_data, g1, b1, g2, b2)
  38. valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
  39. # total_loss = valid_error + config['kmeans_loss_weight'] * k_loss + config['rec_loss_weight']*rec_loss
  40. total_loss = valid_error + config['kmeans_loss_weight'] * k_loss
  41. if get_predictions:
  42. return predictions.detach().cpu(), c
  43. # return total_loss,c,k_loss.detach().cpu().item(),rec_loss.detach().cpu().item()
  44. return total_loss, c, k_loss.detach().cpu().item()