@@ -22,7 +22,7 @@ def get_params(): | |||
args.add_argument("-bs", "--batch_size", default=1024, type=int) | |||
# args.add_argument("-lr", "--learning_rate", default=0.001, type=float) | |||
args.add_argument("-epo", "--epoch", default=1000, type=int) | |||
args.add_argument("-epo", "--epoch", default=100000, type=int) | |||
# args.add_argument("-prt_epo", "--print_epoch", default=100, type=int) | |||
# args.add_argument("-eval_epo", "--eval_epoch", default=1000, type=int) | |||
@@ -65,12 +65,12 @@ def main(num_samples, gpus_per_trial=2): | |||
# "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | |||
# "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | |||
# "lr": tune.loguniform(1e-4, 1e-1), | |||
# "batch_size": tune.choice([2, 4, 8, 16]) | |||
"embed_dim" : tune.choice([50,75,100,125,150,200]), | |||
"embed_dim" : tune.choice([50,75,100,125,150,200,300]), | |||
# "batch_size" : tune.choice([128,256,512,1024,2048]), | |||
"learning_rate" : tune.choice([0.1,0.01,0.005,0.001,0.0001]), | |||
"beta" : tune.choice([0.1,1,5,10]), | |||
"margin" : tune.choice([1]), | |||
"learning_rate" : tune.choice([0.1,0.01,0.004,0.005,0.007,0.001,0.0001]), | |||
"beta" : tune.choice([0.05,0.1,1,4,4.5,5,5.5,6,10]), | |||
"margin" : tune.choice([1,0.9,0.8,1.1,1.2]), | |||
# "sampler":sampler, | |||
# "sampler_test":sampler_test, | |||
@@ -98,7 +98,7 @@ def main(num_samples, gpus_per_trial=2): | |||
progress_reporter=reporter, | |||
log_to_file=True, | |||
# resume=True, | |||
local_dir="./ray_local_dir", | |||
local_dir="/media/external_10TB/10TB/maheri/metaTL_ray/ray_local_dir", | |||
name="metatl_rnn1", | |||
) | |||
@@ -41,10 +41,16 @@ def train_metatl(conf,checkpoint_dir=None): | |||
trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps) | |||
# trainer.train() | |||
if checkpoint_dir: | |||
print("===================== using checkpoint =====================") | |||
model_state, optimizer_state = torch.load( | |||
os.path.join(checkpoint_dir, "checkpoint")) | |||
trainer.MetaTL.load_state_dict(model_state) | |||
trainer.optimizer.load_state_dict(optimizer_state) | |||
for epoch in range(ps['epoch']): | |||
for epoch in range(int(ps['epoch']/1000)): | |||
for e in range(100): | |||
for e in range(1000): | |||
# sample one batch from data_loader | |||
train_task, curr_rel = trainer.train_data_loader.next_batch() | |||
loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel) | |||
@@ -55,11 +61,6 @@ def train_metatl(conf,checkpoint_dir=None): | |||
# print('Epoch {} Testing...'.format(e)) | |||
# test_data = self.eval(istest=True, epoch=e) | |||
if checkpoint_dir: | |||
model_state, optimizer_state = torch.load( | |||
os.path.join(checkpoint_dir, "checkpoint")) | |||
trainer.MetaTL.load_state_dict(model_state) | |||
trainer.optimizer.load_state_dict(optimizer_state) | |||
with tune.checkpoint_dir(epoch) as checkpoint_dir: | |||
path = os.path.join(checkpoint_dir, "checkpoint") | |||
@@ -68,6 +69,6 @@ def train_metatl(conf,checkpoint_dir=None): | |||
tune.report( | |||
MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"], | |||
Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"], | |||
training_iteration=epoch*100 | |||
training_iteration=epoch*1000 | |||
) | |||
@@ -20,7 +20,7 @@ class Trainer: | |||
self.epoch = parameter['epoch'] | |||
# self.print_epoch = parameter['print_epoch'] | |||
# self.eval_epoch = parameter['eval_epoch'] | |||
self.eval_epoch = 50 | |||
self.eval_epoch = 1000 | |||
self.device = torch.device('cuda:0') | |||
self.MetaTL = MetaTL(itemnum, parameter) |