|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- import os
- import torch
- import torch.nn as nn
- from ray import tune
- import pickle
- # from options import config
- from embedding_module import EmbeddingModule
- import learn2learn as l2l
- import random
- from fast_adapt import fast_adapt
- import gc
- from learn2learn.optim.transforms import KroneckerTransform
- from hyper_testing import hyper_test
- from clustering import Trainer
-
-
- # Define paths (for data)
- # master_path= "/media/external_10TB/10TB/maheri/melu_data5"
-
- def load_data(data_dir=None, test_state='warm_state'):
- training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4)
- supp_xs_s = []
- supp_ys_s = []
- query_xs_s = []
- query_ys_s = []
- for idx in range(training_set_size):
- supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb")))
- supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb")))
- query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb")))
- query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb")))
- total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
- del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
- trainset = total_dataset
-
- test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4)
- supp_xs_s = []
- supp_ys_s = []
- query_xs_s = []
- query_ys_s = []
- for idx in range(test_set_size):
- supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
- supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
- query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
- query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
- test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
- del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
-
- random.shuffle(test_dataset)
- random.shuffle(trainset)
- val_size = int(test_set_size * 0.3)
- validationset = test_dataset[:val_size]
- testset = test_dataset[val_size:]
-
- return trainset, validationset, testset
-
-
- def train_melu(conf, checkpoint_dir=None, data_dir=None):
- embedding_dim = conf['embedding_dim']
- fc1_in_dim = conf['embedding_dim'] * 8
- fc2_in_dim = conf['first_fc_hidden_dim']
- fc2_out_dim = conf['second_fc_hidden_dim']
-
- emb = EmbeddingModule(conf).cuda()
-
- transform = None
- if conf['transformer'] == "kronoker":
- transform = KroneckerTransform(l2l.nn.KroneckerLinear)
- elif conf['transformer'] == "linear":
- transform = l2l.optim.ModuleTransform(torch.nn.Linear)
-
- trainer = Trainer(conf)
-
- # define meta algorithm
- if conf['meta_algo'] == "maml":
- trainer = l2l.algorithms.MAML(trainer, lr=conf['local_lr'], first_order=conf['first_order'])
- elif conf['meta_algo'] == 'metasgd':
- trainer = l2l.algorithms.MetaSGD(trainer, lr=conf['local_lr'], first_order=conf['first_order'])
- elif conf['meta_algo'] == 'gbml':
- trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=conf['local_lr'],
- adapt_transform=conf['adapt_transform'], first_order=conf['first_order'])
- trainer.cuda()
-
- all_parameters = list(emb.parameters()) + list(trainer.parameters())
- optimizer = torch.optim.Adam(all_parameters, lr=conf['lr'])
-
- if checkpoint_dir:
- print("in checkpoint - bug happened")
- # model_state, optimizer_state = torch.load(
- # os.path.join(checkpoint_dir, "checkpoint"))
- # net.load_state_dict(model_state)
- # optimizer.load_state_dict(optimizer_state)
-
- # loading data
- train_dataset, validation_dataset, test_dataset = load_data(data_dir, test_state=conf['test_state'])
- batch_size = conf['batch_size']
- num_batch = int(len(train_dataset) / batch_size)
-
- a, b, c, d = zip(*train_dataset)
-
- for epoch in range(conf['num_epoch']): # loop over the dataset multiple times
- for i in range(num_batch):
- optimizer.zero_grad()
- meta_train_error = 0.0
-
- # print("EPOCH: ", epoch, " BATCH: ", i)
- supp_xs = list(a[batch_size * i:batch_size * (i + 1)])
- supp_ys = list(b[batch_size * i:batch_size * (i + 1)])
- query_xs = list(c[batch_size * i:batch_size * (i + 1)])
- query_ys = list(d[batch_size * i:batch_size * (i + 1)])
- batch_sz = len(supp_xs)
-
- # iterate over all tasks
- for task in range(batch_sz):
- sxs = supp_xs[task].cuda()
- qxs = query_xs[task].cuda()
- sys = supp_ys[task].cuda()
- qys = query_ys[task].cuda()
-
- learner = trainer.clone()
- temp_sxs = emb(sxs)
- temp_qxs = emb(qxs)
-
- evaluation_error = fast_adapt(learner,
- temp_sxs,
- temp_qxs,
- sys,
- qys,
- conf['inner'])
-
- evaluation_error.backward()
- meta_train_error += evaluation_error.item()
-
- del (sxs, qxs, sys, qys)
- supp_xs[task].cpu()
- query_xs[task].cpu()
- supp_ys[task].cpu()
- query_ys[task].cpu()
-
- # Average the accumulated gradients and optimize (After each batch we will update params)
- for p in all_parameters:
- p.grad.data.mul_(1.0 / batch_sz)
- optimizer.step()
- del (supp_xs, supp_ys, query_xs, query_ys)
- gc.collect()
-
- # test results on the validation data
- val_loss, val_ndcg1, val_ndcg3 = hyper_test(emb, trainer, validation_dataset, adaptation_step=conf['inner'])
-
- # with tune.checkpoint_dir(epoch) as checkpoint_dir:
- # path = os.path.join(checkpoint_dir, "checkpoint")
- # torch.save((net.state_dict(), optimizer.state_dict()), path)
-
- tune.report(loss=val_loss, ndcg1=val_ndcg1, ndcg3=val_ndcg3)
- print("Finished Training")
|