1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- import os
- import torch
- import torch.nn as nn
- from ray import tune
- import pickle
- import random
- import gc
- from trainer import Trainer
- import numpy as np
- from utils import *
- from sampler import *
- import os
-
-
- def train_metatl(conf,checkpoint_dir=None):
-
- SEED = conf["params"]['seed']
- torch.manual_seed(SEED)
- torch.cuda.manual_seed(SEED)
- torch.backends.cudnn.deterministic = True
- np.random.seed(SEED)
- random.seed(SEED)
-
- params = conf['params']
- user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(params['dataset'], params['K'])
- sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=params['batch_size'], maxlen=params['K'], n_workers=1)
- sampler_test = DataLoader(user_input_test, user_test, itemnum, params)
- sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params)
-
- ps = {
- "batch_size" : conf["params"]['batch_size'],
- "learning_rate" : conf['learning_rate'],
- "epoch" : conf["params"]['epoch'],
- "beta" : conf['beta'],
- "embed_dim" : conf['embed_dim'],
- "margin" : conf['margin'],
- "K" : conf["params"]['K'],
-
- }
-
- trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps)
-
- # trainer.train()
- if checkpoint_dir:
- print("===================== using checkpoint =====================")
- model_state, optimizer_state = torch.load(
- os.path.join(checkpoint_dir, "checkpoint"))
- trainer.MetaTL.load_state_dict(model_state)
- trainer.optimizer.load_state_dict(optimizer_state)
-
- for epoch in range(int(ps['epoch']/1000)):
-
- for e in range(1000):
- # sample one batch from data_loader
- train_task, curr_rel = trainer.train_data_loader.next_batch()
- loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel)
-
- # do evaluation on specific epoch
- valid_data = trainer.eval(istest=False, epoch=(-1))
-
- # print('Epoch {} Testing...'.format(e))
- # test_data = self.eval(istest=True, epoch=e)
-
-
- with tune.checkpoint_dir(epoch) as checkpoint_dir:
- path = os.path.join(checkpoint_dir, "checkpoint")
- torch.save((trainer.MetaTL.state_dict(), trainer.optimizer.state_dict()), path)
-
- tune.report(
- MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"],
- Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"],
- training_iteration=epoch*1000
- )
|