make other meta-learning algorithms implemented in l2l.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learnToLearnTest.py 3.1KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. import torch
  3. import pickle
  4. import random
  5. from options import config, states
  6. from torch.nn import functional as F
  7. from torch.nn import L1Loss
  8. import matchzoo as mz
  9. import numpy as np
  10. from fast_adapt import fast_adapt
  11. from sklearn.metrics import ndcg_score
  12. def test(embedding,head, total_dataset, batch_size, num_epoch,adaptation_step=config['inner']):
  13. test_set_size = len(total_dataset)
  14. random.shuffle(total_dataset)
  15. a, b, c, d = zip(*total_dataset)
  16. losses_q = []
  17. # ndcgs1 = []
  18. ndcgs11 = []
  19. # ndcgs111 = []
  20. # ndcgs3 = []
  21. ndcgs33=[]
  22. # ndcgs333 = []
  23. for iterator in range(test_set_size):
  24. if config['use_cuda']:
  25. try:
  26. supp_xs = a[iterator].cuda()
  27. supp_ys = b[iterator].cuda()
  28. query_xs = c[iterator].cuda()
  29. query_ys = d[iterator].cuda()
  30. except IndexError:
  31. print("index error in test method")
  32. continue
  33. else:
  34. try:
  35. supp_xs = a[iterator]
  36. supp_ys = b[iterator]
  37. query_xs = c[iterator]
  38. query_ys = d[iterator]
  39. except IndexError:
  40. print("index error in test method")
  41. continue
  42. num_local_update = adaptation_step
  43. learner = head.clone()
  44. temp_sxs = embedding(supp_xs)
  45. temp_qxs = embedding(query_xs)
  46. evaluation_error,predictions = fast_adapt(learner,
  47. temp_sxs,
  48. temp_qxs,
  49. supp_ys,
  50. query_ys,
  51. config['inner'],
  52. get_predictions=True)
  53. l1 = L1Loss(reduction='mean')
  54. loss_q = l1(predictions.view(-1), query_ys)
  55. # print("testing - iterator:{} - l1:{} ".format(iterator,loss_q))
  56. losses_q.append(float(loss_q))
  57. predictions = predictions.view(-1)
  58. y_true = query_ys.cpu().detach().numpy()
  59. y_pred = predictions.cpu().detach().numpy()
  60. # ndcgs1.append(float(mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(y_true, y_pred)))
  61. # ndcgs3.append(float(mz.metrics.NormalizedDiscountedCumulativeGain(k=3)(y_true, y_pred)))
  62. ndcgs11.append(float(ndcg_score([y_true], [y_pred], k=1, sample_weight=None, ignore_ties=False)))
  63. ndcgs33.append(float(ndcg_score([y_true], [y_pred], k=3, sample_weight=None, ignore_ties=False)))
  64. del supp_xs, supp_ys, query_xs, query_ys, predictions, y_true, y_pred, loss_q
  65. # torch.cuda.empty_cache()
  66. # calculate metrics
  67. # losses_q = torch.stack(losses_q).mean(0)
  68. losses_q = np.array(losses_q).mean()
  69. print("mean of mse: ", losses_q)
  70. # n1 = np.array(ndcgs1).mean()
  71. # print("nDCG1: ", n1)
  72. print("nDCG1: ", np.array(ndcgs11).mean())
  73. # print("nDCG1: ", np.array(ndcgs111).mean())
  74. # n3 = np.array(ndcgs3).mean()
  75. # print("nDCG3: ", n3)
  76. print("nDCG3: ", np.array(ndcgs33).mean())
  77. # print("nDCG3: ", np.array(ndcgs333).mean())