1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- import os
- import torch
- import pickle
- import random
- from MeLU import MeLU
- from options import config, states
- from torch.nn import functional as F
- from torch.nn import L1Loss
- # from pytorchltr.evaluation import ndcg
- import matchzoo as mz
- import numpy as np
-
-
- def test(melu, total_dataset, batch_size, num_epoch):
- if config['use_cuda']:
- melu.cuda()
-
- test_set_size = len(total_dataset)
-
- trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models2.pkl")
- melu.load_state_dict(trained_state_dict)
- melu.eval()
-
- random.shuffle(total_dataset)
- a, b, c, d = zip(*total_dataset)
-
- losses_q = []
- predictions = None
- predictions_size = None
-
- # y_true = []
- # y_pred = []
- ndcgs1 = []
- ndcgs3 = []
-
- for iterator in range(test_set_size):
- # trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models.pkl")
- # melu.load_state_dict(trained_state_dict)
- # melu.eval()
-
- try:
- supp_xs = a[iterator].cuda()
- supp_ys = b[iterator].cuda()
- query_xs = c[iterator].cuda()
- query_ys = d[iterator].cuda()
- except IndexError:
- print("index error in test method")
- continue
-
- num_local_update = config['inner']
- query_set_y_pred = melu.forward(supp_xs, supp_ys, query_xs, num_local_update)
-
- l1 = L1Loss(reduction='mean')
- loss_q = l1(query_set_y_pred, query_ys)
- print("testing - iterator:{} - l1:{} ".format(iterator,loss_q))
- losses_q.append(loss_q)
-
- # if predictions is None:
- # predictions = query_set_y_pred
- # predictions_size = torch.FloatTensor(len(query_set_y_pred))
- # else:
- # predictions = torch.cat((predictions,query_set_y_pred),0)
- # predictions_size = torch.cat((predictions_size,torch.FloatTensor(len(query_set_y_pred))),0)
- # y_true.append(query_ys.cpu().detach().numpy())
- # y_pred.append(query_set_y_pred.cpu().detach().numpy())
-
- y_true = query_ys.cpu().detach().numpy()
- y_pred = query_set_y_pred.cpu().detach().numpy()
- ndcgs1.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(y_true,y_pred))
- ndcgs3.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=3)(y_true, y_pred))
-
-
- del supp_xs, supp_ys, query_xs, query_ys
-
-
- # calculate metrics
- print(losses_q)
- print("======================================")
- losses_q = torch.stack(losses_q).mean(0)
- print("mean of mse: ",losses_q)
- print("======================================")
-
- # n1 = ndcg(d, predictions.cuda(), predictions_size.cuda(), k=1)
- # n1 = mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(np.array(y_true),np.array(y_pred))
- n1 = np.array(ndcgs1).mean()
- print("nDCG1: ",n1)
- n3 = np.array(ndcgs3).mean()
- print("nDCG3: ", n3)
|