 mohamad maheri
					
					4 years ago
						mohamad maheri
					
					4 years ago
				| args.add_argument("-bs", "--batch_size", default=1024, type=int) | args.add_argument("-bs", "--batch_size", default=1024, type=int) | ||||
| # args.add_argument("-lr", "--learning_rate", default=0.001, type=float) | # args.add_argument("-lr", "--learning_rate", default=0.001, type=float) | ||||
| args.add_argument("-epo", "--epoch", default=1000, type=int) | args.add_argument("-epo", "--epoch", default=100000, type=int) | ||||
| # args.add_argument("-prt_epo", "--print_epoch", default=100, type=int) | # args.add_argument("-prt_epo", "--print_epoch", default=100, type=int) | ||||
| # args.add_argument("-eval_epo", "--eval_epoch", default=1000, type=int) | # args.add_argument("-eval_epo", "--eval_epoch", default=1000, type=int) | ||||
| # "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | # "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | ||||
| # "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | # "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | ||||
| # "lr": tune.loguniform(1e-4, 1e-1), | # "lr": tune.loguniform(1e-4, 1e-1), | ||||
| # "batch_size": tune.choice([2, 4, 8, 16]) | "embed_dim" : tune.choice([50,75,100,125,150,200,300]), | ||||
| "embed_dim" : tune.choice([50,75,100,125,150,200]), | |||||
| # "batch_size" : tune.choice([128,256,512,1024,2048]), | # "batch_size" : tune.choice([128,256,512,1024,2048]), | ||||
| "learning_rate" : tune.choice([0.1,0.01,0.005,0.001,0.0001]), | "learning_rate" : tune.choice([0.1,0.01,0.004,0.005,0.007,0.001,0.0001]), | ||||
| "beta" : tune.choice([0.1,1,5,10]), | "beta" : tune.choice([0.05,0.1,1,4,4.5,5,5.5,6,10]), | ||||
| "margin" : tune.choice([1]), | "margin" : tune.choice([1,0.9,0.8,1.1,1.2]), | ||||
| # "sampler":sampler, | # "sampler":sampler, | ||||
| # "sampler_test":sampler_test, | # "sampler_test":sampler_test, | ||||
| progress_reporter=reporter, | progress_reporter=reporter, | ||||
| log_to_file=True, | log_to_file=True, | ||||
| # resume=True, | # resume=True, | ||||
| local_dir="./ray_local_dir", | local_dir="/media/external_10TB/10TB/maheri/metaTL_ray/ray_local_dir", | ||||
| name="metatl_rnn1", | name="metatl_rnn1", | ||||
| ) | ) | ||||
| trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps) | trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps) | ||||
| # trainer.train() | # trainer.train() | ||||
| if checkpoint_dir: | |||||
| print("===================== using checkpoint =====================") | |||||
| model_state, optimizer_state = torch.load( | |||||
| os.path.join(checkpoint_dir, "checkpoint")) | |||||
| trainer.MetaTL.load_state_dict(model_state) | |||||
| trainer.optimizer.load_state_dict(optimizer_state) | |||||
| for epoch in range(ps['epoch']): | for epoch in range(int(ps['epoch']/1000)): | ||||
| for e in range(100): | for e in range(1000): | ||||
| # sample one batch from data_loader | # sample one batch from data_loader | ||||
| train_task, curr_rel = trainer.train_data_loader.next_batch() | train_task, curr_rel = trainer.train_data_loader.next_batch() | ||||
| loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel) | loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel) | ||||
| # print('Epoch {} Testing...'.format(e)) | # print('Epoch {} Testing...'.format(e)) | ||||
| # test_data = self.eval(istest=True, epoch=e) | # test_data = self.eval(istest=True, epoch=e) | ||||
| if checkpoint_dir: | |||||
| model_state, optimizer_state = torch.load( | |||||
| os.path.join(checkpoint_dir, "checkpoint")) | |||||
| trainer.MetaTL.load_state_dict(model_state) | |||||
| trainer.optimizer.load_state_dict(optimizer_state) | |||||
| with tune.checkpoint_dir(epoch) as checkpoint_dir: | with tune.checkpoint_dir(epoch) as checkpoint_dir: | ||||
| path = os.path.join(checkpoint_dir, "checkpoint") | path = os.path.join(checkpoint_dir, "checkpoint") | ||||
| tune.report( | tune.report( | ||||
| MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"], | MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"], | ||||
| Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"], | Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"], | ||||
| training_iteration=epoch*100 | training_iteration=epoch*1000 | ||||
| ) | ) | ||||
| self.epoch = parameter['epoch'] | self.epoch = parameter['epoch'] | ||||
| # self.print_epoch = parameter['print_epoch'] | # self.print_epoch = parameter['print_epoch'] | ||||
| # self.eval_epoch = parameter['eval_epoch'] | # self.eval_epoch = parameter['eval_epoch'] | ||||
| self.eval_epoch = 50 | self.eval_epoch = 1000 | ||||
| self.device = torch.device('cuda:0') | self.device = torch.device('cuda:0') | ||||
| self.MetaTL = MetaTL(itemnum, parameter) | self.MetaTL = MetaTL(itemnum, parameter) |