import os import torch import torch.nn as nn from ray import tune import pickle from embedding_module import EmbeddingModule import learn2learn as l2l import random from fast_adapt import fast_adapt import gc from learn2learn.optim.transforms import KroneckerTransform from hyper_testing import hyper_test from clustering import Trainer from Head import Head import numpy as np # Define paths (for data) # master_path= "/media/external_10TB/10TB/maheri/melu_data5" def load_data(data_dir=None, test_state='warm_state'): # training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4) # supp_xs_s = [] # supp_ys_s = [] # query_xs_s = [] # query_ys_s = [] # for idx in range(training_set_size): # supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb"))) # supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb"))) # query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb"))) # query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb"))) # total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) # del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) # trainset = total_dataset test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4) supp_xs_s = [] supp_ys_s = [] query_xs_s = [] query_ys_s = [] for idx in range(test_set_size): supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(data_dir, test_state, idx), "rb"))) supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(data_dir, test_state, idx), "rb"))) query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(data_dir, test_state, idx), "rb"))) query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(data_dir, test_state, idx), "rb"))) test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) random.shuffle(test_dataset) # random.shuffle(trainset) val_size = int(test_set_size * 0.3) validationset = test_dataset[:val_size] # testset = test_dataset[val_size:] return None, validationset, None def data_batching_new(indexes, C_distribs, batch_size, training_set_size, num_clusters,config): probs = np.squeeze(C_distribs) probs = np.array(probs) ** config['distribution_power'] / np.sum(np.array(probs) ** config['distribution_power'], axis=1, keepdims=True) cs = [np.random.choice(num_clusters, p=i) for i in probs] num_batch = int(training_set_size / batch_size) res = [[] for i in range(num_batch)] clas = [[] for i in range(num_clusters)] clas_temp = [[] for i in range(num_clusters)] for idx, c in zip(indexes, cs): clas[c].append(idx) for i in range(num_clusters): random.shuffle(clas[i]) # t = np.array([len(i) for i in clas]) t = np.array([len(i) ** config['data_selection_pow'] for i in clas]) t = t / t.sum() dif = list(set(list(np.arange(training_set_size))) - set(indexes[0:(num_batch * batch_size)])) cnt = 0 for i in range(len(res)): for j in range(batch_size): temp = np.random.choice(num_clusters, p=t) if len(clas[temp]) > 0: selected = clas[temp].pop(0) res[i].append(selected) clas_temp[temp].append(selected) else: # res[i].append(indexes[training_set_size-1-cnt]) if len(dif) > 0: if random.random() < 0.5 or len(clas_temp[temp]) == 0: res[i].append(dif.pop(0)) else: selected = clas_temp[temp].pop(0) clas_temp[temp].append(selected) res[i].append(selected) else: selected = clas_temp[temp].pop(0) res[i].append(selected) cnt = cnt + 1 print("data_batching : ", cnt) res = np.random.permutation(res) final_result = np.array(res).flatten() return final_result def train_melu(conf, checkpoint_dir=None, data_dir=None): config = conf master_path = data_dir emb = EmbeddingModule(conf).cuda() transform = None if conf['transformer'] == "kronoker": transform = KroneckerTransform(l2l.nn.KroneckerLinear) elif conf['transformer'] == "linear": transform = l2l.optim.ModuleTransform(torch.nn.Linear) trainer = Trainer(conf) trainer.cuda() head = Head(config) # define meta algorithm if conf['meta_algo'] == "maml": head = l2l.algorithms.MAML(head, lr=conf['local_lr'], first_order=conf['first_order']) elif conf['meta_algo'] == 'metasgd': head = l2l.algorithms.MetaSGD(head, lr=conf['local_lr'], first_order=conf['first_order']) elif conf['meta_algo'] == 'gbml': head = l2l.algorithms.GBML(head, transform=transform, lr=conf['local_lr'], adapt_transform=conf['adapt_transform'], first_order=conf['first_order']) head.cuda() criterion = nn.MSELoss() all_parameters = list(emb.parameters()) + list(trainer.parameters()) + list(head.parameters()) optimizer = torch.optim.Adam(all_parameters, lr=conf['lr']) # Load training dataset. print("LOAD DATASET PHASE") training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4) supp_xs_s = [] supp_ys_s = [] query_xs_s = [] query_ys_s = [] if checkpoint_dir: print("in checkpoint - bug happened") # model_state, optimizer_state = torch.load( # os.path.join(checkpoint_dir, "checkpoint")) # net.load_state_dict(model_state) # optimizer.load_state_dict(optimizer_state) # loading data # _, validation_dataset, _ = load_data(data_dir, test_state=conf['test_state']) batch_size = conf['batch_size'] # num_batch = int(len(train_dataset) / batch_size) # a, b, c, d = zip(*train_dataset) C_distribs = [] indexes = list(np.arange(training_set_size)) all_test_users = [] for iteration in range(conf['num_epoch']): # loop over the dataset multiple times print("iteration:", iteration) num_batch = int(training_set_size / batch_size) if iteration == 0: print("changing cluster centroids started ...") indexes = list(np.arange(training_set_size)) supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] for idx in range(0, 2500): supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) query_xs.append( pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) query_ys.append( pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) batch_sz = len(supp_xs) user_embeddings = [] for task in range(batch_sz): # Compute meta-training loss supp_xs[task] = supp_xs[task].cuda() supp_ys[task] = supp_ys[task].cuda() temp_sxs = emb(supp_xs[task]) y = supp_ys[task].view(-1, 1) input_pairs = torch.cat((temp_sxs, y), dim=1) _, mean_task, _ = trainer.cluster_module(temp_sxs, y) user_embeddings.append(mean_task.detach().cpu().numpy()) supp_xs[task] = supp_xs[task].cpu() supp_ys[task] = supp_ys[task].cpu() from sklearn.cluster import KMeans user_embeddings = np.array(user_embeddings) kmeans_model = KMeans(n_clusters=conf['cluster_k'], init="k-means++").fit(user_embeddings) trainer.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda() if iteration > (0): indexes = data_batching_new(indexes, C_distribs, batch_size, training_set_size, conf['cluster_k'], conf) else: random.shuffle(indexes) C_distribs = [] for i in range(num_batch): optimizer.zero_grad() meta_train_error = 0.0 meta_cluster_error = 0.0 supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] for idx in range(batch_size * i, batch_size * (i + 1)): supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) query_xs.append( pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) query_ys.append( pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) batch_sz = len(supp_xs) for task in range(batch_sz): # Compute meta-training loss supp_xs[task] = supp_xs[task].cuda() supp_ys[task] = supp_ys[task].cuda() query_xs[task] = query_xs[task].cuda() query_ys[task] = query_ys[task].cuda() learner = head.clone() temp_sxs = emb(supp_xs[task]) temp_qxs = emb(query_xs[task]) evaluation_error, c, K_LOSS = fast_adapt(learner, temp_sxs, temp_qxs, supp_ys[task], query_ys[task], conf['inner'], trainer=trainer, test=False, iteration=iteration ) C_distribs.append(c.detach().cpu().numpy()) meta_cluster_error += K_LOSS evaluation_error.backward(retain_graph=True) meta_train_error += evaluation_error.item() supp_xs[task] = supp_xs[task].cpu() supp_ys[task] = supp_ys[task].cpu() query_xs[task] = query_xs[task].cpu() query_ys[task] = query_ys[task].cpu() ################################################ # Print some metrics print('Iteration', iteration) print('Meta Train Error', meta_train_error / batch_sz) print('KL Train Error', round(meta_cluster_error / batch_sz, 4), "\t", C_distribs[-1]) # Average the accumulated gradients and optimize for p in all_parameters: # if p.grad!=None: p.grad.data.mul_(1.0 / batch_sz) optimizer.step() # test results on the validation data val_loss, val_ndcg1, val_ndcg3 = hyper_test(emb, head, trainer, batch_size, master_path, conf['test_state'], adaptation_step=conf['inner'], num_epoch=iteration) # with tune.checkpoint_dir(epoch) as checkpoint_dir: # path = os.path.join(checkpoint_dir, "checkpoint") # torch.save((net.state_dict(), optimizer.state_dict()), path) tune.report(loss=val_loss, ndcg1=val_ndcg1, ndcg3=val_ndcg3) print("Finished Training")