@@ -61,23 +61,35 @@ class MeLU(torch.nn.Module): | |||
self.fast_weights = OrderedDict() | |||
def forward(self, support_set_x, support_set_y, query_set_x, num_local_update): | |||
# this line added my maheri | |||
self.keep_weight = deepcopy(self.model.state_dict()) | |||
for idx in range(num_local_update): | |||
if idx > 0: | |||
self.model.load_state_dict(self.fast_weights) | |||
# weight_for_local_update = list(self.model.state_dict().values()) | |||
weight_for_local_update = list(self.model.state_dict().values()) | |||
support_set_y_pred = self.model(support_set_x) | |||
loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1)) | |||
self.model.zero_grad() | |||
grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) | |||
# local update | |||
for i in range(self.weight_len): | |||
if self.weight_name[i] in self.local_update_target_weight_name: | |||
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i] | |||
else: | |||
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] | |||
self.model.load_state_dict(self.fast_weights) | |||
# self.fast_weights = OrderedDict() | |||
query_set_y_pred = self.model(query_set_x) | |||
self.model.load_state_dict(self.keep_weight) | |||
return query_set_y_pred | |||
def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys, num_local_update): | |||
@@ -98,6 +110,7 @@ class MeLU(torch.nn.Module): | |||
losses_q.backward() | |||
self.meta_optim.step() | |||
self.store_parameters() | |||
return | |||
def get_weight_avg_norm(self, support_set_x, support_set_y, num_local_update): |
@@ -5,43 +5,65 @@ import pickle | |||
from MeLU import MeLU | |||
from options import config | |||
from model_training import training | |||
from model_test import test | |||
from data_generation import generate | |||
from evidence_candidate import selection | |||
if __name__ == "__main__": | |||
# master_path= "./ml" | |||
master_path = "/media/external_3TB/3TB/rafie/maheri/melr" | |||
# master_path = "/media/external_10TB/10TB/pourmand/ml" | |||
master_path = "/media/external_10TB/10TB/maheri/melu_data" | |||
if not os.path.exists("{}/".format(master_path)): | |||
print("inajm") | |||
print("generating data phase started") | |||
os.mkdir("{}/".format(master_path)) | |||
# preparing dataset. It needs about 22GB of your hard disk space. | |||
generate(master_path) | |||
# # training model. | |||
# melu = MeLU(config) | |||
# model_filename = "{}/models.pkl".format(master_path) | |||
# if not os.path.exists(model_filename): | |||
# # Load training dataset. | |||
# training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4) | |||
# supp_xs_s = [] | |||
# supp_ys_s = [] | |||
# query_xs_s = [] | |||
# query_ys_s = [] | |||
# for idx in range(training_set_size): | |||
# supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb"))) | |||
# supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb"))) | |||
# query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb"))) | |||
# query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb"))) | |||
# total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||
# del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||
# training(melu, total_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'], model_save=True, model_filename=model_filename) | |||
# else: | |||
# trained_state_dict = torch.load(model_filename) | |||
# melu.load_state_dict(trained_state_dict) | |||
# | |||
# # selecting evidence candidates. | |||
# training model. | |||
melu = MeLU(config) | |||
model_filename = "{}/models2.pkl".format(master_path) | |||
if not os.path.exists(model_filename): | |||
print("training phase started") | |||
# Load training dataset. | |||
training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4) | |||
supp_xs_s = [] | |||
supp_ys_s = [] | |||
query_xs_s = [] | |||
query_ys_s = [] | |||
for idx in range(training_set_size): | |||
supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb"))) | |||
supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb"))) | |||
query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb"))) | |||
query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb"))) | |||
total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||
del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||
training(melu, total_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'], model_save=True, model_filename=model_filename) | |||
else: | |||
trained_state_dict = torch.load(model_filename) | |||
melu.load_state_dict(trained_state_dict) | |||
print("training finished") | |||
# selecting evidence candidates. | |||
# evidence_candidate_list = selection(melu, master_path, config['num_candidate']) | |||
# for movie, score in evidence_candidate_list: | |||
# print(movie, score) | |||
print("start of test phase") | |||
test_dataset = None | |||
test_set_size = int(len(os.listdir("{}/user_cold_state".format(master_path))) / 4) | |||
supp_xs_s = [] | |||
supp_ys_s = [] | |||
query_xs_s = [] | |||
query_ys_s = [] | |||
for idx in range(test_set_size): | |||
supp_xs_s.append(pickle.load(open("{}/user_cold_state/supp_x_{}.pkl".format(master_path, idx), "rb"))) | |||
supp_ys_s.append(pickle.load(open("{}/user_cold_state/supp_y_{}.pkl".format(master_path, idx), "rb"))) | |||
query_xs_s.append(pickle.load(open("{}/user_cold_state/query_x_{}.pkl".format(master_path, idx), "rb"))) | |||
query_ys_s.append(pickle.load(open("{}/user_cold_state/query_y_{}.pkl".format(master_path, idx), "rb"))) | |||
test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) | |||
del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||
model_filename = "{}/models_test.pkl".format(master_path) | |||
test(melu, test_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch']) |
@@ -0,0 +1,89 @@ | |||
import os | |||
import torch | |||
import pickle | |||
import random | |||
from MeLU import MeLU | |||
from options import config, states | |||
from torch.nn import functional as F | |||
from torch.nn import L1Loss | |||
# from pytorchltr.evaluation import ndcg | |||
import matchzoo as mz | |||
import numpy as np | |||
def test(melu, total_dataset, batch_size, num_epoch): | |||
if config['use_cuda']: | |||
melu.cuda() | |||
test_set_size = len(total_dataset) | |||
trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models2.pkl") | |||
melu.load_state_dict(trained_state_dict) | |||
melu.eval() | |||
random.shuffle(total_dataset) | |||
a, b, c, d = zip(*total_dataset) | |||
losses_q = [] | |||
predictions = None | |||
predictions_size = None | |||
# y_true = [] | |||
# y_pred = [] | |||
ndcgs1 = [] | |||
ndcgs3 = [] | |||
for iterator in range(test_set_size): | |||
# trained_state_dict = torch.load("/media/external_10TB/10TB/maheri/melu_data/models.pkl") | |||
# melu.load_state_dict(trained_state_dict) | |||
# melu.eval() | |||
try: | |||
supp_xs = a[iterator].cuda() | |||
supp_ys = b[iterator].cuda() | |||
query_xs = c[iterator].cuda() | |||
query_ys = d[iterator].cuda() | |||
except IndexError: | |||
print("index error in test method") | |||
continue | |||
num_local_update = config['inner'] | |||
query_set_y_pred = melu.forward(supp_xs, supp_ys, query_xs, num_local_update) | |||
l1 = L1Loss(reduction='mean') | |||
loss_q = l1(query_set_y_pred, query_ys) | |||
print("testing - iterator:{} - l1:{} ".format(iterator,loss_q)) | |||
losses_q.append(loss_q) | |||
# if predictions is None: | |||
# predictions = query_set_y_pred | |||
# predictions_size = torch.FloatTensor(len(query_set_y_pred)) | |||
# else: | |||
# predictions = torch.cat((predictions,query_set_y_pred),0) | |||
# predictions_size = torch.cat((predictions_size,torch.FloatTensor(len(query_set_y_pred))),0) | |||
# y_true.append(query_ys.cpu().detach().numpy()) | |||
# y_pred.append(query_set_y_pred.cpu().detach().numpy()) | |||
y_true = query_ys.cpu().detach().numpy() | |||
y_pred = query_set_y_pred.cpu().detach().numpy() | |||
ndcgs1.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(y_true,y_pred)) | |||
ndcgs3.append(mz.metrics.NormalizedDiscountedCumulativeGain(k=3)(y_true, y_pred)) | |||
del supp_xs, supp_ys, query_xs, query_ys | |||
# calculate metrics | |||
print(losses_q) | |||
print("======================================") | |||
losses_q = torch.stack(losses_q).mean(0) | |||
print("mean of mse: ",losses_q) | |||
print("======================================") | |||
# n1 = ndcg(d, predictions.cuda(), predictions_size.cuda(), k=1) | |||
# n1 = mz.metrics.NormalizedDiscountedCumulativeGain(k=1)(np.array(y_true),np.array(y_pred)) | |||
n1 = np.array(ndcgs1).mean() | |||
print("nDCG1: ",n1) | |||
n3 = np.array(ndcgs3).mean() | |||
print("nDCG3: ", n3) | |||
@@ -11,13 +11,16 @@ def training(melu, total_dataset, batch_size, num_epoch, model_save=True, model_ | |||
if config['use_cuda']: | |||
melu.cuda() | |||
print("mode: " + str(config['use_cuda'])) | |||
training_set_size = len(total_dataset) | |||
melu.train() | |||
for _ in range(num_epoch): | |||
for epoch in range(num_epoch): | |||
random.shuffle(total_dataset) | |||
num_batch = int(training_set_size / batch_size) | |||
a,b,c,d = zip(*total_dataset) | |||
for i in range(num_batch): | |||
print("training - epoch:{} - batch:{}".format(epoch,i)) | |||
try: | |||
supp_xs = list(a[batch_size*i:batch_size*(i+1)]) | |||
supp_ys = list(b[batch_size*i:batch_size*(i+1)]) | |||
@@ -26,6 +29,7 @@ def training(melu, total_dataset, batch_size, num_epoch, model_save=True, model_ | |||
except IndexError: | |||
continue | |||
melu.global_update(supp_xs, supp_ys, query_xs, query_ys, config['inner']) | |||
del supp_xs,supp_ys,query_xs,query_ys | |||
if model_save: | |||
torch.save(melu.state_dict(), model_filename) |