123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import os
- import torch
- import torch.nn as nn
- from ray import tune
- import pickle
- from options import config
- from embedding_module import EmbeddingModule
- import learn2learn as l2l
- import random
- from fast_adapt import fast_adapt
- import gc
- from learn2learn.optim.transforms import KroneckerTransform
- from hyper_testing import hyper_test
-
- # Define paths (for data)
- master_path= "/media/external_10TB/10TB/maheri/melu_data5"
-
- def load_data(data_dir=master_path,test_state='warm_state'):
-
- training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4)
- supp_xs_s = []
- supp_ys_s = []
- query_xs_s = []
- query_ys_s = []
- for idx in range(training_set_size):
- supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb")))
- supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb")))
- query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb")))
- query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb")))
- total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
- del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
- trainset = total_dataset
-
- test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4)
- supp_xs_s = []
- supp_ys_s = []
- query_xs_s = []
- query_ys_s = []
- for idx in range(test_set_size):
- supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
- supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
- query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
- query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
- test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
- del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
-
- random.shuffle(test_dataset)
- val_size = int(test_set_size * 0.2)
- validationset = test_dataset[:val_size]
- testset = test_dataset[val_size:]
-
- return trainset, validationset,testset
-
-
- def train_melu(conf, checkpoint_dir=None, data_dir=None):
- embedding_dim = config['embedding_dim']
- print("inajm1:",checkpoint_dir)
- fc1_in_dim = config['embedding_dim'] * 8
- fc2_in_dim = config['first_fc_hidden_dim']
- fc2_out_dim = config['second_fc_hidden_dim']
-
- fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
- fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
- linear_out = torch.nn.Linear(fc2_out_dim, 1)
- head = torch.nn.Sequential(fc1, fc2, linear_out)
-
- emb = EmbeddingModule(config).cuda()
-
- transform = None
- if conf['transformer'] == "kronoker":
- transform = KroneckerTransform(l2l.nn.KroneckerLinear)
- elif conf['transformer'] == "linear":
- transform = l2l.optim.ModuleTransform(torch.nn.Linear)
-
- # define meta algorithm
- if conf['meta_algo'] == "maml":
- head = l2l.algorithms.MAML(head, lr=conf['local_lr'], first_order=conf['first_order'])
- elif conf['meta_algo'] == 'metasgd':
- head = l2l.algorithms.MetaSGD(head, lr=conf['local_lr'], first_order=conf['first_order'])
- elif conf['meta_algo'] == 'gbml':
- head = l2l.algorithms.GBML(head, transform=transform, lr=conf['local_lr'],
- adapt_transform=conf['adapt_transform'], first_order=conf['first_order'])
- head.cuda()
- net = nn.Sequential(emb,head)
-
- criterion = nn.MSELoss()
- all_parameters = list(emb.parameters()) + list(head.parameters())
- optimizer = torch.optim.Adam(all_parameters, lr=conf['lr'])
-
- if checkpoint_dir:
- model_state, optimizer_state = torch.load(
- os.path.join(checkpoint_dir, "checkpoint"))
- net.load_state_dict(model_state)
- optimizer.load_state_dict(optimizer_state)
-
- # loading data
- train_dataset,validation_dataset,test_dataset = load_data(data_dir,test_state=conf['test_state'])
- print(conf['test_state'])
- batch_size = conf['batch_size']
- num_batch = int(len(train_dataset) / batch_size)
- a, b, c, d = zip(*train_dataset)
-
- for epoch in range(config['num_epoch']): # loop over the dataset multiple times
- for i in range(num_batch):
- optimizer.zero_grad()
- meta_train_error = 0.0
-
- # print("EPOCH: ", epoch, " BATCH: ", i)
- supp_xs = list(a[batch_size * i:batch_size * (i + 1)])
- supp_ys = list(b[batch_size * i:batch_size * (i + 1)])
- query_xs = list(c[batch_size * i:batch_size * (i + 1)])
- query_ys = list(d[batch_size * i:batch_size * (i + 1)])
- batch_sz = len(supp_xs)
-
- # iterate over all tasks
- for task in range(batch_sz):
- sxs = supp_xs[task].cuda()
- qxs = query_xs[task].cuda()
- sys = supp_ys[task].cuda()
- qys = query_ys[task].cuda()
-
- learner = head.clone()
- temp_sxs = emb(sxs)
- temp_qxs = emb(qxs)
-
- evaluation_error = fast_adapt(learner,
- temp_sxs,
- temp_qxs,
- sys,
- qys,
- conf['inner'])
-
- evaluation_error.backward()
- meta_train_error += evaluation_error.item()
-
- del(sxs,qxs,sys,qys)
- supp_xs[task].cpu()
- query_xs[task].cpu()
- supp_ys[task].cpu()
- query_ys[task].cpu()
-
- # Average the accumulated gradients and optimize (After each batch we will update params)
- for p in all_parameters:
- p.grad.data.mul_(1.0 / batch_sz)
- optimizer.step()
- del (supp_xs, supp_ys, query_xs, query_ys)
- gc.collect()
-
- # test results on the validation data
- val_loss,val_ndcg1,val_ndcg3 = hyper_test(emb,head,validation_dataset,adaptation_step=conf['inner'])
-
- with tune.checkpoint_dir(epoch) as checkpoint_dir:
- path = os.path.join(checkpoint_dir, "checkpoint")
- torch.save((net.state_dict(), optimizer.state_dict()), path)
-
- tune.report(loss=val_loss, ndcg1=val_ndcg1,ndcg3=val_ndcg3)
- print("Finished Training")
|