make other meta-learning algorithms implemented in l2l.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learnToLearnTest.py 2.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. import torch
  3. import pickle
  4. import random
  5. from options import config, states
  6. from torch.nn import functional as F
  7. from torch.nn import L1Loss
  8. import matchzoo as mz
  9. import numpy as np
  10. from fast_adapt import fast_adapt
  11. def test(embedding,head, total_dataset, batch_size, num_epoch):
  12. test_set_size = len(total_dataset)
  13. random.shuffle(total_dataset)
  14. a, b, c, d = zip(*total_dataset)
  15. losses_q = []
  16. ndcgs1 = []
  17. ndcgs3 = []
  18. for iterator in range(test_set_size):
  19. if config['use_cuda']:
  20. try:
  21. supp_xs = a[iterator].cuda()
  22. supp_ys = b[iterator].cuda()
  23. query_xs = c[iterator].cuda()
  24. query_ys = d[iterator].cuda()
  25. except IndexError:
  26. print("index error in test method")
  27. continue
  28. else:
  29. try:
  30. supp_xs = a[iterator]
  31. supp_ys = b[iterator]
  32. query_xs = c[iterator]
  33. query_ys = d[iterator]
  34. except IndexError:
  35. print("index error in test method")
  36. continue
  37. num_local_update = config['inner']
  38. learner = head.clone()
  39. temp_sxs = embedding(supp_xs)
  40. temp_qxs = embedding(query_xs)
  41. evaluation_error,predictions = fast_adapt(learner,
  42. temp_sxs,
  43. temp_qxs,
  44. supp_ys,
  45. query_ys,
  46. config['inner'],
  47. get_predictions=True
  48. )
  49. l1 = L1Loss(reduction='mean')
  50. loss_q = l1(predictions.view(-1), query_ys)
  51. # print("testing - iterator:{} - l1:{} ".format(iterator,loss_q))
  52. losses_q.append(float(loss_q))
  53. y_true = query_ys.cpu().detach().numpy()
  54. y_pred = predictions.cpu().detach().numpy()
  55. ndcgs1.append(float(mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(y_true, y_pred)))
  56. ndcgs3.append(float(mz.metrics.NormalizedDiscountedCumulativeGain(k=3)(y_true, y_pred)))
  57. del supp_xs, supp_ys, query_xs, query_ys, predictions, y_true, y_pred, loss_q
  58. # torch.cuda.empty_cache()
  59. # calculate metrics
  60. # losses_q = torch.stack(losses_q).mean(0)
  61. losses_q = np.array(losses_q).mean()
  62. print("mean of mse: ", losses_q)
  63. n1 = np.array(ndcgs1).mean()
  64. print("nDCG1: ", n1)
  65. n3 = np.array(ndcgs3).mean()
  66. print("nDCG3: ", n3)