123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- from models import *
- import os
- import sys
- import torch
- import shutil
- import logging
- import numpy as np
-
-
- class Trainer:
- def __init__(self, data_loaders, itemnum, parameter):
- self.parameter = parameter
- # data loader
- self.train_data_loader = data_loaders[0]
- self.dev_data_loader = data_loaders[1]
- self.test_data_loader = data_loaders[2]
- # parameters
- self.batch_size = parameter['batch_size']
- self.learning_rate = parameter['learning_rate']
- self.epoch = parameter['epoch']
- self.print_epoch = parameter['print_epoch']
- self.eval_epoch = parameter['eval_epoch']
- self.device = parameter['device']
-
- self.MetaTL = MetaTL(itemnum, parameter)
- self.MetaTL.to(self.device)
-
- self.optimizer = torch.optim.Adam(self.MetaTL.parameters(), self.learning_rate)
-
-
- def rank_predict(self, data, x, ranks):
- # query_idx is the idx of positive score
- query_idx = x.shape[0] - 1
- # sort all scores with descending, because more plausible triple has higher score
- _, idx = torch.sort(x, descending=True)
- rank = list(idx.cpu().numpy()).index(query_idx) + 1
- ranks.append(rank)
- # update data
- if rank <= 10:
- data['Hits@10'] += 1
- data['NDCG@10'] += 1 / np.log2(rank + 1)
- if rank <= 5:
- data['Hits@5'] += 1
- data['NDCG@5'] += 1 / np.log2(rank + 1)
- if rank == 1:
- data['Hits@1'] += 1
- data['NDCG@1'] += 1 / np.log2(rank + 1)
- data['MRR'] += 1.0 / rank
-
- def do_one_step(self, task, iseval=False, curr_rel=''):
- loss, p_score, n_score = 0, 0, 0
- if not iseval:
- self.optimizer.zero_grad()
- p_score, n_score = self.MetaTL(task, iseval, curr_rel)
- y = torch.Tensor([1]).to(self.device)
- loss = self.MetaTL.loss_func(p_score, n_score, y)
- loss.backward()
- self.optimizer.step()
- elif curr_rel != '':
- p_score, n_score = self.MetaTL(task, iseval, curr_rel)
- y = torch.Tensor([1]).to(self.device)
- loss = self.MetaTL.loss_func(p_score, n_score, y)
- return loss, p_score, n_score
-
- def train(self):
- # initialization
- best_epoch = 0
- best_value = 0
- bad_counts = 0
-
- # training by epoch
- for e in range(self.epoch):
- # sample one batch from data_loader
- train_task, curr_rel = self.train_data_loader.next_batch()
- loss, _, _ = self.do_one_step(train_task, iseval=False, curr_rel=curr_rel)
- # print the loss on specific epoch
- if e % self.print_epoch == 0:
- loss_num = loss.item()
- print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num))
- # do evaluation on specific epoch
- if e % self.eval_epoch == 0 and e != 0:
- print('Epoch {} Validating...'.format(e))
- valid_data = self.eval(istest=False, epoch=e)
-
- print('Epoch {} Testing...'.format(e))
- test_data = self.eval(istest=True, epoch=e)
-
- print('Finish')
-
- def eval(self, istest=False, epoch=None):
- # self.MetaTL.eval()
-
- self.MetaTL.rel_q_sharing = dict()
-
- if istest:
- data_loader = self.test_data_loader
- else:
- data_loader = self.dev_data_loader
- data_loader.curr_tri_idx = 0
-
- # initial return data of validation
- data = {'MRR': 0, 'Hits@1': 0, 'Hits@5': 0, 'Hits@10': 0, 'NDCG@1': 0, 'NDCG@5': 0, 'NDCG@10': 0}
- ranks = []
-
- t = 0
- temp = dict()
- while True:
- # sample all the eval tasks
- eval_task, curr_rel = data_loader.next_one_on_eval()
- # at the end of sample tasks, a symbol 'EOT' will return
- if eval_task == 'EOT':
- break
- t += 1
-
- _, p_score, n_score = self.do_one_step(eval_task, iseval=True, curr_rel=curr_rel)
-
- x = torch.cat([n_score, p_score], 1).squeeze()
-
- self.rank_predict(data, x, ranks)
-
- # print current temp data dynamically
- for k in data.keys():
- temp[k] = data[k] / t
- sys.stdout.write("{}\tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
- t, temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
- sys.stdout.flush()
-
- # print overall evaluation result and return it
- for k in data.keys():
- data[k] = round(data[k] / t, 3)
-
-
- if istest:
- print("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
- temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
- else:
- print("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format(
- temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1']))
-
- return data
|