args.add_argument("-bs", "--batch_size", default=1024, type=int) | args.add_argument("-bs", "--batch_size", default=1024, type=int) | ||||
# args.add_argument("-lr", "--learning_rate", default=0.001, type=float) | # args.add_argument("-lr", "--learning_rate", default=0.001, type=float) | ||||
args.add_argument("-epo", "--epoch", default=1000, type=int) | |||||
args.add_argument("-epo", "--epoch", default=100000, type=int) | |||||
# args.add_argument("-prt_epo", "--print_epoch", default=100, type=int) | # args.add_argument("-prt_epo", "--print_epoch", default=100, type=int) | ||||
# args.add_argument("-eval_epo", "--eval_epoch", default=1000, type=int) | # args.add_argument("-eval_epo", "--eval_epoch", default=1000, type=int) | ||||
# "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | # "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | ||||
# "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | # "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | ||||
# "lr": tune.loguniform(1e-4, 1e-1), | # "lr": tune.loguniform(1e-4, 1e-1), | ||||
# "batch_size": tune.choice([2, 4, 8, 16]) | |||||
"embed_dim" : tune.choice([50,75,100,125,150,200]), | |||||
"embed_dim" : tune.choice([50,75,100,125,150,200,300]), | |||||
# "batch_size" : tune.choice([128,256,512,1024,2048]), | # "batch_size" : tune.choice([128,256,512,1024,2048]), | ||||
"learning_rate" : tune.choice([0.1,0.01,0.005,0.001,0.0001]), | |||||
"beta" : tune.choice([0.1,1,5,10]), | |||||
"margin" : tune.choice([1]), | |||||
"learning_rate" : tune.choice([0.1,0.01,0.004,0.005,0.007,0.001,0.0001]), | |||||
"beta" : tune.choice([0.05,0.1,1,4,4.5,5,5.5,6,10]), | |||||
"margin" : tune.choice([1,0.9,0.8,1.1,1.2]), | |||||
# "sampler":sampler, | # "sampler":sampler, | ||||
# "sampler_test":sampler_test, | # "sampler_test":sampler_test, | ||||
progress_reporter=reporter, | progress_reporter=reporter, | ||||
log_to_file=True, | log_to_file=True, | ||||
# resume=True, | # resume=True, | ||||
local_dir="./ray_local_dir", | |||||
local_dir="/media/external_10TB/10TB/maheri/metaTL_ray/ray_local_dir", | |||||
name="metatl_rnn1", | name="metatl_rnn1", | ||||
) | ) | ||||
trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps) | trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps) | ||||
# trainer.train() | # trainer.train() | ||||
if checkpoint_dir: | |||||
print("===================== using checkpoint =====================") | |||||
model_state, optimizer_state = torch.load( | |||||
os.path.join(checkpoint_dir, "checkpoint")) | |||||
trainer.MetaTL.load_state_dict(model_state) | |||||
trainer.optimizer.load_state_dict(optimizer_state) | |||||
for epoch in range(ps['epoch']): | |||||
for epoch in range(int(ps['epoch']/1000)): | |||||
for e in range(100): | |||||
for e in range(1000): | |||||
# sample one batch from data_loader | # sample one batch from data_loader | ||||
train_task, curr_rel = trainer.train_data_loader.next_batch() | train_task, curr_rel = trainer.train_data_loader.next_batch() | ||||
loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel) | loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel) | ||||
# print('Epoch {} Testing...'.format(e)) | # print('Epoch {} Testing...'.format(e)) | ||||
# test_data = self.eval(istest=True, epoch=e) | # test_data = self.eval(istest=True, epoch=e) | ||||
if checkpoint_dir: | |||||
model_state, optimizer_state = torch.load( | |||||
os.path.join(checkpoint_dir, "checkpoint")) | |||||
trainer.MetaTL.load_state_dict(model_state) | |||||
trainer.optimizer.load_state_dict(optimizer_state) | |||||
with tune.checkpoint_dir(epoch) as checkpoint_dir: | with tune.checkpoint_dir(epoch) as checkpoint_dir: | ||||
path = os.path.join(checkpoint_dir, "checkpoint") | path = os.path.join(checkpoint_dir, "checkpoint") | ||||
tune.report( | tune.report( | ||||
MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"], | MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"], | ||||
Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"], | Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"], | ||||
training_iteration=epoch*100 | |||||
training_iteration=epoch*1000 | |||||
) | ) | ||||
self.epoch = parameter['epoch'] | self.epoch = parameter['epoch'] | ||||
# self.print_epoch = parameter['print_epoch'] | # self.print_epoch = parameter['print_epoch'] | ||||
# self.eval_epoch = parameter['eval_epoch'] | # self.eval_epoch = parameter['eval_epoch'] | ||||
self.eval_epoch = 50 | |||||
self.eval_epoch = 1000 | |||||
self.device = torch.device('cuda:0') | self.device = torch.device('cuda:0') | ||||
self.MetaTL = MetaTL(itemnum, parameter) | self.MetaTL = MetaTL(itemnum, parameter) |