| from hyper_tunning import load_data | |||||
| import os | |||||
| from ray.tune.schedulers import ASHAScheduler | |||||
| from ray.tune import CLIReporter | |||||
| from ray import tune | |||||
| from functools import partial | |||||
| from hyper_tunning import train_melu | |||||
| import numpy as np | |||||
| def main(num_samples, max_num_epochs=20, gpus_per_trial=2): | |||||
| data_dir = os.path.abspath("/media/external_10TB/10TB/maheri/melu_data5") | |||||
| load_data(data_dir) | |||||
| config = { | |||||
| # "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | |||||
| # "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)), | |||||
| # "lr": tune.loguniform(1e-4, 1e-1), | |||||
| # "batch_size": tune.choice([2, 4, 8, 16]) | |||||
| "transformer": tune.choice(['kronoker']), | |||||
| "meta_algo":tune.choice(['gbml']), | |||||
| "first_order":tune.choice([False]), | |||||
| "adapt_transform":tune.choice([True,False]), | |||||
| # "local_lr":tune.choice([5e-6,5e-4,5e-3]), | |||||
| # "lr":tune.choice([5e-5,5e-4]), | |||||
| "local_lr":tune.loguniform(5e-6,5e-3), | |||||
| "lr":tune.loguniform(5e-5,5e-3), | |||||
| "batch_size":tune.choice([16,32,64]), | |||||
| "inner":tune.choice([7,5,4,3,1]), | |||||
| "test_state":tune.choice(["user_and_item_cold_state"]), | |||||
| # "epochs":tune.choice([5,10,20,25]), | |||||
| } | |||||
| scheduler = ASHAScheduler( | |||||
| metric="loss", | |||||
| mode="min", | |||||
| max_t=30, | |||||
| grace_period=6, | |||||
| reduction_factor=2) | |||||
| reporter = CLIReporter( | |||||
| # parameter_columns=["l1", "l2", "lr", "batch_size"], | |||||
| metric_columns=["loss", "ndcg1","ndcg3", "training_iteration"]) | |||||
| result = tune.run( | |||||
| partial(train_melu, data_dir=data_dir), | |||||
| resources_per_trial={"cpu": 4, "gpu": gpus_per_trial}, | |||||
| config=config, | |||||
| num_samples=num_samples, | |||||
| scheduler=scheduler, | |||||
| progress_reporter=reporter, | |||||
| log_to_file=True, | |||||
| # resume=True, | |||||
| local_dir="./hyper_tunning_all_cold", | |||||
| name="melu_all_cold", | |||||
| ) | |||||
| best_trial = result.get_best_trial("loss", "min", "last") | |||||
| print("Best trial config: {}".format(best_trial.config)) | |||||
| print("Best trial final validation loss: {}".format( | |||||
| best_trial.last_result["loss"])) | |||||
| print("Best trial final validation ndcg1: {}".format( | |||||
| best_trial.last_result["ndcg1"])) | |||||
| print("Best trial final validation ndcg3: {}".format( | |||||
| best_trial.last_result["ndcg3"])) | |||||
| # | |||||
| print("=======================================================") | |||||
| print(result.results_df) | |||||
| print("=======================================================\n") | |||||
| # best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"]) | |||||
| # device = "cpu" | |||||
| # if torch.cuda.is_available(): | |||||
| # device = "cuda:0" | |||||
| # if gpus_per_trial > 1: | |||||
| # best_trained_model = nn.DataParallel(best_trained_model) | |||||
| # best_trained_model.to(device) | |||||
| # | |||||
| # best_checkpoint_dir = best_trial.checkpoint.value | |||||
| # model_state, optimizer_state = torch.load(os.path.join( | |||||
| # best_checkpoint_dir, "checkpoint")) | |||||
| # best_trained_model.load_state_dict(model_state) | |||||
| # | |||||
| # test_acc = test_accuracy(best_trained_model, device) | |||||
| # print("Best trial test set accuracy: {}".format(test_acc)) | |||||
| if __name__ == "__main__": | |||||
| # You can change the number of GPUs per trial here: | |||||
| main(num_samples=150, max_num_epochs=25, gpus_per_trial=1) |
| import random | |||||
| from torch.nn import L1Loss | |||||
| import numpy as np | |||||
| from fast_adapt import fast_adapt | |||||
| from sklearn.metrics import ndcg_score | |||||
| def hyper_test(embedding,head, total_dataset, adaptation_step): | |||||
| test_set_size = len(total_dataset) | |||||
| random.shuffle(total_dataset) | |||||
| a, b, c, d = zip(*total_dataset) | |||||
| losses_q = [] | |||||
| ndcgs11 = [] | |||||
| ndcgs33=[] | |||||
| for iterator in range(test_set_size): | |||||
| try: | |||||
| supp_xs = a[iterator].cuda() | |||||
| supp_ys = b[iterator].cuda() | |||||
| query_xs = c[iterator].cuda() | |||||
| query_ys = d[iterator].cuda() | |||||
| except IndexError: | |||||
| print("index error in test method") | |||||
| continue | |||||
| learner = head.clone() | |||||
| temp_sxs = embedding(supp_xs) | |||||
| temp_qxs = embedding(query_xs) | |||||
| evaluation_error,predictions = fast_adapt(learner, | |||||
| temp_sxs, | |||||
| temp_qxs, | |||||
| supp_ys, | |||||
| query_ys, | |||||
| adaptation_step, | |||||
| get_predictions=True) | |||||
| l1 = L1Loss(reduction='mean') | |||||
| loss_q = l1(predictions.view(-1), query_ys) | |||||
| losses_q.append(float(loss_q)) | |||||
| predictions = predictions.view(-1) | |||||
| y_true = query_ys.cpu().detach().numpy() | |||||
| y_pred = predictions.cpu().detach().numpy() | |||||
| ndcgs11.append(float(ndcg_score([y_true], [y_pred], k=1, sample_weight=None, ignore_ties=False))) | |||||
| ndcgs33.append(float(ndcg_score([y_true], [y_pred], k=3, sample_weight=None, ignore_ties=False))) | |||||
| del supp_xs, supp_ys, query_xs, query_ys, predictions, y_true, y_pred, loss_q | |||||
| # calculate metrics | |||||
| try: | |||||
| losses_q = np.array(losses_q).mean() | |||||
| except: | |||||
| losses_q = 100 | |||||
| try: | |||||
| ndcg1 = np.array(ndcgs11).mean() | |||||
| ndcg3 = np.array(ndcgs33).mean() | |||||
| except: | |||||
| ndcg1 = 0 | |||||
| ndcg3 = 0 | |||||
| return losses_q,ndcg1,ndcg3 | |||||
| import os | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from ray import tune | |||||
| import pickle | |||||
| from options import config | |||||
| from embedding_module import EmbeddingModule | |||||
| import learn2learn as l2l | |||||
| import random | |||||
| from fast_adapt import fast_adapt | |||||
| import gc | |||||
| from learn2learn.optim.transforms import KroneckerTransform | |||||
| from hyper_testing import hyper_test | |||||
| # Define paths (for data) | |||||
| master_path= "/media/external_10TB/10TB/maheri/melu_data5" | |||||
| def load_data(data_dir=master_path,test_state='warm_state'): | |||||
| training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4) | |||||
| supp_xs_s = [] | |||||
| supp_ys_s = [] | |||||
| query_xs_s = [] | |||||
| query_ys_s = [] | |||||
| for idx in range(training_set_size): | |||||
| supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb"))) | |||||
| total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||||
| del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||||
| trainset = total_dataset | |||||
| test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4) | |||||
| supp_xs_s = [] | |||||
| supp_ys_s = [] | |||||
| query_xs_s = [] | |||||
| query_ys_s = [] | |||||
| for idx in range(test_set_size): | |||||
| supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(data_dir, test_state, idx), "rb"))) | |||||
| supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(data_dir, test_state, idx), "rb"))) | |||||
| query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(data_dir, test_state, idx), "rb"))) | |||||
| query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(data_dir, test_state, idx), "rb"))) | |||||
| test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||||
| del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||||
| random.shuffle(test_dataset) | |||||
| val_size = int(test_set_size * 0.2) | |||||
| validationset = test_dataset[:val_size] | |||||
| testset = test_dataset[val_size:] | |||||
| return trainset, validationset,testset | |||||
| def train_melu(conf, checkpoint_dir=None, data_dir=None): | |||||
| embedding_dim = config['embedding_dim'] | |||||
| print("inajm1:",checkpoint_dir) | |||||
| fc1_in_dim = config['embedding_dim'] * 8 | |||||
| fc2_in_dim = config['first_fc_hidden_dim'] | |||||
| fc2_out_dim = config['second_fc_hidden_dim'] | |||||
| fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) | |||||
| fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) | |||||
| linear_out = torch.nn.Linear(fc2_out_dim, 1) | |||||
| head = torch.nn.Sequential(fc1, fc2, linear_out) | |||||
| emb = EmbeddingModule(config).cuda() | |||||
| transform = None | |||||
| if conf['transformer'] == "kronoker": | |||||
| transform = KroneckerTransform(l2l.nn.KroneckerLinear) | |||||
| elif conf['transformer'] == "linear": | |||||
| transform = l2l.optim.ModuleTransform(torch.nn.Linear) | |||||
| # define meta algorithm | |||||
| if conf['meta_algo'] == "maml": | |||||
| head = l2l.algorithms.MAML(head, lr=conf['local_lr'], first_order=conf['first_order']) | |||||
| elif conf['meta_algo'] == 'metasgd': | |||||
| head = l2l.algorithms.MetaSGD(head, lr=conf['local_lr'], first_order=conf['first_order']) | |||||
| elif conf['meta_algo'] == 'gbml': | |||||
| head = l2l.algorithms.GBML(head, transform=transform, lr=conf['local_lr'], | |||||
| adapt_transform=conf['adapt_transform'], first_order=conf['first_order']) | |||||
| head.cuda() | |||||
| net = nn.Sequential(emb,head) | |||||
| criterion = nn.MSELoss() | |||||
| all_parameters = list(emb.parameters()) + list(head.parameters()) | |||||
| optimizer = torch.optim.Adam(all_parameters, lr=conf['lr']) | |||||
| if checkpoint_dir: | |||||
| model_state, optimizer_state = torch.load( | |||||
| os.path.join(checkpoint_dir, "checkpoint")) | |||||
| net.load_state_dict(model_state) | |||||
| optimizer.load_state_dict(optimizer_state) | |||||
| # loading data | |||||
| train_dataset,validation_dataset,test_dataset = load_data(data_dir,test_state=conf['test_state']) | |||||
| print(conf['test_state']) | |||||
| batch_size = conf['batch_size'] | |||||
| num_batch = int(len(train_dataset) / batch_size) | |||||
| a, b, c, d = zip(*train_dataset) | |||||
| for epoch in range(config['num_epoch']): # loop over the dataset multiple times | |||||
| for i in range(num_batch): | |||||
| optimizer.zero_grad() | |||||
| meta_train_error = 0.0 | |||||
| # print("EPOCH: ", epoch, " BATCH: ", i) | |||||
| supp_xs = list(a[batch_size * i:batch_size * (i + 1)]) | |||||
| supp_ys = list(b[batch_size * i:batch_size * (i + 1)]) | |||||
| query_xs = list(c[batch_size * i:batch_size * (i + 1)]) | |||||
| query_ys = list(d[batch_size * i:batch_size * (i + 1)]) | |||||
| batch_sz = len(supp_xs) | |||||
| # iterate over all tasks | |||||
| for task in range(batch_sz): | |||||
| sxs = supp_xs[task].cuda() | |||||
| qxs = query_xs[task].cuda() | |||||
| sys = supp_ys[task].cuda() | |||||
| qys = query_ys[task].cuda() | |||||
| learner = head.clone() | |||||
| temp_sxs = emb(sxs) | |||||
| temp_qxs = emb(qxs) | |||||
| evaluation_error = fast_adapt(learner, | |||||
| temp_sxs, | |||||
| temp_qxs, | |||||
| sys, | |||||
| qys, | |||||
| conf['inner']) | |||||
| evaluation_error.backward() | |||||
| meta_train_error += evaluation_error.item() | |||||
| del(sxs,qxs,sys,qys) | |||||
| supp_xs[task].cpu() | |||||
| query_xs[task].cpu() | |||||
| supp_ys[task].cpu() | |||||
| query_ys[task].cpu() | |||||
| # Average the accumulated gradients and optimize (After each batch we will update params) | |||||
| for p in all_parameters: | |||||
| p.grad.data.mul_(1.0 / batch_sz) | |||||
| optimizer.step() | |||||
| del (supp_xs, supp_ys, query_xs, query_ys) | |||||
| gc.collect() | |||||
| # test results on the validation data | |||||
| val_loss,val_ndcg1,val_ndcg3 = hyper_test(emb,head,validation_dataset,adaptation_step=conf['inner']) | |||||
| with tune.checkpoint_dir(epoch) as checkpoint_dir: | |||||
| path = os.path.join(checkpoint_dir, "checkpoint") | |||||
| torch.save((net.state_dict(), optimizer.state_dict()), path) | |||||
| tune.report(loss=val_loss, ndcg1=val_ndcg1,ndcg3=val_ndcg3) | |||||
| print("Finished Training") |
| help='MAML/MetaSGD/GBML') | help='MAML/MetaSGD/GBML') | ||||
| parser.add_argument('--gpu', type=int, default=0, | parser.add_argument('--gpu', type=int, default=0, | ||||
| help='number of gpu to run the code') | help='number of gpu to run the code') | ||||
| parser.add_argument('--epochs', type=int, default=config['num_epoch'], | |||||
| help='number of gpu to run the code') | |||||
| # META LEARNING | # META LEARNING | ||||
| print("META LEARNING PHASE") | print("META LEARNING PHASE") | ||||
| # head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) | |||||
| # define transformer | # define transformer | ||||
| transform = None | transform = None | ||||
| # define meta algorithm | # define meta algorithm | ||||
| if args.meta_algo == "maml": | if args.meta_algo == "maml": | ||||
| head = l2l.algorithms.MAML(head, lr=config['local_lr'],first_order=args.first_order) | |||||
| head = l2l.algorithms.MAML(head, lr=args.lr_inner,first_order=args.first_order) | |||||
| elif args.meta_algo == 'metasgd': | elif args.meta_algo == 'metasgd': | ||||
| head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=args.first_order) | |||||
| head = l2l.algorithms.MetaSGD(head, lr=args.lr_inner,first_order=args.first_order) | |||||
| elif args.meta_algo == 'gbml': | elif args.meta_algo == 'gbml': | ||||
| head = l2l.algorithms.GBML(head, transform=transform, lr=config['local_lr'],adapt_transform=args.adapt_transform, first_order=args.first_order) | |||||
| head = l2l.algorithms.GBML(head, transform=transform, lr=args.lr_inner,adapt_transform=args.adapt_transform, first_order=args.first_order) | |||||
| if use_cuda: | if use_cuda: | ||||
| head.cuda() | head.cuda() | ||||
| # Setup optimization | # Setup optimization | ||||
| print("SETUP OPTIMIZATION PHASE") | print("SETUP OPTIMIZATION PHASE") | ||||
| all_parameters = list(emb.parameters()) + list(head.parameters()) | all_parameters = list(emb.parameters()) + list(head.parameters()) | ||||
| optimizer = torch.optim.Adam(all_parameters, lr=config['lr']) | |||||
| optimizer = torch.optim.Adam(all_parameters, lr=args.lr_meta) | |||||
| # loss = torch.nn.MSELoss(reduction='mean') | # loss = torch.nn.MSELoss(reduction='mean') | ||||
| # Load training dataset. | # Load training dataset. | ||||
| a, b, c, d = zip(*total_dataset) | a, b, c, d = zip(*total_dataset) | ||||
| print("\n\n\n") | print("\n\n\n") | ||||
| for iteration in range(config['num_epoch']): | |||||
| for iteration in range(args.epochs): | |||||
| for i in range(num_batch): | for i in range(num_batch): | ||||
| optimizer.zero_grad() | optimizer.zero_grad() | ||||
| meta_train_error = 0.0 | meta_train_error = 0.0 | ||||
| del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | ||||
| print("===================== " + test_state + " =====================") | print("===================== " + test_state + " =====================") | ||||
| test(emb,head, test_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'],adaptation_step=args.inner_eval) | |||||
| test(emb,head, test_dataset, batch_size=config['batch_size'], num_epoch=args.epochs,adaptation_step=args.inner_eval) | |||||
| print("===================================================\n\n\n") | print("===================================================\n\n\n") | ||||
| print(args) | print(args) | ||||
| temp_qxs, | temp_qxs, | ||||
| supp_ys, | supp_ys, | ||||
| query_ys, | query_ys, | ||||
| config['inner'], | |||||
| # config['inner'], | |||||
| adaptation_step, | |||||
| get_predictions=True) | get_predictions=True) | ||||
| l1 = L1Loss(reduction='mean') | l1 = L1Loss(reduction='mean') | ||||
| print("nDCG3: ", np.array(ndcgs33).mean()) | print("nDCG3: ", np.array(ndcgs33).mean()) | ||||
| # print("nDCG3: ", np.array(ndcgs333).mean()) | # print("nDCG3: ", np.array(ndcgs333).mean()) | ||||
| print("is there nan? " + str(np.any(np.isnan(ndcgs11)))) | |||||
| print("is there nan? " + str(np.any(np.isnan(ndcgs33)))) | |||||
| print("is there nan? " + str(np.any(np.isnan(losses_q)))) | |||||