extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_testing.py 2.7KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import random
  2. from torch.nn import L1Loss
  3. import numpy as np
  4. from fast_adapt import fast_adapt
  5. from sklearn.metrics import ndcg_score
  6. import gc
  7. import pickle
  8. import os
  9. def hyper_test(embedding, head, trainer, batch_size, master_path, test_state, adaptation_step, num_epoch=None):
  10. test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
  11. indexes = list(np.arange(test_set_size))
  12. random.shuffle(indexes)
  13. # test_set_size = len(total_dataset)
  14. # random.shuffle(total_dataset)
  15. # a, b, c, d = zip(*total_dataset)
  16. # a, b, c, d = list(a), list(b), list(c), list(d)
  17. losses_q = []
  18. ndcgs11 = []
  19. ndcgs33 = []
  20. head.eval()
  21. trainer.eval()
  22. for iterator in range(test_set_size):
  23. a = pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, iterator), "rb"))
  24. b = pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, iterator), "rb"))
  25. c = pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, iterator), "rb"))
  26. d = pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, iterator), "rb"))
  27. try:
  28. supp_xs = a.cuda()
  29. supp_ys = b.cuda()
  30. query_xs = c.cuda()
  31. query_ys = d.cuda()
  32. except IndexError:
  33. print("index error in test method")
  34. continue
  35. learner = head.clone()
  36. temp_sxs = embedding(supp_xs)
  37. temp_qxs = embedding(query_xs)
  38. predictions, c = fast_adapt(
  39. learner,
  40. temp_sxs,
  41. temp_qxs,
  42. supp_ys,
  43. query_ys,
  44. adaptation_step,
  45. get_predictions=True,
  46. trainer=trainer,
  47. test=True,
  48. iteration=num_epoch
  49. )
  50. l1 = L1Loss(reduction='mean')
  51. loss_q = l1(predictions.view(-1), query_ys.cpu())
  52. losses_q.append(float(loss_q))
  53. predictions = predictions.view(-1)
  54. y_true = query_ys.cpu().detach().numpy()
  55. y_pred = predictions.cpu().detach().numpy()
  56. ndcgs11.append(float(ndcg_score([y_true], [y_pred], k=1, sample_weight=None, ignore_ties=False)))
  57. ndcgs33.append(float(ndcg_score([y_true], [y_pred], k=3, sample_weight=None, ignore_ties=False)))
  58. del supp_xs, supp_ys, query_xs, query_ys, predictions, y_true, y_pred, loss_q
  59. # calculate metrics
  60. try:
  61. losses_q = np.array(losses_q).mean()
  62. except:
  63. losses_q = 100
  64. try:
  65. ndcg1 = np.array(ndcgs11).mean()
  66. ndcg3 = np.array(ndcgs33).mean()
  67. except:
  68. ndcg1 = 0
  69. ndcg3 = 0
  70. head.train()
  71. trainer.train()
  72. gc.collect()
  73. return losses_q, ndcg1, ndcg3