Adapted to Movie lens dataset
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 1.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. """
  2. Run evaluation with saved models.
  3. """
  4. import random
  5. import argparse
  6. from tqdm import tqdm
  7. import torch
  8. from utils.scorer import *
  9. def testing(trainer, opt, test_dataset):
  10. test_dataset_len = len(test_dataset)
  11. #batch_size = opt["batch_size"]
  12. minibatch_size = 1
  13. a, b, c, d = zip(*test_dataset)
  14. trainer.eval()
  15. all_loss = 0
  16. pre5 = []
  17. ap5 = []
  18. ndcg5 = []
  19. pre7 = []
  20. ap7 = []
  21. ndcg7 = []
  22. pre10 = []
  23. ap10 = []
  24. ndcg10 = []
  25. for i in range(test_dataset_len):
  26. try:
  27. supp_xs = list(a[minibatch_size * i:minibatch_size * (i + 1)])
  28. supp_ys = list(b[minibatch_size * i:minibatch_size * (i + 1)])
  29. query_xs = list(c[minibatch_size * i:minibatch_size * (i + 1)])
  30. query_ys = list(d[minibatch_size * i:minibatch_size * (i + 1)])
  31. except IndexError:
  32. continue
  33. test_loss, recommendation_list = trainer.query_rec(supp_xs, supp_ys, query_xs, query_ys)
  34. all_loss += test_loss
  35. add_metric(recommendation_list, query_ys[0].cpu().detach().numpy(), pre5, ap5, ndcg5, 5)
  36. add_metric(recommendation_list, query_ys[0].cpu().detach().numpy(), pre7, ap7, ndcg7, 7)
  37. add_metric(recommendation_list, query_ys[0].cpu().detach().numpy(), pre10, ap10, ndcg10, 10)
  38. mpre5, mndcg5, map5 = cal_metric(pre5, ap5, ndcg5)
  39. mpre7, mndcg7, map7 = cal_metric(pre7, ap7, ndcg7)
  40. mpre10, mndcg10, map10 = cal_metric(pre10, ap10, ndcg10)
  41. return mpre5, mndcg5, map5, mpre7, mndcg7, map7, mpre10, mndcg10, map10