extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_testing.py 2.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import random
  2. from torch.nn import L1Loss
  3. import numpy as np
  4. from fast_adapt import fast_adapt
  5. from sklearn.metrics import ndcg_score
  6. def hyper_test(embedding, head, total_dataset, adaptation_step):
  7. test_set_size = len(total_dataset)
  8. random.shuffle(total_dataset)
  9. a, b, c, d = zip(*total_dataset)
  10. losses_q = []
  11. ndcgs11 = []
  12. ndcgs33 = []
  13. head.eval()
  14. for iterator in range(test_set_size):
  15. try:
  16. supp_xs = a[iterator].cuda()
  17. supp_ys = b[iterator].cuda()
  18. query_xs = c[iterator].cuda()
  19. query_ys = d[iterator].cuda()
  20. except IndexError:
  21. print("index error in test method")
  22. continue
  23. learner = head.clone()
  24. temp_sxs = embedding(supp_xs)
  25. temp_qxs = embedding(query_xs)
  26. evaluation_error, predictions = fast_adapt(learner,
  27. temp_sxs,
  28. temp_qxs,
  29. supp_ys,
  30. query_ys,
  31. adaptation_step,
  32. get_predictions=True)
  33. l1 = L1Loss(reduction='mean')
  34. loss_q = l1(predictions.view(-1), query_ys)
  35. losses_q.append(float(loss_q))
  36. predictions = predictions.view(-1)
  37. y_true = query_ys.cpu().detach().numpy()
  38. y_pred = predictions.cpu().detach().numpy()
  39. ndcgs11.append(float(ndcg_score([y_true], [y_pred], k=1, sample_weight=None, ignore_ties=False)))
  40. ndcgs33.append(float(ndcg_score([y_true], [y_pred], k=3, sample_weight=None, ignore_ties=False)))
  41. del supp_xs, supp_ys, query_xs, query_ys, predictions, y_true, y_pred, loss_q
  42. # calculate metrics
  43. try:
  44. losses_q = np.array(losses_q).mean()
  45. except:
  46. losses_q = 100
  47. try:
  48. ndcg1 = np.array(ndcgs11).mean()
  49. ndcg3 = np.array(ndcgs33).mean()
  50. except:
  51. ndcg1 = 0
  52. ndcg3 = 0
  53. head.train()
  54. return losses_q, ndcg1, ndcg3