extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fast_adapt.py 3.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # import torch
  2. # import pickle
  3. # from options import config
  4. # import random
  5. #
  6. #
  7. # def cl_loss(c):
  8. # alpha = config['alpha']
  9. # beta = config['beta']
  10. # d = config['d']
  11. # a = torch.div(1,
  12. # torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta))))))
  13. # # a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
  14. # b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
  15. # # b = 1 * a * (1 - a) * (1 - 2 * a)
  16. # loss = torch.sum(b)
  17. # return loss
  18. #
  19. #
  20. # def fast_adapt(
  21. # learn,
  22. # adaptation_data,
  23. # evaluation_data,
  24. # adaptation_labels,
  25. # evaluation_labels,
  26. # adaptation_steps,
  27. # get_predictions=False,
  28. # epoch=None):
  29. # is_print = random.random() < 0.05
  30. #
  31. # for step in range(adaptation_steps):
  32. # temp, c, k_loss = learn(adaptation_data, adaptation_labels, training=True)
  33. # train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
  34. # # cluster_loss = cl_loss(c)
  35. # # total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
  36. # total_loss = train_error + config['kmeans_loss_weight'] * k_loss
  37. # learn.adapt(total_loss)
  38. # if is_print:
  39. # # print("in support:\t", round(k_loss.item(),4))
  40. # pass
  41. #
  42. # predictions, c, k_loss = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
  43. # adaptation_labels=adaptation_labels)
  44. # valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
  45. # # cluster_loss = cl_loss(c)
  46. # # total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss
  47. # total_loss = valid_error + config['kmeans_loss_weight'] * k_loss
  48. #
  49. # if is_print:
  50. #
  51. # # print("in query:\t", round(k_loss.item(),4))
  52. # print(c[0].detach().cpu().numpy(),"\t",round(k_loss.item(),3),"\n")
  53. #
  54. # # if random.random() < 0.05:
  55. # # print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())
  56. #
  57. # if get_predictions:
  58. # return total_loss, predictions
  59. # return total_loss, c, k_loss.item()
  60. import torch
  61. import pickle
  62. def fast_adapt(
  63. learn,
  64. adaptation_data,
  65. evaluation_data,
  66. adaptation_labels,
  67. evaluation_labels,
  68. adaptation_steps,
  69. get_predictions=False):
  70. for step in range(adaptation_steps):
  71. temp = learn(adaptation_data, adaptation_labels, training=True)
  72. train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
  73. learn.adapt(train_error)
  74. predictions = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
  75. adaptation_labels=adaptation_labels)
  76. valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
  77. if get_predictions:
  78. return valid_error, predictions
  79. return valid_error