123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281 |
- import os
- import torch
- import torch.nn as nn
- from ray import tune
- import pickle
- from embedding_module import EmbeddingModule
- import learn2learn as l2l
- import random
- from fast_adapt import fast_adapt
- import gc
- from learn2learn.optim.transforms import KroneckerTransform
- from hyper_testing import hyper_test
- from clustering import Trainer
- from Head import Head
- import numpy as np
-
-
- # Define paths (for data)
- # master_path= "/media/external_10TB/10TB/maheri/melu_data5"
-
- def load_data(data_dir=None, test_state='warm_state'):
- # training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4)
- # supp_xs_s = []
- # supp_ys_s = []
- # query_xs_s = []
- # query_ys_s = []
- # for idx in range(training_set_size):
- # supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb")))
- # supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb")))
- # query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb")))
- # query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb")))
- # total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
- # del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
- # trainset = total_dataset
-
- test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4)
- supp_xs_s = []
- supp_ys_s = []
- query_xs_s = []
- query_ys_s = []
- for idx in range(test_set_size):
- supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
- supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
- query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
- query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
- test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
- del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
-
- random.shuffle(test_dataset)
- # random.shuffle(trainset)
- val_size = int(test_set_size * 0.3)
- validationset = test_dataset[:val_size]
- # testset = test_dataset[val_size:]
-
- return None, validationset, None
-
-
- def data_batching_new(indexes, C_distribs, batch_size, training_set_size, num_clusters,config):
- probs = np.squeeze(C_distribs)
- probs = np.array(probs) ** config['distribution_power'] / np.sum(np.array(probs) ** config['distribution_power'],
- axis=1, keepdims=True)
- cs = [np.random.choice(num_clusters, p=i) for i in probs]
- num_batch = int(training_set_size / batch_size)
- res = [[] for i in range(num_batch)]
- clas = [[] for i in range(num_clusters)]
- clas_temp = [[] for i in range(num_clusters)]
-
- for idx, c in zip(indexes, cs):
- clas[c].append(idx)
-
- for i in range(num_clusters):
- random.shuffle(clas[i])
-
- # t = np.array([len(i) for i in clas])
- t = np.array([len(i) ** config['data_selection_pow'] for i in clas])
- t = t / t.sum()
-
- dif = list(set(list(np.arange(training_set_size))) - set(indexes[0:(num_batch * batch_size)]))
- cnt = 0
-
- for i in range(len(res)):
- for j in range(batch_size):
- temp = np.random.choice(num_clusters, p=t)
- if len(clas[temp]) > 0:
- selected = clas[temp].pop(0)
- res[i].append(selected)
- clas_temp[temp].append(selected)
- else:
- # res[i].append(indexes[training_set_size-1-cnt])
- if len(dif) > 0:
- if random.random() < 0.5 or len(clas_temp[temp]) == 0:
- res[i].append(dif.pop(0))
- else:
- selected = clas_temp[temp].pop(0)
- clas_temp[temp].append(selected)
- res[i].append(selected)
- else:
- selected = clas_temp[temp].pop(0)
- res[i].append(selected)
-
- cnt = cnt + 1
-
- print("data_batching : ", cnt)
- res = np.random.permutation(res)
- final_result = np.array(res).flatten()
- return final_result
-
-
- def train_melu(conf, checkpoint_dir=None, data_dir=None):
- config = conf
- master_path = data_dir
-
- emb = EmbeddingModule(conf).cuda()
-
- transform = None
- if conf['transformer'] == "kronoker":
- transform = KroneckerTransform(l2l.nn.KroneckerLinear)
- elif conf['transformer'] == "linear":
- transform = l2l.optim.ModuleTransform(torch.nn.Linear)
-
- trainer = Trainer(conf)
- trainer.cuda()
-
- head = Head(config)
-
- # define meta algorithm
- if conf['meta_algo'] == "maml":
- head = l2l.algorithms.MAML(head, lr=conf['local_lr'], first_order=conf['first_order'])
- elif conf['meta_algo'] == 'metasgd':
- head = l2l.algorithms.MetaSGD(head, lr=conf['local_lr'], first_order=conf['first_order'])
- elif conf['meta_algo'] == 'gbml':
- head = l2l.algorithms.GBML(head, transform=transform, lr=conf['local_lr'],
- adapt_transform=conf['adapt_transform'], first_order=conf['first_order'])
- head.cuda()
-
- criterion = nn.MSELoss()
- all_parameters = list(emb.parameters()) + list(trainer.parameters()) + list(head.parameters())
- optimizer = torch.optim.Adam(all_parameters, lr=conf['lr'])
-
- # Load training dataset.
- print("LOAD DATASET PHASE")
- training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4)
- supp_xs_s = []
- supp_ys_s = []
- query_xs_s = []
- query_ys_s = []
-
- if checkpoint_dir:
- print("in checkpoint - bug happened")
- # model_state, optimizer_state = torch.load(
- # os.path.join(checkpoint_dir, "checkpoint"))
- # net.load_state_dict(model_state)
- # optimizer.load_state_dict(optimizer_state)
-
- # loading data
- # _, validation_dataset, _ = load_data(data_dir, test_state=conf['test_state'])
- batch_size = conf['batch_size']
- # num_batch = int(len(train_dataset) / batch_size)
-
- # a, b, c, d = zip(*train_dataset)
-
- C_distribs = []
- indexes = list(np.arange(training_set_size))
- all_test_users = []
-
- for iteration in range(conf['num_epoch']): # loop over the dataset multiple times
- print("iteration:", iteration)
- num_batch = int(training_set_size / batch_size)
- if iteration == 0:
- print("changing cluster centroids started ...")
- indexes = list(np.arange(training_set_size))
- supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
- for idx in range(0, 2500):
- supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
- supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
- query_xs.append(
- pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
- query_ys.append(
- pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
- batch_sz = len(supp_xs)
-
- user_embeddings = []
-
- for task in range(batch_sz):
- # Compute meta-training loss
- supp_xs[task] = supp_xs[task].cuda()
- supp_ys[task] = supp_ys[task].cuda()
-
- temp_sxs = emb(supp_xs[task])
- y = supp_ys[task].view(-1, 1)
- input_pairs = torch.cat((temp_sxs, y), dim=1)
- _, mean_task, _ = trainer.cluster_module(temp_sxs, y)
- user_embeddings.append(mean_task.detach().cpu().numpy())
-
- supp_xs[task] = supp_xs[task].cpu()
- supp_ys[task] = supp_ys[task].cpu()
-
- from sklearn.cluster import KMeans
- user_embeddings = np.array(user_embeddings)
- kmeans_model = KMeans(n_clusters=conf['cluster_k'], init="k-means++").fit(user_embeddings)
- trainer.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda()
-
- if iteration > (0):
- indexes = data_batching_new(indexes, C_distribs, batch_size, training_set_size, conf['cluster_k'], conf)
- else:
- random.shuffle(indexes)
-
- C_distribs = []
-
- for i in range(num_batch):
- optimizer.zero_grad()
- meta_train_error = 0.0
- meta_cluster_error = 0.0
-
- supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
-
- for idx in range(batch_size * i, batch_size * (i + 1)):
- supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
- supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
- query_xs.append(
- pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
- query_ys.append(
- pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
- batch_sz = len(supp_xs)
-
- for task in range(batch_sz):
- # Compute meta-training loss
- supp_xs[task] = supp_xs[task].cuda()
- supp_ys[task] = supp_ys[task].cuda()
- query_xs[task] = query_xs[task].cuda()
- query_ys[task] = query_ys[task].cuda()
-
- learner = head.clone()
- temp_sxs = emb(supp_xs[task])
- temp_qxs = emb(query_xs[task])
-
- evaluation_error, c, K_LOSS = fast_adapt(learner,
- temp_sxs,
- temp_qxs,
- supp_ys[task],
- query_ys[task],
- conf['inner'],
- trainer=trainer,
- test=False,
- iteration=iteration
- )
-
- C_distribs.append(c.detach().cpu().numpy())
- meta_cluster_error += K_LOSS
-
- evaluation_error.backward(retain_graph=True)
- meta_train_error += evaluation_error.item()
-
- supp_xs[task] = supp_xs[task].cpu()
- supp_ys[task] = supp_ys[task].cpu()
- query_xs[task] = query_xs[task].cpu()
- query_ys[task] = query_ys[task].cpu()
- ################################################
-
- # Print some metrics
- print('Iteration', iteration)
- print('Meta Train Error', meta_train_error / batch_sz)
- print('KL Train Error', round(meta_cluster_error / batch_sz, 4), "\t", C_distribs[-1])
-
- # Average the accumulated gradients and optimize
- for p in all_parameters:
- # if p.grad!=None:
- p.grad.data.mul_(1.0 / batch_sz)
- optimizer.step()
-
- # test results on the validation data
- val_loss, val_ndcg1, val_ndcg3 = hyper_test(emb, head, trainer, batch_size, master_path, conf['test_state'],
- adaptation_step=conf['inner'], num_epoch=iteration)
-
- # with tune.checkpoint_dir(epoch) as checkpoint_dir:
- # path = os.path.join(checkpoint_dir, "checkpoint")
- # torch.save((net.state_dict(), optimizer.state_dict()), path)
-
- tune.report(loss=val_loss, ndcg1=val_ndcg1, ndcg3=val_ndcg3)
-
- print("Finished Training")
|