extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learnToLearnTest.py 2.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. import torch
  3. import pickle
  4. import random
  5. from options import config, states
  6. from torch.nn import functional as F
  7. from torch.nn import L1Loss
  8. # import matchzoo as mz
  9. import numpy as np
  10. from fast_adapt import fast_adapt
  11. from sklearn.metrics import ndcg_score
  12. import gc
  13. def test(embedding, head, total_dataset, batch_size, num_epoch, test_state=None,args=None):
  14. losses_q = []
  15. ndcgs1 = []
  16. ndcgs3 = []
  17. master_path = config['master_path']
  18. test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
  19. indexes = list(np.arange(test_set_size))
  20. random.shuffle(indexes)
  21. for iterator in indexes:
  22. a = pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, iterator), "rb"))
  23. b = pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, iterator), "rb"))
  24. c = pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, iterator), "rb"))
  25. d = pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, iterator), "rb"))
  26. try:
  27. supp_xs = a.cuda()
  28. supp_ys = b.cuda()
  29. query_xs = c.cuda()
  30. query_ys = d.cuda()
  31. except IndexError:
  32. print("index error in test method")
  33. continue
  34. learner = head.clone()
  35. temp_sxs = embedding(supp_xs)
  36. temp_qxs = embedding(query_xs)
  37. evaluation_error, predictions = fast_adapt(learner,
  38. temp_sxs,
  39. temp_qxs,
  40. supp_ys,
  41. query_ys,
  42. # config['inner'],
  43. args.inner_eval,
  44. get_predictions=True)
  45. l1 = L1Loss(reduction='mean')
  46. loss_q = l1(predictions.view(-1), query_ys)
  47. # print("testing - iterator:{} - l1:{} ".format(iterator,loss_q))
  48. losses_q.append(float(loss_q))
  49. predictions = predictions.view(-1)
  50. y_true = query_ys.cpu().detach().numpy()
  51. y_pred = predictions.cpu().detach().numpy()
  52. ndcgs1.append(float(ndcg_score([y_true], [y_pred], k=1, sample_weight=None, ignore_ties=False)))
  53. ndcgs3.append(float(ndcg_score([y_true], [y_pred], k=3, sample_weight=None, ignore_ties=False)))
  54. del supp_xs, supp_ys, query_xs, query_ys, y_true, y_pred, loss_q, temp_sxs, temp_qxs, predictions, l1
  55. # torch.cuda.empty_cache()
  56. # calculate metrics
  57. losses_q = np.array(losses_q).mean()
  58. print("mean of mse: ", losses_q)
  59. # print("======================================")
  60. n1 = np.array(ndcgs1).mean()
  61. print("nDCG1: ", n1)
  62. n3 = np.array(ndcgs3).mean()
  63. print("nDCG3: ", n3)
  64. del a, b, c, d, total_dataset
  65. gc.collect()
  66. return losses_q, n1, n3