|
|
|
|
|
|
|
|
help='outer-loop learning rate (used with Adam optimiser)') |
|
|
help='outer-loop learning rate (used with Adam optimiser)') |
|
|
# parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate') |
|
|
# parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate') |
|
|
|
|
|
|
|
|
parser.add_argument('--inner', type=int, default=2, |
|
|
|
|
|
|
|
|
parser.add_argument('--inner', type=int, default=1, |
|
|
help='number of gradient steps in inner loop (during training)') |
|
|
help='number of gradient steps in inner loop (during training)') |
|
|
parser.add_argument('--inner_eval', type=int, default=2, |
|
|
|
|
|
|
|
|
parser.add_argument('--inner_eval', type=int, default=1, |
|
|
help='number of gradient updates at test time (for evaluation)') |
|
|
help='number of gradient updates at test time (for evaluation)') |
|
|
|
|
|
|
|
|
parser.add_argument('--first_order', action='store_true', default=False, |
|
|
parser.add_argument('--first_order', action='store_true', default=False, |
|
|
|
|
|
|
|
|
fc2_out_dim = config['second_fc_hidden_dim'] |
|
|
fc2_out_dim = config['second_fc_hidden_dim'] |
|
|
use_cuda = config['use_cuda'] |
|
|
use_cuda = config['use_cuda'] |
|
|
|
|
|
|
|
|
fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) |
|
|
|
|
|
fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) |
|
|
|
|
|
linear_out = torch.nn.Linear(fc2_out_dim, 1) |
|
|
|
|
|
head = torch.nn.Sequential(fc1, fc2, linear_out) |
|
|
|
|
|
|
|
|
# fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) |
|
|
|
|
|
# fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) |
|
|
|
|
|
# linear_out = torch.nn.Linear(fc2_out_dim, 1) |
|
|
|
|
|
# head = torch.nn.Sequential(fc1, fc2, linear_out) |
|
|
|
|
|
|
|
|
if use_cuda: |
|
|
if use_cuda: |
|
|
emb = EmbeddingModule(config).cuda() |
|
|
emb = EmbeddingModule(config).cuda() |
|
|
|
|
|
|
|
|
transform = l2l.optim.ModuleTransform(torch.nn.Linear) |
|
|
transform = l2l.optim.ModuleTransform(torch.nn.Linear) |
|
|
|
|
|
|
|
|
trainer = Trainer(config) |
|
|
trainer = Trainer(config) |
|
|
|
|
|
tr = trainer |
|
|
|
|
|
|
|
|
# define meta algorithm |
|
|
# define meta algorithm |
|
|
if args.meta_algo == "maml": |
|
|
if args.meta_algo == "maml": |
|
|
|
|
|
|
|
|
# Setup optimization |
|
|
# Setup optimization |
|
|
print("SETUP OPTIMIZATION PHASE") |
|
|
print("SETUP OPTIMIZATION PHASE") |
|
|
all_parameters = list(emb.parameters()) + list(trainer.parameters()) |
|
|
all_parameters = list(emb.parameters()) + list(trainer.parameters()) |
|
|
optimizer = torch.optim.Adam(all_parameters, lr=args.lr_meta) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(all_parameters, lr=config['lr']) |
|
|
# loss = torch.nn.MSELoss(reduction='mean') |
|
|
# loss = torch.nn.MSELoss(reduction='mean') |
|
|
|
|
|
|
|
|
# Load training dataset. |
|
|
# Load training dataset. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n\n\n") |
|
|
print("\n\n\n") |
|
|
|
|
|
|
|
|
for iteration in range(args.epochs): |
|
|
|
|
|
|
|
|
for iteration in range(config['num_epoch']): |
|
|
|
|
|
|
|
|
|
|
|
if iteration == 1: |
|
|
|
|
|
print("changing cluster centroids started ...") |
|
|
|
|
|
indexes = list(np.arange(training_set_size)) |
|
|
|
|
|
supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] |
|
|
|
|
|
for idx in range(0, 2500): |
|
|
|
|
|
supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
|
|
supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
|
|
query_xs.append( |
|
|
|
|
|
pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
|
|
query_ys.append( |
|
|
|
|
|
pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
|
|
batch_sz = len(supp_xs) |
|
|
|
|
|
|
|
|
|
|
|
user_embeddings = [] |
|
|
|
|
|
|
|
|
|
|
|
for task in range(batch_sz): |
|
|
|
|
|
# Compute meta-training loss |
|
|
|
|
|
supp_xs[task] = supp_xs[task].cuda() |
|
|
|
|
|
supp_ys[task] = supp_ys[task].cuda() |
|
|
|
|
|
# query_xs[task] = query_xs[task].cuda() |
|
|
|
|
|
# query_ys[task] = query_ys[task].cuda() |
|
|
|
|
|
temp_sxs = emb(supp_xs[task]) |
|
|
|
|
|
# temp_qxs = emb(query_xs[task]) |
|
|
|
|
|
y = supp_ys[task].view(-1, 1) |
|
|
|
|
|
input_pairs = torch.cat((temp_sxs, y), dim=1) |
|
|
|
|
|
task_embed = tr.cluster_module.input_to_hidden(input_pairs) |
|
|
|
|
|
|
|
|
|
|
|
# todo : may be useless |
|
|
|
|
|
mean_task = tr.cluster_module.aggregate(task_embed) |
|
|
|
|
|
user_embeddings.append(mean_task.detach().cpu().numpy()) |
|
|
|
|
|
|
|
|
|
|
|
supp_xs[task] = supp_xs[task].cpu() |
|
|
|
|
|
supp_ys[task] = supp_ys[task].cpu() |
|
|
|
|
|
|
|
|
|
|
|
from sklearn.cluster import KMeans |
|
|
|
|
|
|
|
|
|
|
|
user_embeddings = np.array(user_embeddings) |
|
|
|
|
|
kmeans_model = KMeans(n_clusters=config['cluster_k'], init="k-means++").fit(user_embeddings) |
|
|
|
|
|
tr.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda() |
|
|
|
|
|
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
num_batch = int(training_set_size / batch_size) |
|
|
indexes = list(np.arange(training_set_size)) |
|
|
indexes = list(np.arange(training_set_size)) |
|
|
|
|
|
|
|
|
temp_qxs, |
|
|
temp_qxs, |
|
|
supp_ys[task], |
|
|
supp_ys[task], |
|
|
query_ys[task], |
|
|
query_ys[task], |
|
|
args.inner) |
|
|
|
|
|
|
|
|
config['inner'], |
|
|
|
|
|
epoch=iteration) |
|
|
|
|
|
|
|
|
evaluation_error.backward() |
|
|
evaluation_error.backward() |
|
|
meta_train_error += evaluation_error.item() |
|
|
meta_train_error += evaluation_error.item() |
|
|
|
|
|
|
|
|
gc.collect() |
|
|
gc.collect() |
|
|
print("===============================================\n") |
|
|
print("===============================================\n") |
|
|
|
|
|
|
|
|
if iteration % 2 == 0: |
|
|
|
|
|
|
|
|
if iteration % 2 == 0 and iteration != 0: |
|
|
# testing |
|
|
# testing |
|
|
print("start of test phase") |
|
|
print("start of test phase") |
|
|
trainer.eval() |
|
|
trainer.eval() |
|
|
|
|
|
|
|
|
gc.collect() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
print("===================== " + test_state + " =====================") |
|
|
print("===================== " + test_state + " =====================") |
|
|
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],num_epoch=config['num_epoch'],test_state=test_state,args=args) |
|
|
|
|
|
|
|
|
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'], |
|
|
|
|
|
num_epoch=config['num_epoch'], test_state=test_state, args=args) |
|
|
with open("results2.txt", "a") as f: |
|
|
with open("results2.txt", "a") as f: |
|
|
f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) |
|
|
f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) |
|
|
print("===================================================") |
|
|
print("===================================================") |