import os import torch import torch.nn as nn from ray import tune import pickle import random import gc from trainer import Trainer import numpy as np from utils import * from sampler import * import os def train_metatl(conf,checkpoint_dir=None): SEED = conf["params"]['seed'] torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) torch.backends.cudnn.deterministic = True np.random.seed(SEED) random.seed(SEED) params = conf['params'] user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(params['dataset'], params['K']) sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=params['batch_size'], maxlen=params['K'], n_workers=1) sampler_test = DataLoader(user_input_test, user_test, itemnum, params) sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params) ps = { "batch_size" : conf["params"]['batch_size'], "learning_rate" : conf['learning_rate'], "epoch" : conf["params"]['epoch'], "beta" : conf['beta'], "embed_dim" : conf['embed_dim'], "margin" : conf['margin'], "K" : conf["params"]['K'], } trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps) # trainer.train() if checkpoint_dir: print("===================== using checkpoint =====================") model_state, optimizer_state = torch.load( os.path.join(checkpoint_dir, "checkpoint")) trainer.MetaTL.load_state_dict(model_state) trainer.optimizer.load_state_dict(optimizer_state) for epoch in range(int(ps['epoch']/1000)): for e in range(1000): # sample one batch from data_loader train_task, curr_rel = trainer.train_data_loader.next_batch() loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel) # do evaluation on specific epoch valid_data = trainer.eval(istest=False, epoch=(-1)) # print('Epoch {} Testing...'.format(e)) # test_data = self.eval(istest=True, epoch=e) with tune.checkpoint_dir(epoch) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") torch.save((trainer.MetaTL.state_dict(), trainer.optimizer.state_dict()), path) tune.report( MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"], Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"], training_iteration=epoch*1000 )