meta-learning approach for solving cold start problem
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_test.py 2.8KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import os
  2. import torch
  3. import pickle
  4. import random
  5. from MeLU import MeLU
  6. from options import config, states
  7. from torch.nn import functional as F
  8. from torch.nn import L1Loss
  9. # from pytorchltr.evaluation import ndcg
  10. import matchzoo as mz
  11. import numpy as np
  12. def test(melu, total_dataset, batch_size, num_epoch):
  13. if config['use_cuda']:
  14. melu.cuda()
  15. test_set_size = len(total_dataset)
  16. trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models2.pkl")
  17. melu.load_state_dict(trained_state_dict)
  18. melu.eval()
  19. random.shuffle(total_dataset)
  20. a, b, c, d = zip(*total_dataset)
  21. losses_q = []
  22. predictions = None
  23. predictions_size = None
  24. # y_true = []
  25. # y_pred = []
  26. ndcgs1 = []
  27. ndcgs3 = []
  28. for iterator in range(test_set_size):
  29. # trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models.pkl")
  30. # melu.load_state_dict(trained_state_dict)
  31. # melu.eval()
  32. try:
  33. supp_xs = a[iterator].cuda()
  34. supp_ys = b[iterator].cuda()
  35. query_xs = c[iterator].cuda()
  36. query_ys = d[iterator].cuda()
  37. except IndexError:
  38. print("index error in test method")
  39. continue
  40. num_local_update = config['inner']
  41. query_set_y_pred = melu.forward(supp_xs, supp_ys, query_xs, num_local_update)
  42. l1 = L1Loss(reduction='mean')
  43. loss_q = l1(query_set_y_pred, query_ys)
  44. print("testing - iterator:{} - l1:{} ".format(iterator,loss_q))
  45. losses_q.append(loss_q)
  46. # if predictions is None:
  47. # predictions = query_set_y_pred
  48. # predictions_size = torch.FloatTensor(len(query_set_y_pred))
  49. # else:
  50. # predictions = torch.cat((predictions,query_set_y_pred),0)
  51. # predictions_size = torch.cat((predictions_size,torch.FloatTensor(len(query_set_y_pred))),0)
  52. # y_true.append(query_ys.cpu().detach().numpy())
  53. # y_pred.append(query_set_y_pred.cpu().detach().numpy())
  54. y_true = query_ys.cpu().detach().numpy()
  55. y_pred = query_set_y_pred.cpu().detach().numpy()
  56. ndcgs1.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(y_true,y_pred))
  57. ndcgs3.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=3)(y_true, y_pred))
  58. del supp_xs, supp_ys, query_xs, query_ys
  59. # calculate metrics
  60. print(losses_q)
  61. print("======================================")
  62. losses_q = torch.stack(losses_q).mean(0)
  63. print("mean of mse: ",losses_q)
  64. print("======================================")
  65. # n1 = ndcg(d, predictions.cuda(), predictions_size.cuda(), k=1)
  66. # n1 = mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(np.array(y_true),np.array(y_pred))
  67. n1 = np.array(ndcgs1).mean()
  68. print("nDCG1: ",n1)
  69. n3 = np.array(ndcgs3).mean()
  70. print("nDCG3: ", n3)