extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learnToLearnTest.py 3.0KB

  1. import os
  2. import torch
  3. import pickle
  4. import random
  5. from options import config, states
  6. from torch.nn import functional as F
  7. from torch.nn import L1Loss
  8. # import matchzoo as mz
  9. import numpy as np
  10. from fast_adapt import fast_adapt
  11. from sklearn.metrics import ndcg_score
  12. import gc
  13. def test(embedding, head, total_dataset, batch_size, num_epoch, test_state=None,args=None):
  14. losses_q = []
  15. ndcgs1 = []
  16. ndcgs3 = []
  17. master_path = config['master_path']
  18. test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
  19. indexes = list(np.arange(test_set_size))
  20. random.shuffle(indexes)
  21. for iterator in indexes:
  22. a = pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, iterator), "rb"))
  23. b = pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, iterator), "rb"))
  24. c = pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, iterator), "rb"))
  25. d = pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, iterator), "rb"))
  26. try:
  27. supp_xs = a.cuda()
  28. supp_ys = b.cuda()
  29. query_xs = c.cuda()
  30. query_ys = d.cuda()
  31. except IndexError:
  32. print("index error in test method")
  33. continue
  34. learner = head.clone()
  35. temp_sxs = embedding(supp_xs)
  36. temp_qxs = embedding(query_xs)
  37. evaluation_error, predictions = fast_adapt(learner,
  38. temp_sxs,
  39. temp_qxs,
  40. supp_ys,
  41. query_ys,
  42. config['inner'],
  43. # args.inner_eval,
  44. get_predictions=True,
  45. )
  46. l1 = L1Loss(reduction='mean')
  47. loss_q = l1(predictions.view(-1), query_ys)
  48. # print("testing - iterator:{} - l1:{} ".format(iterator,loss_q))
  49. losses_q.append(float(loss_q))
  50. predictions = predictions.view(-1)
  51. y_true = query_ys.cpu().detach().numpy()
  52. y_pred = predictions.cpu().detach().numpy()
  53. ndcgs1.append(float(ndcg_score([y_true], [y_pred], k=1, sample_weight=None, ignore_ties=False)))
  54. ndcgs3.append(float(ndcg_score([y_true], [y_pred], k=3, sample_weight=None, ignore_ties=False)))
  55. del supp_xs, supp_ys, query_xs, query_ys, y_true, y_pred, loss_q, temp_sxs, temp_qxs, predictions, l1
  56. torch.cuda.empty_cache()
  57. # calculate metrics
  58. losses_q = np.array(losses_q).mean()
  59. print("mean of mse: ", losses_q)
  60. # print("======================================")
  61. n1 = np.array(ndcgs1).mean()
  62. print("nDCG1: ", n1)
  63. n3 = np.array(ndcgs3).mean()
  64. print("nDCG3: ", n3)
  65. del a, b, c, d, total_dataset
  66. gc.collect()
  67. return losses_q, n1, n3