import os import torch import torch.nn as nn from ray import tune import pickle from options import config from embedding_module import EmbeddingModule import learn2learn as l2l import random from fast_adapt import fast_adapt import gc from learn2learn.optim.transforms import KroneckerTransform from hyper_testing import hyper_test from clustering import Trainer # Define paths (for data) # master_path= "/media/external_10TB/10TB/maheri/melu_data5" def load_data(data_dir=None, test_state='warm_state'): training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4) supp_xs_s = [] supp_ys_s = [] query_xs_s = [] query_ys_s = [] for idx in range(training_set_size): supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb"))) supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb"))) query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb"))) query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb"))) total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) trainset = total_dataset test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4) supp_xs_s = [] supp_ys_s = [] query_xs_s = [] query_ys_s = [] for idx in range(test_set_size): supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(data_dir, test_state, idx), "rb"))) supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(data_dir, test_state, idx), "rb"))) query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(data_dir, test_state, idx), "rb"))) query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(data_dir, test_state, idx), "rb"))) test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) random.shuffle(test_dataset) random.shuffle(trainset) val_size = int(test_set_size * 0.2) validationset = test_dataset[:val_size] testset = test_dataset[val_size:] return trainset, validationset, testset def train_melu(conf, checkpoint_dir=None, data_dir=None): print("inajm1:", checkpoint_dir) embedding_dim = conf['embedding_dim'] fc1_in_dim = conf['embedding_dim'] * 8 fc2_in_dim = conf['first_fc_hidden_dim'] fc2_out_dim = conf['second_fc_hidden_dim'] # fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) # fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) # linear_out = torch.nn.Linear(fc2_out_dim, 1) # head = torch.nn.Sequential(fc1, fc2, linear_out) emb = EmbeddingModule(config).cuda() transform = None if conf['transformer'] == "kronoker": transform = KroneckerTransform(l2l.nn.KroneckerLinear) elif conf['transformer'] == "linear": transform = l2l.optim.ModuleTransform(torch.nn.Linear) trainer = Trainer(config) # define meta algorithm if conf['meta_algo'] == "maml": trainer = l2l.algorithms.MAML(trainer, lr=conf['local_lr'], first_order=conf['first_order']) elif conf['meta_algo'] == 'metasgd': trainer = l2l.algorithms.MetaSGD(trainer, lr=conf['local_lr'], first_order=conf['first_order']) elif conf['meta_algo'] == 'gbml': trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=conf['local_lr'], adapt_transform=conf['adapt_transform'], first_order=conf['first_order']) trainer.cuda() # net = nn.Sequential(emb, head) criterion = nn.MSELoss() all_parameters = list(emb.parameters()) + list(trainer.parameters()) optimizer = torch.optim.Adam(all_parameters, lr=conf['lr']) if checkpoint_dir: print("in checkpoint - bug happened") # model_state, optimizer_state = torch.load( # os.path.join(checkpoint_dir, "checkpoint")) # net.load_state_dict(model_state) # optimizer.load_state_dict(optimizer_state) # loading data train_dataset, validation_dataset, test_dataset = load_data(data_dir, test_state=conf['test_state']) batch_size = conf['batch_size'] num_batch = int(len(train_dataset) / batch_size) a, b, c, d = zip(*train_dataset) for epoch in range(config['num_epoch']): # loop over the dataset multiple times for i in range(num_batch): optimizer.zero_grad() meta_train_error = 0.0 # print("EPOCH: ", epoch, " BATCH: ", i) supp_xs = list(a[batch_size * i:batch_size * (i + 1)]) supp_ys = list(b[batch_size * i:batch_size * (i + 1)]) query_xs = list(c[batch_size * i:batch_size * (i + 1)]) query_ys = list(d[batch_size * i:batch_size * (i + 1)]) batch_sz = len(supp_xs) # iterate over all tasks for task in range(batch_sz): sxs = supp_xs[task].cuda() qxs = query_xs[task].cuda() sys = supp_ys[task].cuda() qys = query_ys[task].cuda() learner = trainer.clone() temp_sxs = emb(sxs) temp_qxs = emb(qxs) evaluation_error = fast_adapt(learner, temp_sxs, temp_qxs, sys, qys, conf['inner']) evaluation_error.backward() meta_train_error += evaluation_error.item() del (sxs, qxs, sys, qys) supp_xs[task].cpu() query_xs[task].cpu() supp_ys[task].cpu() query_ys[task].cpu() # Average the accumulated gradients and optimize (After each batch we will update params) for p in all_parameters: p.grad.data.mul_(1.0 / batch_sz) optimizer.step() del (supp_xs, supp_ys, query_xs, query_ys) gc.collect() # test results on the validation data val_loss, val_ndcg1, val_ndcg3 = hyper_test(emb, trainer, validation_dataset, adaptation_step=conf['inner']) # with tune.checkpoint_dir(epoch) as checkpoint_dir: # path = os.path.join(checkpoint_dir, "checkpoint") # torch.save((net.state_dict(), optimizer.state_dict()), path) tune.report(loss=val_loss, ndcg1=val_ndcg1, ndcg3=val_ndcg3) print("Finished Training")