|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import os
- import torch
- import torch.nn as nn
- from ray import tune
- import pickle
- import random
- import gc
- from trainer import Trainer
- import numpy as np
- from utils import *
- from sampler import *
- import os
-
-
- def train_metatl(conf,checkpoint_dir=None):
-
- SEED = conf["params"]['seed']
- torch.manual_seed(SEED)
- torch.cuda.manual_seed(SEED)
- torch.backends.cudnn.deterministic = True
- np.random.seed(SEED)
- random.seed(SEED)
-
- params = conf['params']
- params['batch_size'] = conf['batch_size']
- params['number_of_neg'] = conf['number_of_neg']
-
- user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(params['dataset'], params['K'])
- sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=params['batch_size'], maxlen=params['K'], n_workers=1,params=params)
- sampler_test = DataLoader(user_input_test, user_test, itemnum, params)
- sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params)
-
- ps = {
- "batch_size" : conf['batch_size'],
- "learning_rate" : conf['learning_rate'],
- "epoch" : conf["params"]['epoch'],
- "beta" : conf['beta'],
- "embed_dim" : conf['embed_dim'],
- "margin" : conf['margin'],
- "K" : conf["params"]['K'],
-
- "number_of_neg" : conf["number_of_neg"],
- "loss_function" : conf["loss_function"],
- "eval_epoch" : conf["eval_epoch"],
- "device" : params['device']
- }
-
- trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps)
-
- if checkpoint_dir:
- print("===================== using checkpoint =====================")
- model_state, optimizer_state = torch.load(
- os.path.join(checkpoint_dir, "checkpoint"))
- trainer.MetaTL.load_state_dict(model_state)
- trainer.optimizer.load_state_dict(optimizer_state)
-
- for epoch in range(int(ps['epoch']/ps['eval_epoch'])):
- for e in range(ps['eval_epoch']):
- # sample one batch from data_loader
- train_task, curr_rel = trainer.train_data_loader.next_batch()
- loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel)
-
- # do evaluation on specific epoch
- valid_data = trainer.eval(istest=False, epoch=(-1))
-
- with tune.checkpoint_dir(epoch) as checkpoint_dir:
- path = os.path.join(checkpoint_dir, "checkpoint")
- torch.save((trainer.MetaTL.state_dict(), trainer.optimizer.state_dict()), path)
-
- tune.report(
- MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"],
- Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"],
- training_iteration=epoch*ps['eval_epoch']
- )
-
|