1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- import random
- from torch.nn import L1Loss
- import numpy as np
- from fast_adapt import fast_adapt
- from sklearn.metrics import ndcg_score
- import gc
- import pickle
- import os
-
-
- def hyper_test(embedding, head, trainer, batch_size, master_path, test_state, adaptation_step, num_epoch=None):
- test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
- indexes = list(np.arange(test_set_size))
- random.shuffle(indexes)
-
- # test_set_size = len(total_dataset)
- # random.shuffle(total_dataset)
- # a, b, c, d = zip(*total_dataset)
- # a, b, c, d = list(a), list(b), list(c), list(d)
- losses_q = []
- ndcgs11 = []
- ndcgs33 = []
-
- head.eval()
- trainer.eval()
-
- for iterator in range(test_set_size):
- a = pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, iterator), "rb"))
- b = pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, iterator), "rb"))
- c = pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, iterator), "rb"))
- d = pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, iterator), "rb"))
-
- try:
- supp_xs = a.cuda()
- supp_ys = b.cuda()
- query_xs = c.cuda()
- query_ys = d.cuda()
- except IndexError:
- print("index error in test method")
- continue
-
- learner = head.clone()
- temp_sxs = embedding(supp_xs)
- temp_qxs = embedding(query_xs)
-
- predictions, c = fast_adapt(
- learner,
- temp_sxs,
- temp_qxs,
- supp_ys,
- query_ys,
- adaptation_step,
- get_predictions=True,
- trainer=trainer,
- test=True,
- iteration=num_epoch
- )
-
- l1 = L1Loss(reduction='mean')
- loss_q = l1(predictions.view(-1), query_ys.cpu())
- losses_q.append(float(loss_q))
- predictions = predictions.view(-1)
- y_true = query_ys.cpu().detach().numpy()
- y_pred = predictions.cpu().detach().numpy()
- ndcgs11.append(float(ndcg_score([y_true], [y_pred], k=1, sample_weight=None, ignore_ties=False)))
- ndcgs33.append(float(ndcg_score([y_true], [y_pred], k=3, sample_weight=None, ignore_ties=False)))
-
- del supp_xs, supp_ys, query_xs, query_ys, predictions, y_true, y_pred, loss_q
-
- # calculate metrics
- try:
- losses_q = np.array(losses_q).mean()
- except:
- losses_q = 100
-
- try:
- ndcg1 = np.array(ndcgs11).mean()
- ndcg3 = np.array(ndcgs33).mean()
- except:
- ndcg1 = 0
- ndcg3 = 0
-
- head.train()
- trainer.train()
- gc.collect()
- return losses_q, ndcg1, ndcg3
|