|
|
@@ -74,7 +74,7 @@ def parse_args(): |
|
|
|
help='run adaptation transform') |
|
|
|
parser.add_argument('--transformer', type=str, default="kronoker", |
|
|
|
help='transformer type') |
|
|
|
parser.add_argument('--meta_algo', type=str, default="gbml", |
|
|
|
parser.add_argument('--meta_algo', type=str, default="metasgd", |
|
|
|
help='MAML/MetaSGD/GBML') |
|
|
|
parser.add_argument('--gpu', type=int, default=0, |
|
|
|
help='number of gpu to run the code') |
|
|
@@ -163,11 +163,6 @@ if __name__ == '__main__': |
|
|
|
fc2_out_dim = config['second_fc_hidden_dim'] |
|
|
|
use_cuda = config['use_cuda'] |
|
|
|
|
|
|
|
# fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) |
|
|
|
# fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) |
|
|
|
# linear_out = torch.nn.Linear(fc2_out_dim, 1) |
|
|
|
# head = torch.nn.Sequential(fc1, fc2, linear_out) |
|
|
|
|
|
|
|
if use_cuda: |
|
|
|
emb = EmbeddingModule(config).cuda() |
|
|
|
else: |
|
|
@@ -244,7 +239,8 @@ if __name__ == '__main__': |
|
|
|
temp_sxs = emb(supp_xs[task]) |
|
|
|
# temp_qxs = emb(query_xs[task]) |
|
|
|
y = supp_ys[task].view(-1, 1) |
|
|
|
input_pairs = torch.cat((temp_sxs, y), dim=1) |
|
|
|
# input_pairs = torch.cat((temp_sxs, y), dim=1) |
|
|
|
input_pairs = temp_sxs |
|
|
|
task_embed = tr.cluster_module.input_to_hidden(input_pairs) |
|
|
|
|
|
|
|
# todo : may be useless |
|
|
@@ -263,7 +259,9 @@ if __name__ == '__main__': |
|
|
|
if iteration > 0: |
|
|
|
# indexes = data_batching(indexes, C_distribs, batch_size, training_set_size, config['cluster_k']) |
|
|
|
# random.shuffle(indexes) |
|
|
|
C_distribs = [] |
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
indexes = list(np.arange(training_set_size)) |
|
|
|
random.shuffle(indexes) |
|
|
|
else: |
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
indexes = list(np.arange(training_set_size)) |
|
|
@@ -293,12 +291,6 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
C_distribs = [] |
|
|
|
for task in range(batch_sz): |
|
|
|
# Compute meta-training loss |
|
|
|
# sxs = supp_xs[task].cuda() |
|
|
|
# qxs = query_xs[task].cuda() |
|
|
|
# sys = supp_ys[task].cuda() |
|
|
|
# qys = query_ys[task].cuda() |
|
|
|
|
|
|
|
learner = trainer.clone() |
|
|
|
temp_sxs = emb(supp_xs[task]) |
|
|
|
temp_qxs = emb(query_xs[task]) |
|
|
@@ -316,11 +308,6 @@ if __name__ == '__main__': |
|
|
|
meta_train_error += evaluation_error.item() |
|
|
|
meta_cluster_error += k_loss |
|
|
|
|
|
|
|
# supp_xs[task].cpu() |
|
|
|
# query_xs[task].cpu() |
|
|
|
# supp_ys[task].cpu() |
|
|
|
# query_ys[task].cpu() |
|
|
|
|
|
|
|
# Print some metrics |
|
|
|
print('Iteration', iteration) |
|
|
|
print('Meta Train Error', meta_train_error / batch_sz) |
|
|
@@ -330,49 +317,46 @@ if __name__ == '__main__': |
|
|
|
# clustering_loss.backward() |
|
|
|
# print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy()) |
|
|
|
|
|
|
|
# if i != (num_batch - 1): |
|
|
|
# C_distribs = [] |
|
|
|
|
|
|
|
# Average the accumulated gradients and optimize |
|
|
|
for p in all_parameters: |
|
|
|
p.grad.data.mul_(1.0 / batch_sz) |
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
# torch.cuda.empty_cache() |
|
|
|
del (supp_xs, supp_ys, query_xs, query_ys, learner, temp_sxs, temp_qxs) |
|
|
|
gc.collect() |
|
|
|
# del (supp_xs, supp_ys, query_xs, query_ys, learner, temp_sxs, temp_qxs) |
|
|
|
# gc.collect() |
|
|
|
print("===============================================\n") |
|
|
|
|
|
|
|
# if iteration % 2 == 0 and iteration != 0: |
|
|
|
# # testing |
|
|
|
# print("start of test phase") |
|
|
|
# trainer.eval() |
|
|
|
# |
|
|
|
# with open("results2.txt", "a") as f: |
|
|
|
# f.write("epoch:{}\n".format(iteration)) |
|
|
|
# |
|
|
|
# for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: |
|
|
|
# test_dataset = None |
|
|
|
# test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) |
|
|
|
# supp_xs_s = [] |
|
|
|
# supp_ys_s = [] |
|
|
|
# query_xs_s = [] |
|
|
|
# query_ys_s = [] |
|
|
|
# gc.collect() |
|
|
|
# |
|
|
|
# print("===================== " + test_state + " =====================") |
|
|
|
# mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'], |
|
|
|
# num_epoch=config['num_epoch'], test_state=test_state, args=args) |
|
|
|
# with open("results2.txt", "a") as f: |
|
|
|
# f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) |
|
|
|
# print("===================================================") |
|
|
|
# del (test_dataset) |
|
|
|
# gc.collect() |
|
|
|
# |
|
|
|
# trainer.train() |
|
|
|
# with open("results2.txt", "a") as f: |
|
|
|
# f.write("\n") |
|
|
|
# print("\n\n\n") |
|
|
|
if iteration % 2 == 0 and iteration != 0: |
|
|
|
# testing |
|
|
|
print("start of test phase") |
|
|
|
trainer.eval() |
|
|
|
|
|
|
|
with open("results2.txt", "a") as f: |
|
|
|
f.write("epoch:{}\n".format(iteration)) |
|
|
|
|
|
|
|
for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: |
|
|
|
test_dataset = None |
|
|
|
test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) |
|
|
|
supp_xs_s = [] |
|
|
|
supp_ys_s = [] |
|
|
|
query_xs_s = [] |
|
|
|
query_ys_s = [] |
|
|
|
gc.collect() |
|
|
|
|
|
|
|
print("===================== " + test_state + " =====================") |
|
|
|
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'], |
|
|
|
num_epoch=config['num_epoch'], test_state=test_state, args=args) |
|
|
|
with open("results2.txt", "a") as f: |
|
|
|
f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) |
|
|
|
print("===================================================") |
|
|
|
del (test_dataset) |
|
|
|
gc.collect() |
|
|
|
|
|
|
|
trainer.train() |
|
|
|
with open("results2.txt", "a") as f: |
|
|
|
f.write("\n") |
|
|
|
print("\n\n\n") |
|
|
|
|
|
|
|
# save model |
|
|
|
# final_model = torch.nn.Sequential(emb, head) |