extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learnToLearnTest.py 3.0KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os
  2. import torch
  3. import pickle
  4. import random
  5. from options import config, states
  6. from torch.nn import functional as F
  7. from torch.nn import L1Loss
  8. # import matchzoo as mz
  9. import numpy as np
  10. from fast_adapt import fast_adapt
  11. from sklearn.metrics import ndcg_score
  12. import gc
  13. def test(embedding, head, total_dataset, batch_size, num_epoch, test_state=None,args=None):
  14. losses_q = []
  15. ndcgs1 = []
  16. ndcgs3 = []
  17. master_path = config['master_path']
  18. test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
  19. indexes = list(np.arange(test_set_size))
  20. random.shuffle(indexes)
  21. for iterator in indexes:
  22. a = pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, iterator), "rb"))
  23. b = pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, iterator), "rb"))
  24. c = pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, iterator), "rb"))
  25. d = pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, iterator), "rb"))
  26. try:
  27. supp_xs = a.cuda()
  28. supp_ys = b.cuda()
  29. query_xs = c.cuda()
  30. query_ys = d.cuda()
  31. except IndexError:
  32. print("index error in test method")
  33. continue
  34. learner = head.clone()
  35. temp_sxs = embedding(supp_xs)
  36. temp_qxs = embedding(query_xs)
  37. evaluation_error, predictions = fast_adapt(learner,
  38. temp_sxs,
  39. temp_qxs,
  40. supp_ys,
  41. query_ys,
  42. config['inner'],
  43. # args.inner_eval,
  44. get_predictions=True,
  45. )
  46. l1 = L1Loss(reduction='mean')
  47. loss_q = l1(predictions.view(-1), query_ys)
  48. # print("testing - iterator:{} - l1:{} ".format(iterator,loss_q))
  49. losses_q.append(float(loss_q))
  50. predictions = predictions.view(-1)
  51. y_true = query_ys.cpu().detach().numpy()
  52. y_pred = predictions.cpu().detach().numpy()
  53. ndcgs1.append(float(ndcg_score([y_true], [y_pred], k=1, sample_weight=None, ignore_ties=False)))
  54. ndcgs3.append(float(ndcg_score([y_true], [y_pred], k=3, sample_weight=None, ignore_ties=False)))
  55. del supp_xs, supp_ys, query_xs, query_ys, y_true, y_pred, loss_q, temp_sxs, temp_qxs, predictions, l1
  56. torch.cuda.empty_cache()
  57. # calculate metrics
  58. losses_q = np.array(losses_q).mean()
  59. print("mean of mse: ", losses_q)
  60. # print("======================================")
  61. n1 = np.array(ndcgs1).mean()
  62. print("nDCG1: ", n1)
  63. n3 = np.array(ndcgs3).mean()
  64. print("nDCG3: ", n3)
  65. del a, b, c, d, total_dataset
  66. gc.collect()
  67. return losses_q, n1, n3