| @@ -61,23 +61,35 @@ class MeLU(torch.nn.Module): | |||
| self.fast_weights = OrderedDict() | |||
| def forward(self, support_set_x, support_set_y, query_set_x, num_local_update): | |||
| # this line added my maheri | |||
| self.keep_weight = deepcopy(self.model.state_dict()) | |||
| for idx in range(num_local_update): | |||
| if idx > 0: | |||
| self.model.load_state_dict(self.fast_weights) | |||
| # weight_for_local_update = list(self.model.state_dict().values()) | |||
| weight_for_local_update = list(self.model.state_dict().values()) | |||
| support_set_y_pred = self.model(support_set_x) | |||
| loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1)) | |||
| self.model.zero_grad() | |||
| grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) | |||
| # local update | |||
| for i in range(self.weight_len): | |||
| if self.weight_name[i] in self.local_update_target_weight_name: | |||
| self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i] | |||
| else: | |||
| self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] | |||
| self.model.load_state_dict(self.fast_weights) | |||
| # self.fast_weights = OrderedDict() | |||
| query_set_y_pred = self.model(query_set_x) | |||
| self.model.load_state_dict(self.keep_weight) | |||
| return query_set_y_pred | |||
| def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys, num_local_update): | |||
| @@ -98,6 +110,7 @@ class MeLU(torch.nn.Module): | |||
| losses_q.backward() | |||
| self.meta_optim.step() | |||
| self.store_parameters() | |||
| return | |||
| def get_weight_avg_norm(self, support_set_x, support_set_y, num_local_update): | |||
| @@ -5,43 +5,65 @@ import pickle | |||
| from MeLU import MeLU | |||
| from options import config | |||
| from model_training import training | |||
| from model_test import test | |||
| from data_generation import generate | |||
| from evidence_candidate import selection | |||
| if __name__ == "__main__": | |||
| # master_path= "./ml" | |||
| master_path = "/media/external_3TB/3TB/rafie/maheri/melr" | |||
| # master_path = "/media/external_10TB/10TB/pourmand/ml" | |||
| master_path = "/media/external_10TB/10TB/maheri/melu_data" | |||
| if not os.path.exists("{}/".format(master_path)): | |||
| print("inajm") | |||
| print("generating data phase started") | |||
| os.mkdir("{}/".format(master_path)) | |||
| # preparing dataset. It needs about 22GB of your hard disk space. | |||
| generate(master_path) | |||
| # # training model. | |||
| # melu = MeLU(config) | |||
| # model_filename = "{}/models.pkl".format(master_path) | |||
| # if not os.path.exists(model_filename): | |||
| # # Load training dataset. | |||
| # training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4) | |||
| # supp_xs_s = [] | |||
| # supp_ys_s = [] | |||
| # query_xs_s = [] | |||
| # query_ys_s = [] | |||
| # for idx in range(training_set_size): | |||
| # supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb"))) | |||
| # supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb"))) | |||
| # query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb"))) | |||
| # query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb"))) | |||
| # total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||
| # del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||
| # training(melu, total_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'], model_save=True, model_filename=model_filename) | |||
| # else: | |||
| # trained_state_dict = torch.load(model_filename) | |||
| # melu.load_state_dict(trained_state_dict) | |||
| # | |||
| # # selecting evidence candidates. | |||
| # training model. | |||
| melu = MeLU(config) | |||
| model_filename = "{}/models2.pkl".format(master_path) | |||
| if not os.path.exists(model_filename): | |||
| print("training phase started") | |||
| # Load training dataset. | |||
| training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4) | |||
| supp_xs_s = [] | |||
| supp_ys_s = [] | |||
| query_xs_s = [] | |||
| query_ys_s = [] | |||
| for idx in range(training_set_size): | |||
| supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb"))) | |||
| supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb"))) | |||
| query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb"))) | |||
| query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb"))) | |||
| total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||
| del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||
| training(melu, total_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'], model_save=True, model_filename=model_filename) | |||
| else: | |||
| trained_state_dict = torch.load(model_filename) | |||
| melu.load_state_dict(trained_state_dict) | |||
| print("training finished") | |||
| # selecting evidence candidates. | |||
| # evidence_candidate_list = selection(melu, master_path, config['num_candidate']) | |||
| # for movie, score in evidence_candidate_list: | |||
| # print(movie, score) | |||
| print("start of test phase") | |||
| test_dataset = None | |||
| test_set_size = int(len(os.listdir("{}/user_cold_state".format(master_path))) / 4) | |||
| supp_xs_s = [] | |||
| supp_ys_s = [] | |||
| query_xs_s = [] | |||
| query_ys_s = [] | |||
| for idx in range(test_set_size): | |||
| supp_xs_s.append(pickle.load(open("{}/user_cold_state/supp_x_{}.pkl".format(master_path, idx), "rb"))) | |||
| supp_ys_s.append(pickle.load(open("{}/user_cold_state/supp_y_{}.pkl".format(master_path, idx), "rb"))) | |||
| query_xs_s.append(pickle.load(open("{}/user_cold_state/query_x_{}.pkl".format(master_path, idx), "rb"))) | |||
| query_ys_s.append(pickle.load(open("{}/user_cold_state/query_y_{}.pkl".format(master_path, idx), "rb"))) | |||
| test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||
| del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||
| model_filename = "{}/models_test.pkl".format(master_path) | |||
| test(melu, test_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch']) | |||
| @@ -0,0 +1,89 @@ | |||
| import os | |||
| import torch | |||
| import pickle | |||
| import random | |||
| from MeLU import MeLU | |||
| from options import config, states | |||
| from torch.nn import functional as F | |||
| from torch.nn import L1Loss | |||
| # from pytorchltr.evaluation import ndcg | |||
| import matchzoo as mz | |||
| import numpy as np | |||
| def test(melu, total_dataset, batch_size, num_epoch): | |||
| if config['use_cuda']: | |||
| melu.cuda() | |||
| test_set_size = len(total_dataset) | |||
| trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models2.pkl") | |||
| melu.load_state_dict(trained_state_dict) | |||
| melu.eval() | |||
| random.shuffle(total_dataset) | |||
| a, b, c, d = zip(*total_dataset) | |||
| losses_q = [] | |||
| predictions = None | |||
| predictions_size = None | |||
| # y_true = [] | |||
| # y_pred = [] | |||
| ndcgs1 = [] | |||
| ndcgs3 = [] | |||
| for iterator in range(test_set_size): | |||
| # trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models.pkl") | |||
| # melu.load_state_dict(trained_state_dict) | |||
| # melu.eval() | |||
| try: | |||
| supp_xs = a[iterator].cuda() | |||
| supp_ys = b[iterator].cuda() | |||
| query_xs = c[iterator].cuda() | |||
| query_ys = d[iterator].cuda() | |||
| except IndexError: | |||
| print("index error in test method") | |||
| continue | |||
| num_local_update = config['inner'] | |||
| query_set_y_pred = melu.forward(supp_xs, supp_ys, query_xs, num_local_update) | |||
| l1 = L1Loss(reduction='mean') | |||
| loss_q = l1(query_set_y_pred, query_ys) | |||
| print("testing - iterator:{} - l1:{} ".format(iterator,loss_q)) | |||
| losses_q.append(loss_q) | |||
| # if predictions is None: | |||
| # predictions = query_set_y_pred | |||
| # predictions_size = torch.FloatTensor(len(query_set_y_pred)) | |||
| # else: | |||
| # predictions = torch.cat((predictions,query_set_y_pred),0) | |||
| # predictions_size = torch.cat((predictions_size,torch.FloatTensor(len(query_set_y_pred))),0) | |||
| # y_true.append(query_ys.cpu().detach().numpy()) | |||
| # y_pred.append(query_set_y_pred.cpu().detach().numpy()) | |||
| y_true = query_ys.cpu().detach().numpy() | |||
| y_pred = query_set_y_pred.cpu().detach().numpy() | |||
| ndcgs1.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(y_true,y_pred)) | |||
| ndcgs3.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=3)(y_true, y_pred)) | |||
| del supp_xs, supp_ys, query_xs, query_ys | |||
| # calculate metrics | |||
| print(losses_q) | |||
| print("======================================") | |||
| losses_q = torch.stack(losses_q).mean(0) | |||
| print("mean of mse: ",losses_q) | |||
| print("======================================") | |||
| # n1 = ndcg(d, predictions.cuda(), predictions_size.cuda(), k=1) | |||
| # n1 = mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(np.array(y_true),np.array(y_pred)) | |||
| n1 = np.array(ndcgs1).mean() | |||
| print("nDCG1: ",n1) | |||
| n3 = np.array(ndcgs3).mean() | |||
| print("nDCG3: ", n3) | |||
| @@ -11,13 +11,16 @@ def training(melu, total_dataset, batch_size, num_epoch, model_save=True, model_ | |||
| if config['use_cuda']: | |||
| melu.cuda() | |||
| print("mode: " + str(config['use_cuda'])) | |||
| training_set_size = len(total_dataset) | |||
| melu.train() | |||
| for _ in range(num_epoch): | |||
| for epoch in range(num_epoch): | |||
| random.shuffle(total_dataset) | |||
| num_batch = int(training_set_size / batch_size) | |||
| a,b,c,d = zip(*total_dataset) | |||
| for i in range(num_batch): | |||
| print("training - epoch:{} - batch:{}".format(epoch,i)) | |||
| try: | |||
| supp_xs = list(a[batch_size*i:batch_size*(i+1)]) | |||
| supp_ys = list(b[batch_size*i:batch_size*(i+1)]) | |||
| @@ -26,6 +29,7 @@ def training(melu, total_dataset, batch_size, num_epoch, model_save=True, model_ | |||
| except IndexError: | |||
| continue | |||
| melu.global_update(supp_xs, supp_ys, query_xs, query_ys, config['inner']) | |||
| del supp_xs,supp_ys,query_xs,query_ys | |||
| if model_save: | |||
| torch.save(melu.state_dict(), model_filename) | |||