import os import torch import pickle import random from options import config, states from torch.nn import functional as F from torch.nn import L1Loss # import matchzoo as mz import numpy as np from fast_adapt import fast_adapt from sklearn.metrics import ndcg_score import gc def test(embedding, head, total_dataset, batch_size, num_epoch, test_state=None,args=None): losses_q = [] ndcgs1 = [] ndcgs3 = [] master_path = config['master_path'] test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) indexes = list(np.arange(test_set_size)) random.shuffle(indexes) for iterator in indexes: a = pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, iterator), "rb")) b = pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, iterator), "rb")) c = pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, iterator), "rb")) d = pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, iterator), "rb")) try: supp_xs = a.cuda() supp_ys = b.cuda() query_xs = c.cuda() query_ys = d.cuda() except IndexError: print("index error in test method") continue learner = head.clone() temp_sxs = embedding(supp_xs) temp_qxs = embedding(query_xs) evaluation_error, predictions = fast_adapt(learner, temp_sxs, temp_qxs, supp_ys, query_ys, config['inner'], # args.inner_eval, get_predictions=True, epoch=0) l1 = L1Loss(reduction='mean') loss_q = l1(predictions.view(-1), query_ys) # print("testing - iterator:{} - l1:{} ".format(iterator,loss_q)) losses_q.append(float(loss_q)) predictions = predictions.view(-1) y_true = query_ys.cpu().detach().numpy() y_pred = predictions.cpu().detach().numpy() ndcgs1.append(float(ndcg_score([y_true], [y_pred], k=1, sample_weight=None, ignore_ties=False))) ndcgs3.append(float(ndcg_score([y_true], [y_pred], k=3, sample_weight=None, ignore_ties=False))) del supp_xs, supp_ys, query_xs, query_ys, y_true, y_pred, loss_q, temp_sxs, temp_qxs, predictions, l1 # torch.cuda.empty_cache() # calculate metrics losses_q = np.array(losses_q).mean() print("mean of mse: ", losses_q) # print("======================================") n1 = np.array(ndcgs1).mean() print("nDCG1: ", n1) n3 = np.array(ndcgs3).mean() print("nDCG3: ", n3) del a, b, c, d, total_dataset gc.collect() return losses_q, n1, n3