import random from torch.nn import L1Loss import numpy as np from fast_adapt import fast_adapt from sklearn.metrics import ndcg_score import gc import pickle import os def hyper_test(embedding, head, trainer, batch_size, master_path, test_state, adaptation_step, num_epoch=None): test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) indexes = list(np.arange(test_set_size)) random.shuffle(indexes) # test_set_size = len(total_dataset) # random.shuffle(total_dataset) # a, b, c, d = zip(*total_dataset) # a, b, c, d = list(a), list(b), list(c), list(d) losses_q = [] ndcgs11 = [] ndcgs33 = [] head.eval() trainer.eval() for iterator in range(test_set_size): a = pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, iterator), "rb")) b = pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, iterator), "rb")) c = pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, iterator), "rb")) d = pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, iterator), "rb")) try: supp_xs = a.cuda() supp_ys = b.cuda() query_xs = c.cuda() query_ys = d.cuda() except IndexError: print("index error in test method") continue learner = head.clone() temp_sxs = embedding(supp_xs) temp_qxs = embedding(query_xs) predictions, c = fast_adapt( learner, temp_sxs, temp_qxs, supp_ys, query_ys, adaptation_step, get_predictions=True, trainer=trainer, test=True, iteration=num_epoch ) l1 = L1Loss(reduction='mean') loss_q = l1(predictions.view(-1), query_ys.cpu()) losses_q.append(float(loss_q)) predictions = predictions.view(-1) y_true = query_ys.cpu().detach().numpy() y_pred = predictions.cpu().detach().numpy() ndcgs11.append(float(ndcg_score([y_true], [y_pred], k=1, sample_weight=None, ignore_ties=False))) ndcgs33.append(float(ndcg_score([y_true], [y_pred], k=3, sample_weight=None, ignore_ties=False))) del supp_xs, supp_ys, query_xs, query_ys, predictions, y_true, y_pred, loss_q # calculate metrics try: losses_q = np.array(losses_q).mean() except: losses_q = 100 try: ndcg1 = np.array(ndcgs11).mean() ndcg3 = np.array(ndcgs33).mean() except: ndcg1 = 0 ndcg3 = 0 head.train() trainer.train() gc.collect() return losses_q, ndcg1, ndcg3