from ray.tune.schedulers import ASHAScheduler | |||||
from ray.tune import CLIReporter | |||||
from ray import tune | |||||
from functools import partial | |||||
from hyper_tunning import train_metatl | |||||
import argparse | |||||
import numpy as np | |||||
import torch | |||||
import random | |||||
from trainer import * | |||||
from utils import * | |||||
from sampler import * | |||||
import copy | |||||
def get_params(): | |||||
args = argparse.ArgumentParser() | |||||
args.add_argument("-data", "--dataset", default="electronics", type=str) | |||||
args.add_argument("-seed", "--seed", default=None, type=int) | |||||
args.add_argument("-K", "--K", default=3, type=int) #NUMBER OF SHOT | |||||
# args.add_argument("-dim", "--embed_dim", default=100, type=int) | |||||
args.add_argument("-bs", "--batch_size", default=1024, type=int) | |||||
# args.add_argument("-lr", "--learning_rate", default=0.001, type=float) | |||||
args.add_argument("-epo", "--epoch", default=1000, type=int) | |||||
# args.add_argument("-prt_epo", "--print_epoch", default=100, type=int) | |||||
# args.add_argument("-eval_epo", "--eval_epoch", default=1000, type=int) | |||||
# args.add_argument("-b", "--beta", default=5, type=float) | |||||
# args.add_argument("-m", "--margin", default=1, type=float) | |||||
# args.add_argument("-p", "--dropout_p", default=0.5, type=float) | |||||
# args.add_argument("-gpu", "--device", default=1, type=int) | |||||
args = args.parse_args() | |||||
params = {} | |||||
for k, v in vars(args).items(): | |||||
params[k] = v | |||||
params['device'] = torch.device('cuda:0') | |||||
return params, args | |||||
def main(num_samples, gpus_per_trial=2): | |||||
params, args = get_params() | |||||
if params['seed'] is not None: | |||||
SEED = params['seed'] | |||||
torch.manual_seed(SEED) | |||||
torch.cuda.manual_seed(SEED) | |||||
torch.backends.cudnn.deterministic = True | |||||
np.random.seed(SEED) | |||||
random.seed(SEED) | |||||
user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(args.dataset, args.K) | |||||
batch_size = params['batch_size'] | |||||
# sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=batch_size, maxlen=args.K, n_workers=1) | |||||
# sampler_test = DataLoader(user_input_test, user_test, itemnum, params) | |||||
# sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params) | |||||
config = { | |||||
# "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | |||||
# "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | |||||
# "lr": tune.loguniform(1e-4, 1e-1), | |||||
# "batch_size": tune.choice([2, 4, 8, 16]) | |||||
"embed_dim" : tune.choice([50,75,100,125,150,200]), | |||||
# "batch_size" : tune.choice([128,256,512,1024,2048]), | |||||
"learning_rate" : tune.choice([0.1,0.01,0.005,0.001,0.0001]), | |||||
"beta" : tune.choice([0.1,1,5,10]), | |||||
"margin" : tune.choice([1]), | |||||
# "sampler":sampler, | |||||
# "sampler_test":sampler_test, | |||||
# "sampler_valid":sampler_valid, | |||||
"itemnum":itemnum, | |||||
"params":params, | |||||
} | |||||
scheduler = ASHAScheduler( | |||||
metric="MRR", | |||||
mode="max", | |||||
max_t=params['epoch'], | |||||
grace_period=200, | |||||
reduction_factor=2) | |||||
reporter = CLIReporter( | |||||
# parameter_columns=["l1", "l2", "lr", "batch_size"], | |||||
metric_columns=["MRR","NDCG10","NDCG5","NDCG1","Hits10","Hits5","Hits1","training_iteration"]) | |||||
result = tune.run( | |||||
train_metatl, | |||||
resources_per_trial={"cpu": 4, "gpu": gpus_per_trial}, | |||||
config=config, | |||||
num_samples=num_samples, | |||||
scheduler=scheduler, | |||||
progress_reporter=reporter, | |||||
log_to_file=True, | |||||
# resume=True, | |||||
local_dir="./ray_local_dir", | |||||
name="metatl_rnn1", | |||||
) | |||||
best_trial = result.get_best_trial("MRR", "max", "last") | |||||
print("Best trial config: {}".format(best_trial.config)) | |||||
print("Best trial final validation loss: {}".format( | |||||
best_trial.last_result["loss"])) | |||||
print("Best trial final validation MRR: {}".format( | |||||
best_trial.last_result["MRR"])) | |||||
print("Best trial final validation NDCG@1: {}".format( | |||||
best_trial.last_result["NDCG@1"])) | |||||
# | |||||
print("=======================================================") | |||||
print(result.results_df) | |||||
print("=======================================================\n") | |||||
# best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"]) | |||||
# device = "cpu" | |||||
# if torch.cuda.is_available(): | |||||
# device = "cuda:0" | |||||
# if gpus_per_trial > 1: | |||||
# best_trained_model = nn.DataParallel(best_trained_model) | |||||
# best_trained_model.to(device) | |||||
# | |||||
# best_checkpoint_dir = best_trial.checkpoint.value | |||||
# model_state, optimizer_state = torch.load(os.path.join( | |||||
# best_checkpoint_dir, "checkpoint")) | |||||
# best_trained_model.load_state_dict(model_state) | |||||
# | |||||
# test_acc = test_accuracy(best_trained_model, device) | |||||
# print("Best trial test set accuracy: {}".format(test_acc)) | |||||
if __name__ == "__main__": | |||||
# You can change the number of GPUs per trial here: | |||||
main(num_samples=150, gpus_per_trial=1) |
import os | |||||
import torch | |||||
import torch.nn as nn | |||||
from ray import tune | |||||
import pickle | |||||
import random | |||||
import gc | |||||
from trainer import Trainer | |||||
import numpy as np | |||||
from utils import * | |||||
from sampler import * | |||||
import os | |||||
def train_metatl(conf,checkpoint_dir=None): | |||||
SEED = conf["params"]['seed'] | |||||
torch.manual_seed(SEED) | |||||
torch.cuda.manual_seed(SEED) | |||||
torch.backends.cudnn.deterministic = True | |||||
np.random.seed(SEED) | |||||
random.seed(SEED) | |||||
params = conf['params'] | |||||
user_train, usernum_train, itemnum, user_input_test, user_test, user_input_valid, user_valid = data_load(params['dataset'], params['K']) | |||||
sampler = WarpSampler(user_train, usernum_train, itemnum, batch_size=params['batch_size'], maxlen=params['K'], n_workers=1) | |||||
sampler_test = DataLoader(user_input_test, user_test, itemnum, params) | |||||
sampler_valid = DataLoader(user_input_valid, user_valid, itemnum, params) | |||||
ps = { | |||||
"batch_size" : conf["params"]['batch_size'], | |||||
"learning_rate" : conf['learning_rate'], | |||||
"epoch" : conf["params"]['epoch'], | |||||
"beta" : conf['beta'], | |||||
"embed_dim" : conf['embed_dim'], | |||||
"margin" : conf['margin'], | |||||
"K" : conf["params"]['K'], | |||||
} | |||||
trainer = Trainer([sampler, sampler_valid, sampler_test], conf["itemnum"], ps) | |||||
# trainer.train() | |||||
for epoch in range(ps['epoch']): | |||||
for e in range(100): | |||||
# sample one batch from data_loader | |||||
train_task, curr_rel = trainer.train_data_loader.next_batch() | |||||
loss, _, _ = trainer.do_one_step(train_task, iseval=False, curr_rel=curr_rel) | |||||
# do evaluation on specific epoch | |||||
valid_data = trainer.eval(istest=False, epoch=(-1)) | |||||
# print('Epoch {} Testing...'.format(e)) | |||||
# test_data = self.eval(istest=True, epoch=e) | |||||
if checkpoint_dir: | |||||
model_state, optimizer_state = torch.load( | |||||
os.path.join(checkpoint_dir, "checkpoint")) | |||||
trainer.MetaTL.load_state_dict(model_state) | |||||
trainer.optimizer.load_state_dict(optimizer_state) | |||||
with tune.checkpoint_dir(epoch) as checkpoint_dir: | |||||
path = os.path.join(checkpoint_dir, "checkpoint") | |||||
torch.save((trainer.MetaTL.state_dict(), trainer.optimizer.state_dict()), path) | |||||
tune.report( | |||||
MRR=valid_data["MRR"], NDCG10=valid_data['NDCG@10'], NDCG5=valid_data["NDCG@5"], NDCG1=valid_data["NDCG@1"], | |||||
Hits10=valid_data["Hits@10"], Hits5=valid_data["Hits@5"], Hits1=valid_data["Hits@1"], | |||||
training_iteration=epoch*100 | |||||
) | |||||
class Embedding(nn.Module): | class Embedding(nn.Module): | ||||
def __init__(self, num_ent, parameter): | def __init__(self, num_ent, parameter): | ||||
super(Embedding, self).__init__() | super(Embedding, self).__init__() | ||||
self.device = parameter['device'] | |||||
self.device = torch.device('cuda:0') | |||||
self.es = parameter['embed_dim'] | self.es = parameter['embed_dim'] | ||||
self.embedding = nn.Embedding(num_ent + 1, self.es) | self.embedding = nn.Embedding(num_ent + 1, self.es) | ||||
super(MetaLearner, self).__init__() | super(MetaLearner, self).__init__() | ||||
self.embed_size = embed_size | self.embed_size = embed_size | ||||
self.K = K | self.K = K | ||||
self.out_size = out_size | |||||
self.hidden_size = out_size | |||||
self.rnn = nn.LSTM(embed_size,self.hidden_size,1) | |||||
# self.out_size = out_size | |||||
# self.hidden_size = out_size | |||||
self.out_size = embed_size | |||||
self.hidden_size = embed_size | |||||
self.rnn = nn.LSTM(embed_size,self.hidden_size,2,dropout=0.2) | |||||
# nn.init.xavier_normal_(self.rnn.all_weights) | # nn.init.xavier_normal_(self.rnn.all_weights) | ||||
def forward(self, inputs): | def forward(self, inputs): | ||||
class MetaTL(nn.Module): | class MetaTL(nn.Module): | ||||
def __init__(self, itemnum, parameter): | def __init__(self, itemnum, parameter): | ||||
super(MetaTL, self).__init__() | super(MetaTL, self).__init__() | ||||
self.device = parameter['device'] | |||||
self.device = torch.device('cuda:0') | |||||
self.beta = parameter['beta'] | self.beta = parameter['beta'] | ||||
self.dropout_p = parameter['dropout_p'] | |||||
# self.dropout_p = parameter['dropout_p'] | |||||
self.embed_dim = parameter['embed_dim'] | self.embed_dim = parameter['embed_dim'] | ||||
self.margin = parameter['margin'] | self.margin = parameter['margin'] | ||||
self.embedding = Embedding(itemnum, parameter) | self.embedding = Embedding(itemnum, parameter) | ||||
self.relation_learner = MetaLearner(parameter['K'] - 1, embed_size=100, num_hidden1=500, | |||||
num_hidden2=200, out_size=100, dropout_p=self.dropout_p) | |||||
self.relation_learner = MetaLearner(parameter['K'] - 1, embed_size=self.embed_dim, num_hidden1=500, | |||||
num_hidden2=200, out_size=100, dropout_p=0) | |||||
self.embedding_learner = EmbeddingLearner() | self.embedding_learner = EmbeddingLearner() | ||||
self.loss_func = nn.MarginRankingLoss(self.margin) | self.loss_func = nn.MarginRankingLoss(self.margin) |
self.batch_size = parameter['batch_size'] | self.batch_size = parameter['batch_size'] | ||||
self.learning_rate = parameter['learning_rate'] | self.learning_rate = parameter['learning_rate'] | ||||
self.epoch = parameter['epoch'] | self.epoch = parameter['epoch'] | ||||
self.print_epoch = parameter['print_epoch'] | |||||
self.eval_epoch = parameter['eval_epoch'] | |||||
self.device = parameter['device'] | |||||
# self.print_epoch = parameter['print_epoch'] | |||||
# self.eval_epoch = parameter['eval_epoch'] | |||||
self.eval_epoch = 50 | |||||
self.device = torch.device('cuda:0') | |||||
self.MetaTL = MetaTL(itemnum, parameter) | self.MetaTL = MetaTL(itemnum, parameter) | ||||
self.MetaTL.to(self.device) | self.MetaTL.to(self.device) | ||||
train_task, curr_rel = self.train_data_loader.next_batch() | train_task, curr_rel = self.train_data_loader.next_batch() | ||||
loss, _, _ = self.do_one_step(train_task, iseval=False, curr_rel=curr_rel) | loss, _, _ = self.do_one_step(train_task, iseval=False, curr_rel=curr_rel) | ||||
# print the loss on specific epoch | # print the loss on specific epoch | ||||
if e % self.print_epoch == 0: | |||||
loss_num = loss.item() | |||||
print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num)) | |||||
# if e % self.print_epoch == 0: | |||||
# loss_num = loss.item() | |||||
# print("Epoch: {}\tLoss: {:.4f}".format(e, loss_num)) | |||||
# do evaluation on specific epoch | # do evaluation on specific epoch | ||||
if e % self.eval_epoch == 0 and e != 0: | if e % self.eval_epoch == 0 and e != 0: | ||||
print('Epoch {} Validating...'.format(e)) | print('Epoch {} Validating...'.format(e)) | ||||
# print current temp data dynamically | # print current temp data dynamically | ||||
for k in data.keys(): | for k in data.keys(): | ||||
temp[k] = data[k] / t | temp[k] = data[k] / t | ||||
sys.stdout.write("{}\tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( | |||||
t, temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1'])) | |||||
sys.stdout.flush() | |||||
# sys.stdout.write("{}\tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( | |||||
# t, temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1'])) | |||||
# sys.stdout.flush() | |||||
# print overall evaluation result and return it | # print overall evaluation result and return it | ||||
for k in data.keys(): | for k in data.keys(): | ||||
if istest: | if istest: | ||||
print("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( | |||||
print("TEST: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( | |||||
temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1'])) | temp['MRR'], temp['NDCG@10'], temp['NDCG@5'], temp['NDCG@1'], temp['Hits@10'], temp['Hits@5'], temp['Hits@1'])) | ||||
else: | else: | ||||
print("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( | print("VALID: \tMRR: {:.3f}\tNDCG@10: {:.3f}\tNDCG@5: {:.3f}\tNDCG@1: {:.3f}\tHits@10: {:.3f}\tHits@5: {:.3f}\tHits@1: {:.3f}\r".format( |
user_train = defaultdict(list) | user_train = defaultdict(list) | ||||
# assume user/item index starting from 1 | # assume user/item index starting from 1 | ||||
f = open('data/%s/%s_train.csv' % (fname, fname), 'r') | |||||
f = open('/home/maheri/metaTL/data/%s/%s_train.csv' % (fname, fname), 'r') | |||||
for line in f: | for line in f: | ||||
u, i, t = line.rstrip().split('\t') | u, i, t = line.rstrip().split('\t') | ||||
u = int(u) | u = int(u) | ||||
User_test_new = defaultdict(list) | User_test_new = defaultdict(list) | ||||
f = open('data/%s/%s_test_new_user.csv' % (fname, fname), 'r') | |||||
f = open('/home/maheri/metaTL/data/%s/%s_test_new_user.csv' % (fname, fname), 'r') | |||||
for line in f: | for line in f: | ||||
u, i, t = line.rstrip().split('\t') | u, i, t = line.rstrip().split('\t') | ||||
u = int(u) | u = int(u) |