import os import torch import pickle from options import config from data_generation import generate from embedding_module import EmbeddingModule import learn2learn as l2l import random from learnToLearnTest import test from fast_adapt import fast_adapt import gc from learn2learn.optim.transforms import KroneckerTransform import argparse from clustering import ClustringModule, Trainer import numpy as np from torch.nn import functional as F def data_batching(indexes, C_distribs, batch_size, training_set_size, num_clusters): probs = np.squeeze(C_distribs) cs = [np.random.choice(num_clusters, p=i) for i in probs] num_batch = int(training_set_size / batch_size) res = [[] for i in range(num_batch)] clas = [[] for i in range(num_clusters)] for idx, c in zip(indexes, cs): clas[c].append(idx) t = np.array([len(i) for i in clas]) t = t / t.sum() dif = list(set(list(np.arange(training_set_size))) - set(indexes[0:(num_batch * batch_size)])) cnt = 0 for i in range(len(res)): for j in range(batch_size): temp = np.random.choice(num_clusters, p=t) if len(clas[temp]) > 0: res[i].append(clas[temp].pop(0)) else: # res[i].append(indexes[training_set_size-1-cnt]) res[i].append(random.choice(dif)) cnt = cnt + 1 res = np.random.permutation(res) final_result = np.array(res).flatten() return final_result def parse_args(): print("==============") parser = argparse.ArgumentParser([], description='Fast Context Adaptation via Meta-Learning (CAVIA),' 'Clasification experiments.') print("==============\n") parser.add_argument('--seed', type=int, default=53) parser.add_argument('--task', type=str, default='multi', help='problem setting: sine or celeba') parser.add_argument('--tasks_per_metaupdate', type=int, default=32, help='number of tasks in each batch per meta-update') parser.add_argument('--lr_inner', type=float, default=5e-6, help='inner-loop learning rate (per task)') parser.add_argument('--lr_meta', type=float, default=5e-5, help='outer-loop learning rate (used with Adam optimiser)') # parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate') parser.add_argument('--inner', type=int, default=1, help='number of gradient steps in inner loop (during training)') parser.add_argument('--inner_eval', type=int, default=1, help='number of gradient updates at test time (for evaluation)') parser.add_argument('--first_order', action='store_true', default=False, help='run first order approximation of CAVIA') parser.add_argument('--adapt_transform', action='store_true', default=False, help='run adaptation transform') parser.add_argument('--transformer', type=str, default="kronoker", help='transformer type') parser.add_argument('--meta_algo', type=str, default="metasgd", help='MAML/MetaSGD/GBML') parser.add_argument('--gpu', type=int, default=0, help='number of gpu to run the code') parser.add_argument('--epochs', type=int, default=config['num_epoch'], help='number of gpu to run the code') # parser.add_argument('--data_root', type=str, default="./movielens/ml-1m", help='path to data root') # parser.add_argument('--num_workers', type=int, default=4, help='num of workers to use') # parser.add_argument('--test', action='store_true', default=False, help='num of workers to use') # parser.add_argument('--embedding_dim', type=int, default=32, help='num of workers to use') # parser.add_argument('--first_fc_hidden_dim', type=int, default=64, help='num of workers to use') # parser.add_argument('--second_fc_hidden_dim', type=int, default=64, help='num of workers to use') # parser.add_argument('--num_epoch', type=int, default=30, help='num of workers to use') # parser.add_argument('--num_genre', type=int, default=25, help='num of workers to use') # parser.add_argument('--num_director', type=int, default=2186, help='num of workers to use') # parser.add_argument('--num_actor', type=int, default=8030, help='num of workers to use') # parser.add_argument('--num_rate', type=int, default=6, help='num of workers to use') # parser.add_argument('--num_gender', type=int, default=2, help='num of workers to use') # parser.add_argument('--num_age', type=int, default=7, help='num of workers to use') # parser.add_argument('--num_occupation', type=int, default=21, help='num of workers to use') # parser.add_argument('--num_zipcode', type=int, default=3402, help='num of workers to use') # parser.add_argument('--rerun', action='store_true', default=False, # help='Re-run experiment (will override previously saved results)') args = parser.parse_args() # use the GPU if available # args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # print('Running on device: {}'.format(args.device)) return args from torch.nn import functional as F def kl_loss(C_distribs): # batchsize * k C_distribs = torch.stack(C_distribs).squeeze() # print("injam:",len(C_distribs)) # print(C_distribs[0].shape) # batchsize * k # print("injam2",C_distribs) C_distribs_sq = torch.pow(C_distribs, 2) # print("injam3",C_distribs_sq) # 1*k C_distribs_sum = torch.sum(C_distribs, dim=0, keepdim=True) # print("injam4",C_distribs_sum) # batchsize * k temp = C_distribs_sq / C_distribs_sum # print("injam5",temp) # batchsize * 1 temp_sum = torch.sum(temp, dim=1, keepdim=True) # print("injam6",temp_sum) target_distribs = temp / temp_sum # print("injam7",target_distribs) # calculate the kl loss clustering_loss = F.kl_div(C_distribs.log(), target_distribs, reduction='batchmean') # print("injam8",clustering_loss) return clustering_loss if __name__ == '__main__': args = parse_args() print(args) if config['use_cuda']: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) master_path = "/media/external_10TB/10TB/maheri/define_task_melu_data2" config['master_path'] = master_path # DATA GENERATION print("DATA GENERATION PHASE") if not os.path.exists("{}/".format(master_path)): os.mkdir("{}/".format(master_path)) # preparing dataset. It needs about 22GB of your hard disk space. generate(master_path) # TRAINING print("TRAINING PHASE") embedding_dim = config['embedding_dim'] fc1_in_dim = config['embedding_dim'] * 8 fc2_in_dim = config['first_fc_hidden_dim'] fc2_out_dim = config['second_fc_hidden_dim'] use_cuda = config['use_cuda'] if use_cuda: emb = EmbeddingModule(config).cuda() else: emb = EmbeddingModule(config) # META LEARNING print("META LEARNING PHASE") # define transformer transform = None if args.transformer == "kronoker": transform = KroneckerTransform(l2l.nn.KroneckerLinear) elif args.transformer == "linear": transform = l2l.optim.ModuleTransform(torch.nn.Linear) trainer = Trainer(config) tr = trainer # define meta algorithm if args.meta_algo == "maml": trainer = l2l.algorithms.MAML(trainer, lr=args.lr_inner, first_order=args.first_order) elif args.meta_algo == 'metasgd': trainer = l2l.algorithms.MetaSGD(trainer, lr=args.lr_inner, first_order=args.first_order) elif args.meta_algo == 'gbml': trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=args.lr_inner, adapt_transform=args.adapt_transform, first_order=args.first_order) if use_cuda: trainer.cuda() # Setup optimization print("SETUP OPTIMIZATION PHASE") all_parameters = list(emb.parameters()) + list(trainer.parameters()) optimizer = torch.optim.Adam(all_parameters, lr=config['lr']) # loss = torch.nn.MSELoss(reduction='mean') # Load training dataset. print("LOAD DATASET PHASE") training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4) supp_xs_s = [] supp_ys_s = [] query_xs_s = [] query_ys_s = [] batch_size = config['batch_size'] # torch.cuda.empty_cache() print("\n\n\n") for iteration in range(config['num_epoch']): if iteration == 0: print("changing cluster centroids started ...") indexes = list(np.arange(training_set_size)) supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] for idx in range(0, 2500): supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) query_xs.append( pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) query_ys.append( pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) batch_sz = len(supp_xs) user_embeddings = [] for task in range(batch_sz): # Compute meta-training loss supp_xs[task] = supp_xs[task].cuda() supp_ys[task] = supp_ys[task].cuda() # query_xs[task] = query_xs[task].cuda() # query_ys[task] = query_ys[task].cuda() temp_sxs = emb(supp_xs[task]) # temp_qxs = emb(query_xs[task]) y = supp_ys[task].view(-1, 1) # input_pairs = torch.cat((temp_sxs, y), dim=1) input_pairs = temp_sxs task_embed = tr.cluster_module.input_to_hidden(input_pairs) # todo : may be useless mean_task = tr.cluster_module.aggregate(task_embed) user_embeddings.append(mean_task.detach().cpu().numpy()) supp_xs[task] = supp_xs[task].cpu() supp_ys[task] = supp_ys[task].cpu() from sklearn.cluster import KMeans user_embeddings = np.array(user_embeddings) kmeans_model = KMeans(n_clusters=config['cluster_k'], init="k-means++").fit(user_embeddings) tr.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda() if iteration > 0: # indexes = data_batching(indexes, C_distribs, batch_size, training_set_size, config['cluster_k']) # random.shuffle(indexes) num_batch = int(training_set_size / batch_size) indexes = list(np.arange(training_set_size)) random.shuffle(indexes) else: num_batch = int(training_set_size / batch_size) indexes = list(np.arange(training_set_size)) random.shuffle(indexes) for i in range(num_batch): meta_train_error = 0.0 meta_cluster_error = 0.0 optimizer.zero_grad() print("EPOCH: ", iteration, " BATCH: ", i) supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] for idx in range(batch_size * i, batch_size * (i + 1)): supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) query_xs.append( pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) query_ys.append( pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) batch_sz = len(supp_xs) if use_cuda: for j in range(batch_size): supp_xs[j] = supp_xs[j].cuda() supp_ys[j] = supp_ys[j].cuda() query_xs[j] = query_xs[j].cuda() query_ys[j] = query_ys[j].cuda() C_distribs = [] for task in range(batch_sz): learner = trainer.clone() temp_sxs = emb(supp_xs[task]) temp_qxs = emb(query_xs[task]) evaluation_error, c, k_loss = fast_adapt(learner, temp_sxs, temp_qxs, supp_ys[task], query_ys[task], config['inner'], epoch=iteration) # C_distribs.append(c) evaluation_error.backward(retain_graph=True) meta_train_error += evaluation_error.item() meta_cluster_error += k_loss # Print some metrics print('Iteration', iteration) print('Meta Train Error', meta_train_error / batch_sz) print('KL Train Error', meta_cluster_error / batch_sz) # clustering_loss = config['kl_loss_weight'] * kl_loss(C_distribs) # clustering_loss.backward() # print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy()) # Average the accumulated gradients and optimize for p in all_parameters: p.grad.data.mul_(1.0 / batch_sz) optimizer.step() # torch.cuda.empty_cache() # del (supp_xs, supp_ys, query_xs, query_ys, learner, temp_sxs, temp_qxs) # gc.collect() print("===============================================\n") if iteration % 2 == 0 and iteration != 0: # testing print("start of test phase") trainer.eval() with open("results2.txt", "a") as f: f.write("epoch:{}\n".format(iteration)) for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: test_dataset = None test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) supp_xs_s = [] supp_ys_s = [] query_xs_s = [] query_ys_s = [] gc.collect() print("===================== " + test_state + " =====================") mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'], test_state=test_state, args=args) with open("results2.txt", "a") as f: f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) print("===================================================") del (test_dataset) gc.collect() trainer.train() with open("results2.txt", "a") as f: f.write("\n") print("\n\n\n") # save model # final_model = torch.nn.Sequential(emb, head) # torch.save(final_model.state_dict(), master_path + "/models_gbml.pkl") # testing # print("start of test phase") # for test_state in ['warm_state', 'user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: # test_dataset = None # test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) # supp_xs_s = [] # supp_ys_s = [] # query_xs_s = [] # query_ys_s = [] # for idx in range(test_set_size): # supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, idx), "rb"))) # supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, idx), "rb"))) # query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, idx), "rb"))) # query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, idx), "rb"))) # test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) # del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) # # print("===================== " + test_state + " =====================") # test(emb, head, test_dataset, batch_size=config['batch_size'], num_epoch=args.epochs, # adaptation_step=args.inner_eval) # print("===================================================\n\n\n") # print(args)