|
|
@@ -11,6 +11,10 @@ from fast_adapt import fast_adapt |
|
|
|
import gc |
|
|
|
from learn2learn.optim.transforms import KroneckerTransform |
|
|
|
import argparse |
|
|
|
from clustering import ClustringModule, Trainer |
|
|
|
import numpy as np |
|
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
|
print("==============") |
|
|
@@ -46,8 +50,6 @@ def parse_args(): |
|
|
|
parser.add_argument('--epochs', type=int, default=config['num_epoch'], |
|
|
|
help='number of gpu to run the code') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# parser.add_argument('--data_root', type=str, default="./movielens/ml-1m", help='path to data root') |
|
|
|
# parser.add_argument('--num_workers', type=int, default=4, help='num of workers to use') |
|
|
|
# parser.add_argument('--test', action='store_true', default=False, help='num of workers to use') |
|
|
@@ -74,6 +76,7 @@ def parse_args(): |
|
|
|
# print('Running on device: {}'.format(args.device)) |
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
args = parse_args() |
|
|
|
print(args) |
|
|
@@ -81,7 +84,8 @@ if __name__ == '__main__': |
|
|
|
if config['use_cuda']: |
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
|
|
|
master_path= "/media/external_10TB/10TB/maheri/melu_data5" |
|
|
|
master_path = "/media/external_10TB/10TB/maheri/define_task_melu_data" |
|
|
|
config['master_path'] = master_path |
|
|
|
|
|
|
|
# DATA GENERATION |
|
|
|
print("DATA GENERATION PHASE") |
|
|
@@ -101,7 +105,7 @@ if __name__ == '__main__': |
|
|
|
fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) |
|
|
|
fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) |
|
|
|
linear_out = torch.nn.Linear(fc2_out_dim, 1) |
|
|
|
head = torch.nn.Sequential(fc1,fc2,linear_out) |
|
|
|
head = torch.nn.Sequential(fc1, fc2, linear_out) |
|
|
|
|
|
|
|
if use_cuda: |
|
|
|
emb = EmbeddingModule(config).cuda() |
|
|
@@ -118,20 +122,24 @@ if __name__ == '__main__': |
|
|
|
elif args.transformer == "linear": |
|
|
|
transform = l2l.optim.ModuleTransform(torch.nn.Linear) |
|
|
|
|
|
|
|
trainer = Trainer(config) |
|
|
|
|
|
|
|
# define meta algorithm |
|
|
|
if args.meta_algo == "maml": |
|
|
|
head = l2l.algorithms.MAML(head, lr=args.lr_inner,first_order=args.first_order) |
|
|
|
trainer = l2l.algorithms.MAML(trainer, lr=args.lr_inner, first_order=args.first_order) |
|
|
|
elif args.meta_algo == 'metasgd': |
|
|
|
head = l2l.algorithms.MetaSGD(head, lr=args.lr_inner,first_order=args.first_order) |
|
|
|
trainer = l2l.algorithms.MetaSGD(trainer, lr=args.lr_inner, first_order=args.first_order) |
|
|
|
elif args.meta_algo == 'gbml': |
|
|
|
head = l2l.algorithms.GBML(head, transform=transform, lr=args.lr_inner,adapt_transform=args.adapt_transform, first_order=args.first_order) |
|
|
|
trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=args.lr_inner, |
|
|
|
adapt_transform=args.adapt_transform, |
|
|
|
first_order=args.first_order) |
|
|
|
|
|
|
|
if use_cuda: |
|
|
|
head.cuda() |
|
|
|
trainer.cuda() |
|
|
|
|
|
|
|
# Setup optimization |
|
|
|
print("SETUP OPTIMIZATION PHASE") |
|
|
|
all_parameters = list(emb.parameters()) + list(head.parameters()) |
|
|
|
all_parameters = list(emb.parameters()) + list(trainer.parameters()) |
|
|
|
optimizer = torch.optim.Adam(all_parameters, lr=args.lr_meta) |
|
|
|
# loss = torch.nn.MSELoss(reduction='mean') |
|
|
|
|
|
|
@@ -142,85 +150,68 @@ if __name__ == '__main__': |
|
|
|
supp_ys_s = [] |
|
|
|
query_xs_s = [] |
|
|
|
query_ys_s = [] |
|
|
|
for idx in range(training_set_size): |
|
|
|
supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb"))) |
|
|
|
supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb"))) |
|
|
|
query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb"))) |
|
|
|
query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb"))) |
|
|
|
total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) |
|
|
|
del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) |
|
|
|
training_set_size = len(total_dataset) |
|
|
|
|
|
|
|
batch_size = config['batch_size'] |
|
|
|
# torch.cuda.empty_cache() |
|
|
|
|
|
|
|
random.shuffle(total_dataset) |
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
a, b, c, d = zip(*total_dataset) |
|
|
|
|
|
|
|
print("\n\n\n") |
|
|
|
|
|
|
|
for iteration in range(args.epochs): |
|
|
|
|
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
indexes = list(np.arange(training_set_size)) |
|
|
|
random.shuffle(indexes) |
|
|
|
|
|
|
|
for i in range(num_batch): |
|
|
|
optimizer.zero_grad() |
|
|
|
meta_train_error = 0.0 |
|
|
|
meta_train_accuracy = 0.0 |
|
|
|
meta_valid_error = 0.0 |
|
|
|
meta_valid_accuracy = 0.0 |
|
|
|
meta_test_error = 0.0 |
|
|
|
meta_test_accuracy = 0.0 |
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
print("EPOCH: ", iteration, " BATCH: ", i) |
|
|
|
supp_xs = list(a[batch_size * i:batch_size * (i + 1)]) |
|
|
|
supp_ys = list(b[batch_size * i:batch_size * (i + 1)]) |
|
|
|
query_xs = list(c[batch_size * i:batch_size * (i + 1)]) |
|
|
|
query_ys = list(d[batch_size * i:batch_size * (i + 1)]) |
|
|
|
supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] |
|
|
|
for idx in range(batch_size * i, batch_size * (i + 1)): |
|
|
|
supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
query_xs.append( |
|
|
|
pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
query_ys.append( |
|
|
|
pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
batch_sz = len(supp_xs) |
|
|
|
|
|
|
|
# if use_cuda: |
|
|
|
# for j in range(batch_size): |
|
|
|
# supp_xs[j] = supp_xs[j].cuda() |
|
|
|
# supp_ys[j] = supp_ys[j].cuda() |
|
|
|
# query_xs[j] = query_xs[j].cuda() |
|
|
|
# query_ys[j] = query_ys[j].cuda() |
|
|
|
if use_cuda: |
|
|
|
for j in range(batch_size): |
|
|
|
supp_xs[j] = supp_xs[j].cuda() |
|
|
|
supp_ys[j] = supp_ys[j].cuda() |
|
|
|
query_xs[j] = query_xs[j].cuda() |
|
|
|
query_ys[j] = query_ys[j].cuda() |
|
|
|
|
|
|
|
for task in range(batch_sz): |
|
|
|
# print("EPOCH: ", iteration," BATCH: ",i, "TASK: ",task) |
|
|
|
# Compute meta-training loss |
|
|
|
# if use_cuda: |
|
|
|
sxs = supp_xs[task].cuda() |
|
|
|
qxs = query_xs[task].cuda() |
|
|
|
sys = supp_ys[task].cuda() |
|
|
|
qys = query_ys[task].cuda() |
|
|
|
# sxs = supp_xs[task].cuda() |
|
|
|
# qxs = query_xs[task].cuda() |
|
|
|
# sys = supp_ys[task].cuda() |
|
|
|
# qys = query_ys[task].cuda() |
|
|
|
|
|
|
|
learner = head.clone() |
|
|
|
temp_sxs = emb(sxs) |
|
|
|
temp_qxs = emb(qxs) |
|
|
|
learner = trainer.clone() |
|
|
|
temp_sxs = emb(supp_xs[task]) |
|
|
|
temp_qxs = emb(query_xs[task]) |
|
|
|
|
|
|
|
evaluation_error = fast_adapt(learner, |
|
|
|
temp_sxs, |
|
|
|
temp_qxs, |
|
|
|
sys, |
|
|
|
qys, |
|
|
|
args.inner) |
|
|
|
# config['inner']) |
|
|
|
temp_sxs, |
|
|
|
temp_qxs, |
|
|
|
supp_ys[task], |
|
|
|
query_ys[task], |
|
|
|
args.inner) |
|
|
|
|
|
|
|
evaluation_error.backward() |
|
|
|
meta_train_error += evaluation_error.item() |
|
|
|
|
|
|
|
del(sxs,qxs,sys,qys) |
|
|
|
supp_xs[task].cpu() |
|
|
|
query_xs[task].cpu() |
|
|
|
supp_ys[task].cpu() |
|
|
|
query_ys[task].cpu() |
|
|
|
|
|
|
|
# supp_xs[task].cpu() |
|
|
|
# query_xs[task].cpu() |
|
|
|
# supp_ys[task].cpu() |
|
|
|
# query_ys[task].cpu() |
|
|
|
|
|
|
|
# Print some metrics |
|
|
|
print('Iteration', iteration) |
|
|
|
print('Meta Train Error', meta_train_error / batch_sz) |
|
|
|
# print('Meta Train Accuracy', meta_train_accuracy / batch_sz) |
|
|
|
# print('Meta Valid Error', meta_valid_error / batch_sz) |
|
|
|
# print('Meta Valid Accuracy', meta_valid_accuracy / batch_sz) |
|
|
|
# print('Meta Test Error', meta_test_error / batch_sz) |
|
|
|
# print('Meta Test Accuracy', meta_test_accuracy / batch_sz) |
|
|
|
|
|
|
|
# Average the accumulated gradients and optimize |
|
|
|
for p in all_parameters: |
|
|
@@ -228,40 +219,63 @@ if __name__ == '__main__': |
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
# torch.cuda.empty_cache() |
|
|
|
del(supp_xs,supp_ys,query_xs,query_ys) |
|
|
|
del (supp_xs, supp_ys, query_xs, query_ys, learner, temp_sxs, temp_qxs) |
|
|
|
gc.collect() |
|
|
|
print("===============================================\n") |
|
|
|
|
|
|
|
if iteration % 2 == 0: |
|
|
|
# testing |
|
|
|
print("start of test phase") |
|
|
|
trainer.eval() |
|
|
|
|
|
|
|
with open("results.txt", "a") as f: |
|
|
|
f.write("epoch:{}\n".format(iteration)) |
|
|
|
|
|
|
|
for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: |
|
|
|
test_dataset = None |
|
|
|
test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) |
|
|
|
supp_xs_s = [] |
|
|
|
supp_ys_s = [] |
|
|
|
query_xs_s = [] |
|
|
|
query_ys_s = [] |
|
|
|
gc.collect() |
|
|
|
|
|
|
|
print("===================== " + test_state + " =====================") |
|
|
|
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],num_epoch=config['num_epoch'],test_state=test_state) |
|
|
|
with open("results.txt", "a") as f: |
|
|
|
f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) |
|
|
|
print("===================================================") |
|
|
|
del (test_dataset) |
|
|
|
gc.collect() |
|
|
|
|
|
|
|
trainer.train() |
|
|
|
with open("results.txt", "a") as f: |
|
|
|
f.write("\n") |
|
|
|
print("\n\n\n") |
|
|
|
|
|
|
|
# save model |
|
|
|
final_model = torch.nn.Sequential(emb,head) |
|
|
|
torch.save(final_model.state_dict(), master_path + "/models_gbml.pkl") |
|
|
|
# final_model = torch.nn.Sequential(emb, head) |
|
|
|
# torch.save(final_model.state_dict(), master_path + "/models_gbml.pkl") |
|
|
|
|
|
|
|
# testing |
|
|
|
print("start of test phase") |
|
|
|
for test_state in ['warm_state', 'user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: |
|
|
|
test_dataset = None |
|
|
|
test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) |
|
|
|
supp_xs_s = [] |
|
|
|
supp_ys_s = [] |
|
|
|
query_xs_s = [] |
|
|
|
query_ys_s = [] |
|
|
|
for idx in range(test_set_size): |
|
|
|
supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, idx), "rb"))) |
|
|
|
supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, idx), "rb"))) |
|
|
|
query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, idx), "rb"))) |
|
|
|
query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, idx), "rb"))) |
|
|
|
test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) |
|
|
|
del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) |
|
|
|
|
|
|
|
print("===================== " + test_state + " =====================") |
|
|
|
test(emb,head, test_dataset, batch_size=config['batch_size'], num_epoch=args.epochs,adaptation_step=args.inner_eval) |
|
|
|
print("===================================================\n\n\n") |
|
|
|
print(args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# print("start of test phase") |
|
|
|
# for test_state in ['warm_state', 'user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: |
|
|
|
# test_dataset = None |
|
|
|
# test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) |
|
|
|
# supp_xs_s = [] |
|
|
|
# supp_ys_s = [] |
|
|
|
# query_xs_s = [] |
|
|
|
# query_ys_s = [] |
|
|
|
# for idx in range(test_set_size): |
|
|
|
# supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, idx), "rb"))) |
|
|
|
# supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, idx), "rb"))) |
|
|
|
# query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, idx), "rb"))) |
|
|
|
# query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, idx), "rb"))) |
|
|
|
# test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) |
|
|
|
# del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) |
|
|
|
# |
|
|
|
# print("===================== " + test_state + " =====================") |
|
|
|
# test(emb, head, test_dataset, batch_size=config['batch_size'], num_epoch=args.epochs, |
|
|
|
# adaptation_step=args.inner_eval) |
|
|
|
# print("===================================================\n\n\n") |
|
|
|
# print(args) |