import os import torch import pickle import random from MeLU import MeLU from options import config, states from torch.nn import functional as F from torch.nn import L1Loss # from pytorchltr.evaluation import ndcg import matchzoo as mz import numpy as np def test(melu, total_dataset, batch_size, num_epoch): if config['use_cuda']: melu.cuda() test_set_size = len(total_dataset) trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models2.pkl") melu.load_state_dict(trained_state_dict) melu.eval() random.shuffle(total_dataset) a, b, c, d = zip(*total_dataset) losses_q = [] predictions = None predictions_size = None # y_true = [] # y_pred = [] ndcgs1 = [] ndcgs3 = [] for iterator in range(test_set_size): # trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models.pkl") # melu.load_state_dict(trained_state_dict) # melu.eval() try: supp_xs = a[iterator].cuda() supp_ys = b[iterator].cuda() query_xs = c[iterator].cuda() query_ys = d[iterator].cuda() except IndexError: print("index error in test method") continue num_local_update = config['inner'] query_set_y_pred = melu.forward(supp_xs, supp_ys, query_xs, num_local_update) l1 = L1Loss(reduction='mean') loss_q = l1(query_set_y_pred, query_ys) print("testing - iterator:{} - l1:{} ".format(iterator,loss_q)) losses_q.append(loss_q) # if predictions is None: # predictions = query_set_y_pred # predictions_size = torch.FloatTensor(len(query_set_y_pred)) # else: # predictions = torch.cat((predictions,query_set_y_pred),0) # predictions_size = torch.cat((predictions_size,torch.FloatTensor(len(query_set_y_pred))),0) # y_true.append(query_ys.cpu().detach().numpy()) # y_pred.append(query_set_y_pred.cpu().detach().numpy()) y_true = query_ys.cpu().detach().numpy() y_pred = query_set_y_pred.cpu().detach().numpy() ndcgs1.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(y_true,y_pred)) ndcgs3.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=3)(y_true, y_pred)) del supp_xs, supp_ys, query_xs, query_ys # calculate metrics print(losses_q) print("======================================") losses_q = torch.stack(losses_q).mean(0) print("mean of mse: ",losses_q) print("======================================") # n1 = ndcg(d, predictions.cuda(), predictions_size.cuda(), k=1) # n1 = mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(np.array(y_true),np.array(y_pred)) n1 = np.array(ndcgs1).mean() print("nDCG1: ",n1) n3 = np.array(ndcgs3).mean() print("nDCG3: ", n3)