self.fast_weights = OrderedDict() | self.fast_weights = OrderedDict() | ||||
def forward(self, support_set_x, support_set_y, query_set_x, num_local_update): | def forward(self, support_set_x, support_set_y, query_set_x, num_local_update): | ||||
# this line added my maheri | |||||
self.keep_weight = deepcopy(self.model.state_dict()) | |||||
for idx in range(num_local_update): | for idx in range(num_local_update): | ||||
if idx > 0: | if idx > 0: | ||||
self.model.load_state_dict(self.fast_weights) | self.model.load_state_dict(self.fast_weights) | ||||
# weight_for_local_update = list(self.model.state_dict().values()) | |||||
weight_for_local_update = list(self.model.state_dict().values()) | weight_for_local_update = list(self.model.state_dict().values()) | ||||
support_set_y_pred = self.model(support_set_x) | support_set_y_pred = self.model(support_set_x) | ||||
loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1)) | loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1)) | ||||
self.model.zero_grad() | self.model.zero_grad() | ||||
grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) | grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) | ||||
# local update | # local update | ||||
for i in range(self.weight_len): | for i in range(self.weight_len): | ||||
if self.weight_name[i] in self.local_update_target_weight_name: | if self.weight_name[i] in self.local_update_target_weight_name: | ||||
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i] | self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i] | ||||
else: | else: | ||||
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] | self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] | ||||
self.model.load_state_dict(self.fast_weights) | self.model.load_state_dict(self.fast_weights) | ||||
# self.fast_weights = OrderedDict() | |||||
query_set_y_pred = self.model(query_set_x) | query_set_y_pred = self.model(query_set_x) | ||||
self.model.load_state_dict(self.keep_weight) | self.model.load_state_dict(self.keep_weight) | ||||
return query_set_y_pred | return query_set_y_pred | ||||
def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys, num_local_update): | def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys, num_local_update): | ||||
losses_q.backward() | losses_q.backward() | ||||
self.meta_optim.step() | self.meta_optim.step() | ||||
self.store_parameters() | self.store_parameters() | ||||
return | return | ||||
def get_weight_avg_norm(self, support_set_x, support_set_y, num_local_update): | def get_weight_avg_norm(self, support_set_x, support_set_y, num_local_update): |
from MeLU import MeLU | from MeLU import MeLU | ||||
from options import config | from options import config | ||||
from model_training import training | from model_training import training | ||||
from model_test import test | |||||
from data_generation import generate | from data_generation import generate | ||||
from evidence_candidate import selection | from evidence_candidate import selection | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
# master_path= "./ml" | # master_path= "./ml" | ||||
master_path = "/media/external_3TB/3TB/rafie/maheri/melr" | |||||
# master_path = "/media/external_10TB/10TB/pourmand/ml" | |||||
master_path = "/media/external_10TB/10TB/maheri/melu_data" | |||||
if not os.path.exists("{}/".format(master_path)): | if not os.path.exists("{}/".format(master_path)): | ||||
print("inajm") | |||||
print("generating data phase started") | |||||
os.mkdir("{}/".format(master_path)) | os.mkdir("{}/".format(master_path)) | ||||
# preparing dataset. It needs about 22GB of your hard disk space. | # preparing dataset. It needs about 22GB of your hard disk space. | ||||
generate(master_path) | generate(master_path) | ||||
# # training model. | |||||
# melu = MeLU(config) | |||||
# model_filename = "{}/models.pkl".format(master_path) | |||||
# if not os.path.exists(model_filename): | |||||
# # Load training dataset. | |||||
# training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4) | |||||
# supp_xs_s = [] | |||||
# supp_ys_s = [] | |||||
# query_xs_s = [] | |||||
# query_ys_s = [] | |||||
# for idx in range(training_set_size): | |||||
# supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb"))) | |||||
# supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb"))) | |||||
# query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb"))) | |||||
# query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb"))) | |||||
# total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||||
# del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||||
# training(melu, total_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'], model_save=True, model_filename=model_filename) | |||||
# else: | |||||
# trained_state_dict = torch.load(model_filename) | |||||
# melu.load_state_dict(trained_state_dict) | |||||
# | |||||
# # selecting evidence candidates. | |||||
# training model. | |||||
melu = MeLU(config) | |||||
model_filename = "{}/models2.pkl".format(master_path) | |||||
if not os.path.exists(model_filename): | |||||
print("training phase started") | |||||
# Load training dataset. | |||||
training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4) | |||||
supp_xs_s = [] | |||||
supp_ys_s = [] | |||||
query_xs_s = [] | |||||
query_ys_s = [] | |||||
for idx in range(training_set_size): | |||||
supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb"))) | |||||
supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb"))) | |||||
query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb"))) | |||||
query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb"))) | |||||
total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||||
del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||||
training(melu, total_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'], model_save=True, model_filename=model_filename) | |||||
else: | |||||
trained_state_dict = torch.load(model_filename) | |||||
melu.load_state_dict(trained_state_dict) | |||||
print("training finished") | |||||
# selecting evidence candidates. | |||||
# evidence_candidate_list = selection(melu, master_path, config['num_candidate']) | # evidence_candidate_list = selection(melu, master_path, config['num_candidate']) | ||||
# for movie, score in evidence_candidate_list: | # for movie, score in evidence_candidate_list: | ||||
# print(movie, score) | # print(movie, score) | ||||
print("start of test phase") | |||||
test_dataset = None | |||||
test_set_size = int(len(os.listdir("{}/user_cold_state".format(master_path))) / 4) | |||||
supp_xs_s = [] | |||||
supp_ys_s = [] | |||||
query_xs_s = [] | |||||
query_ys_s = [] | |||||
for idx in range(test_set_size): | |||||
supp_xs_s.append(pickle.load(open("{}/user_cold_state/supp_x_{}.pkl".format(master_path, idx), "rb"))) | |||||
supp_ys_s.append(pickle.load(open("{}/user_cold_state/supp_y_{}.pkl".format(master_path, idx), "rb"))) | |||||
query_xs_s.append(pickle.load(open("{}/user_cold_state/query_x_{}.pkl".format(master_path, idx), "rb"))) | |||||
query_ys_s.append(pickle.load(open("{}/user_cold_state/query_y_{}.pkl".format(master_path, idx), "rb"))) | |||||
test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||||
del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||||
model_filename = "{}/models_test.pkl".format(master_path) | |||||
test(melu, test_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch']) |
import os | |||||
import torch | |||||
import pickle | |||||
import random | |||||
from MeLU import MeLU | |||||
from options import config, states | |||||
from torch.nn import functional as F | |||||
from torch.nn import L1Loss | |||||
# from pytorchltr.evaluation import ndcg | |||||
import matchzoo as mz | |||||
import numpy as np | |||||
def test(melu, total_dataset, batch_size, num_epoch): | |||||
if config['use_cuda']: | |||||
melu.cuda() | |||||
test_set_size = len(total_dataset) | |||||
trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models2.pkl") | |||||
melu.load_state_dict(trained_state_dict) | |||||
melu.eval() | |||||
random.shuffle(total_dataset) | |||||
a, b, c, d = zip(*total_dataset) | |||||
losses_q = [] | |||||
predictions = None | |||||
predictions_size = None | |||||
# y_true = [] | |||||
# y_pred = [] | |||||
ndcgs1 = [] | |||||
ndcgs3 = [] | |||||
for iterator in range(test_set_size): | |||||
# trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models.pkl") | |||||
# melu.load_state_dict(trained_state_dict) | |||||
# melu.eval() | |||||
try: | |||||
supp_xs = a[iterator].cuda() | |||||
supp_ys = b[iterator].cuda() | |||||
query_xs = c[iterator].cuda() | |||||
query_ys = d[iterator].cuda() | |||||
except IndexError: | |||||
print("index error in test method") | |||||
continue | |||||
num_local_update = config['inner'] | |||||
query_set_y_pred = melu.forward(supp_xs, supp_ys, query_xs, num_local_update) | |||||
l1 = L1Loss(reduction='mean') | |||||
loss_q = l1(query_set_y_pred, query_ys) | |||||
print("testing - iterator:{} - l1:{} ".format(iterator,loss_q)) | |||||
losses_q.append(loss_q) | |||||
# if predictions is None: | |||||
# predictions = query_set_y_pred | |||||
# predictions_size = torch.FloatTensor(len(query_set_y_pred)) | |||||
# else: | |||||
# predictions = torch.cat((predictions,query_set_y_pred),0) | |||||
# predictions_size = torch.cat((predictions_size,torch.FloatTensor(len(query_set_y_pred))),0) | |||||
# y_true.append(query_ys.cpu().detach().numpy()) | |||||
# y_pred.append(query_set_y_pred.cpu().detach().numpy()) | |||||
y_true = query_ys.cpu().detach().numpy() | |||||
y_pred = query_set_y_pred.cpu().detach().numpy() | |||||
ndcgs1.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(y_true,y_pred)) | |||||
ndcgs3.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=3)(y_true, y_pred)) | |||||
del supp_xs, supp_ys, query_xs, query_ys | |||||
# calculate metrics | |||||
print(losses_q) | |||||
print("======================================") | |||||
losses_q = torch.stack(losses_q).mean(0) | |||||
print("mean of mse: ",losses_q) | |||||
print("======================================") | |||||
# n1 = ndcg(d, predictions.cuda(), predictions_size.cuda(), k=1) | |||||
# n1 = mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(np.array(y_true),np.array(y_pred)) | |||||
n1 = np.array(ndcgs1).mean() | |||||
print("nDCG1: ",n1) | |||||
n3 = np.array(ndcgs3).mean() | |||||
print("nDCG3: ", n3) | |||||
if config['use_cuda']: | if config['use_cuda']: | ||||
melu.cuda() | melu.cuda() | ||||
print("mode: " + str(config['use_cuda'])) | |||||
training_set_size = len(total_dataset) | training_set_size = len(total_dataset) | ||||
melu.train() | melu.train() | ||||
for _ in range(num_epoch): | |||||
for epoch in range(num_epoch): | |||||
random.shuffle(total_dataset) | random.shuffle(total_dataset) | ||||
num_batch = int(training_set_size / batch_size) | num_batch = int(training_set_size / batch_size) | ||||
a,b,c,d = zip(*total_dataset) | a,b,c,d = zip(*total_dataset) | ||||
for i in range(num_batch): | for i in range(num_batch): | ||||
print("training - epoch:{} - batch:{}".format(epoch,i)) | |||||
try: | try: | ||||
supp_xs = list(a[batch_size*i:batch_size*(i+1)]) | supp_xs = list(a[batch_size*i:batch_size*(i+1)]) | ||||
supp_ys = list(b[batch_size*i:batch_size*(i+1)]) | supp_ys = list(b[batch_size*i:batch_size*(i+1)]) | ||||
except IndexError: | except IndexError: | ||||
continue | continue | ||||
melu.global_update(supp_xs, supp_ys, query_xs, query_ys, config['inner']) | melu.global_update(supp_xs, supp_ys, query_xs, query_ys, config['inner']) | ||||
del supp_xs,supp_ys,query_xs,query_ys | |||||
if model_save: | if model_save: | ||||
torch.save(melu.state_dict(), model_filename) | torch.save(melu.state_dict(), model_filename) |