Browse Source

efficient and generalized test method

master
mohamad maheri 2 years ago
parent
commit
f4744ac74a
3 changed files with 15 additions and 33 deletions
  1. 2
    0
      MeLU.py
  2. 6
    6
      main.py
  3. 7
    27
      model_test.py

+ 2
- 0
MeLU.py View File

@@ -89,6 +89,8 @@ class MeLU(torch.nn.Module):
query_set_y_pred = self.model(query_set_x)
self.model.load_state_dict(self.keep_weight)

del weight_for_local_update,loss,grad,support_set_y_pred


return query_set_y_pred


+ 6
- 6
main.py View File

@@ -51,19 +51,19 @@ if __name__ == "__main__":


print("start of test phase")
test_state = 'user_and_item_cold_state'
test_dataset = None
test_set_size = int(len(os.listdir("{}/user_cold_state".format(master_path))) / 4)
test_set_size = int(len(os.listdir("{}/{}".format(master_path,test_state))) / 4)
supp_xs_s = []
supp_ys_s = []
query_xs_s = []
query_ys_s = []
for idx in range(test_set_size):
supp_xs_s.append(pickle.load(open("{}/user_cold_state/supp_x_{}.pkl".format(master_path, idx), "rb")))
supp_ys_s.append(pickle.load(open("{}/user_cold_state/supp_y_{}.pkl".format(master_path, idx), "rb")))
query_xs_s.append(pickle.load(open("{}/user_cold_state/query_x_{}.pkl".format(master_path, idx), "rb")))
query_ys_s.append(pickle.load(open("{}/user_cold_state/query_y_{}.pkl".format(master_path, idx), "rb")))
supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path,test_state, idx), "rb")))
supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path,test_state, idx), "rb")))
query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path,test_state, idx), "rb")))
query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path,test_state, idx), "rb")))
test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)

model_filename = "{}/models_test.pkl".format(master_path)
test(melu, test_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'])

+ 7
- 27
model_test.py View File

@@ -23,20 +23,11 @@ def test(melu, total_dataset, batch_size, num_epoch):

random.shuffle(total_dataset)
a, b, c, d = zip(*total_dataset)

losses_q = []
predictions = None
predictions_size = None

# y_true = []
# y_pred = []
ndcgs1 = []
ndcgs3 = []

for iterator in range(test_set_size):
# trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models.pkl")
# melu.load_state_dict(trained_state_dict)
# melu.eval()

try:
supp_xs = a[iterator].cuda()
@@ -53,35 +44,24 @@ def test(melu, total_dataset, batch_size, num_epoch):
l1 = L1Loss(reduction='mean')
loss_q = l1(query_set_y_pred, query_ys)
print("testing - iterator:{} - l1:{} ".format(iterator,loss_q))
losses_q.append(loss_q)

# if predictions is None:
# predictions = query_set_y_pred
# predictions_size = torch.FloatTensor(len(query_set_y_pred))
# else:
# predictions = torch.cat((predictions,query_set_y_pred),0)
# predictions_size = torch.cat((predictions_size,torch.FloatTensor(len(query_set_y_pred))),0)
# y_true.append(query_ys.cpu().detach().numpy())
# y_pred.append(query_set_y_pred.cpu().detach().numpy())
losses_q.append(float(loss_q))

y_true = query_ys.cpu().detach().numpy()
y_pred = query_set_y_pred.cpu().detach().numpy()
ndcgs1.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(y_true,y_pred))
ndcgs3.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=3)(y_true, y_pred))
ndcgs1.append(float(mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(y_true,y_pred)))
ndcgs3.append(float(mz.metrics.NormalizedDiscountedCumulativeGain(k=3)(y_true, y_pred)))

del supp_xs, supp_ys, query_xs, query_ys
del supp_xs, supp_ys, query_xs, query_ys,query_set_y_pred,y_true,y_pred,loss_q
torch.cuda.empty_cache()


# calculate metrics
print(losses_q)
print("======================================")
losses_q = torch.stack(losses_q).mean(0)
# losses_q = torch.stack(losses_q).mean(0)
losses_q = np.array(losses_q).mean()
print("mean of mse: ",losses_q)
print("======================================")

# n1 = ndcg(d, predictions.cuda(), predictions_size.cuda(), k=1)
# n1 = mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(np.array(y_true),np.array(y_pred))
n1 = np.array(ndcgs1).mean()
print("nDCG1: ",n1)
n3 = np.array(ndcgs3).mean()

Loading…
Cancel
Save