|
|
|
|
|
|
|
|
from torch.nn import functional as F |
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def data_batching(indexes, C_distribs, batch_size, training_set_size, num_clusters): |
|
|
|
|
|
probs = np.squeeze(C_distribs) |
|
|
|
|
|
cs = [np.random.choice(num_clusters, p=i) for i in probs] |
|
|
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
|
|
res = [[] for i in range(num_batch)] |
|
|
|
|
|
clas = [[] for i in range(num_clusters)] |
|
|
|
|
|
|
|
|
|
|
|
for idx, c in zip(indexes, cs): |
|
|
|
|
|
clas[c].append(idx) |
|
|
|
|
|
|
|
|
|
|
|
t = np.array([len(i) for i in clas]) |
|
|
|
|
|
t = t / t.sum() |
|
|
|
|
|
|
|
|
|
|
|
dif = list(set(list(np.arange(training_set_size))) - set(indexes[0:(num_batch * batch_size)])) |
|
|
|
|
|
cnt = 0 |
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(res)): |
|
|
|
|
|
for j in range(batch_size): |
|
|
|
|
|
temp = np.random.choice(num_clusters, p=t) |
|
|
|
|
|
if len(clas[temp]) > 0: |
|
|
|
|
|
res[i].append(clas[temp].pop(0)) |
|
|
|
|
|
else: |
|
|
|
|
|
# res[i].append(indexes[training_set_size-1-cnt]) |
|
|
|
|
|
res[i].append(random.choice(dif)) |
|
|
|
|
|
cnt = cnt + 1 |
|
|
|
|
|
|
|
|
|
|
|
res = np.random.permutation(res) |
|
|
|
|
|
final_result = np.array(res).flatten() |
|
|
|
|
|
return final_result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
def parse_args(): |
|
|
print("==============") |
|
|
print("==============") |
|
|
parser = argparse.ArgumentParser([], description='Fast Context Adaptation via Meta-Learning (CAVIA),' |
|
|
parser = argparse.ArgumentParser([], description='Fast Context Adaptation via Meta-Learning (CAVIA),' |
|
|
|
|
|
|
|
|
print("==============\n") |
|
|
print("==============\n") |
|
|
|
|
|
|
|
|
parser.add_argument('--seed', type=int, default=53) |
|
|
parser.add_argument('--seed', type=int, default=53) |
|
|
# parser.add_argument('--task', type=str, default='multi', help='problem setting: sine or celeba') |
|
|
|
|
|
# parser.add_argument('--tasks_per_metaupdate', type=int, default=32, |
|
|
|
|
|
# help='number of tasks in each batch per meta-update') |
|
|
|
|
|
# |
|
|
|
|
|
# parser.add_argument('--lr_inner', type=float, default=5e-6, help='inner-loop learning rate (per task)') |
|
|
|
|
|
# parser.add_argument('--lr_meta', type=float, default=5e-5, |
|
|
|
|
|
# help='outer-loop learning rate (used with Adam optimiser)') |
|
|
|
|
|
|
|
|
parser.add_argument('--task', type=str, default='multi', help='problem setting: sine or celeba') |
|
|
|
|
|
parser.add_argument('--tasks_per_metaupdate', type=int, default=32, |
|
|
|
|
|
help='number of tasks in each batch per meta-update') |
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--lr_inner', type=float, default=5e-6, help='inner-loop learning rate (per task)') |
|
|
|
|
|
parser.add_argument('--lr_meta', type=float, default=5e-5, |
|
|
|
|
|
help='outer-loop learning rate (used with Adam optimiser)') |
|
|
# parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate') |
|
|
# parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate') |
|
|
# |
|
|
|
|
|
# parser.add_argument('--inner', type=int, default=1, |
|
|
|
|
|
# help='number of gradient steps in inner loop (during training)') |
|
|
|
|
|
# parser.add_argument('--inner_eval', type=int, default=1, |
|
|
|
|
|
# help='number of gradient updates at test time (for evaluation)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--inner', type=int, default=1, |
|
|
|
|
|
help='number of gradient steps in inner loop (during training)') |
|
|
|
|
|
parser.add_argument('--inner_eval', type=int, default=1, |
|
|
|
|
|
help='number of gradient updates at test time (for evaluation)') |
|
|
|
|
|
|
|
|
parser.add_argument('--first_order', action='store_true', default=False, |
|
|
parser.add_argument('--first_order', action='store_true', default=False, |
|
|
help='run first order approximation of CAVIA') |
|
|
help='run first order approximation of CAVIA') |
|
|
|
|
|
|
|
|
help='run adaptation transform') |
|
|
help='run adaptation transform') |
|
|
parser.add_argument('--transformer', type=str, default="kronoker", |
|
|
parser.add_argument('--transformer', type=str, default="kronoker", |
|
|
help='transformer type') |
|
|
help='transformer type') |
|
|
parser.add_argument('--meta_algo', type=str, default="metasgd", |
|
|
|
|
|
|
|
|
parser.add_argument('--meta_algo', type=str, default="gbml", |
|
|
help='MAML/MetaSGD/GBML') |
|
|
help='MAML/MetaSGD/GBML') |
|
|
parser.add_argument('--gpu', type=int, default=0, |
|
|
parser.add_argument('--gpu', type=int, default=0, |
|
|
help='number of gpu to run the code') |
|
|
help='number of gpu to run the code') |
|
|
|
|
|
|
|
|
return args |
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def kl_loss(C_distribs): |
|
|
|
|
|
# batchsize * k |
|
|
|
|
|
C_distribs = torch.stack(C_distribs).squeeze() |
|
|
|
|
|
|
|
|
|
|
|
# print("injam:",len(C_distribs)) |
|
|
|
|
|
# print(C_distribs[0].shape) |
|
|
|
|
|
# batchsize * k |
|
|
|
|
|
# print("injam2",C_distribs) |
|
|
|
|
|
C_distribs_sq = torch.pow(C_distribs, 2) |
|
|
|
|
|
# print("injam3",C_distribs_sq) |
|
|
|
|
|
# 1*k |
|
|
|
|
|
C_distribs_sum = torch.sum(C_distribs, dim=0, keepdim=True) |
|
|
|
|
|
# print("injam4",C_distribs_sum) |
|
|
|
|
|
# batchsize * k |
|
|
|
|
|
temp = C_distribs_sq / C_distribs_sum |
|
|
|
|
|
# print("injam5",temp) |
|
|
|
|
|
# batchsize * 1 |
|
|
|
|
|
temp_sum = torch.sum(temp, dim=1, keepdim=True) |
|
|
|
|
|
# print("injam6",temp_sum) |
|
|
|
|
|
target_distribs = temp / temp_sum |
|
|
|
|
|
# print("injam7",target_distribs) |
|
|
|
|
|
# calculate the kl loss |
|
|
|
|
|
clustering_loss = F.kl_div(C_distribs.log(), target_distribs, reduction='batchmean') |
|
|
|
|
|
# print("injam8",clustering_loss) |
|
|
|
|
|
return clustering_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if __name__ == '__main__': |
|
|
args = parse_args() |
|
|
args = parse_args() |
|
|
print(args) |
|
|
print(args) |
|
|
|
|
|
|
|
|
if config['use_cuda']: |
|
|
if config['use_cuda']: |
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
|
|
master_path = "/media/external_10TB/10TB/maheri/new_data_dir3" |
|
|
|
|
|
|
|
|
master_path = "/media/external_10TB/10TB/maheri/define_task_melu_data2" |
|
|
config['master_path'] = master_path |
|
|
config['master_path'] = master_path |
|
|
|
|
|
|
|
|
# DATA GENERATION |
|
|
# DATA GENERATION |
|
|
|
|
|
|
|
|
fc2_out_dim = config['second_fc_hidden_dim'] |
|
|
fc2_out_dim = config['second_fc_hidden_dim'] |
|
|
use_cuda = config['use_cuda'] |
|
|
use_cuda = config['use_cuda'] |
|
|
|
|
|
|
|
|
fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) |
|
|
|
|
|
fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) |
|
|
|
|
|
linear_out = torch.nn.Linear(fc2_out_dim, 1) |
|
|
|
|
|
head = torch.nn.Sequential(fc1, fc2, linear_out) |
|
|
|
|
|
|
|
|
# fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) |
|
|
|
|
|
# fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) |
|
|
|
|
|
# linear_out = torch.nn.Linear(fc2_out_dim, 1) |
|
|
|
|
|
# head = torch.nn.Sequential(fc1, fc2, linear_out) |
|
|
|
|
|
|
|
|
if use_cuda: |
|
|
if use_cuda: |
|
|
emb = EmbeddingModule(config).cuda() |
|
|
emb = EmbeddingModule(config).cuda() |
|
|
|
|
|
|
|
|
transform = l2l.optim.ModuleTransform(torch.nn.Linear) |
|
|
transform = l2l.optim.ModuleTransform(torch.nn.Linear) |
|
|
|
|
|
|
|
|
trainer = Trainer(config) |
|
|
trainer = Trainer(config) |
|
|
|
|
|
tr = trainer |
|
|
|
|
|
|
|
|
# define meta algorithm |
|
|
# define meta algorithm |
|
|
if args.meta_algo == "maml": |
|
|
if args.meta_algo == "maml": |
|
|
trainer = l2l.algorithms.MAML(trainer, lr=args.lr_inner, first_order=args.first_order) |
|
|
trainer = l2l.algorithms.MAML(trainer, lr=args.lr_inner, first_order=args.first_order) |
|
|
elif args.meta_algo == 'metasgd': |
|
|
elif args.meta_algo == 'metasgd': |
|
|
trainer = l2l.algorithms.MetaSGD(trainer, lr=config['local_lr'], first_order=args.first_order) |
|
|
|
|
|
|
|
|
trainer = l2l.algorithms.MetaSGD(trainer, lr=args.lr_inner, first_order=args.first_order) |
|
|
elif args.meta_algo == 'gbml': |
|
|
elif args.meta_algo == 'gbml': |
|
|
trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=config['local_lr'], |
|
|
|
|
|
|
|
|
trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=args.lr_inner, |
|
|
adapt_transform=args.adapt_transform, |
|
|
adapt_transform=args.adapt_transform, |
|
|
first_order=args.first_order) |
|
|
first_order=args.first_order) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n\n\n") |
|
|
print("\n\n\n") |
|
|
|
|
|
|
|
|
for iteration in range(args.epochs): |
|
|
|
|
|
|
|
|
for iteration in range(config['num_epoch']): |
|
|
|
|
|
|
|
|
|
|
|
if iteration == 0: |
|
|
|
|
|
print("changing cluster centroids started ...") |
|
|
|
|
|
indexes = list(np.arange(training_set_size)) |
|
|
|
|
|
supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] |
|
|
|
|
|
for idx in range(0, 2500): |
|
|
|
|
|
supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
|
|
supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
|
|
query_xs.append( |
|
|
|
|
|
pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
|
|
query_ys.append( |
|
|
|
|
|
pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) |
|
|
|
|
|
batch_sz = len(supp_xs) |
|
|
|
|
|
|
|
|
|
|
|
user_embeddings = [] |
|
|
|
|
|
|
|
|
|
|
|
for task in range(batch_sz): |
|
|
|
|
|
# Compute meta-training loss |
|
|
|
|
|
supp_xs[task] = supp_xs[task].cuda() |
|
|
|
|
|
supp_ys[task] = supp_ys[task].cuda() |
|
|
|
|
|
# query_xs[task] = query_xs[task].cuda() |
|
|
|
|
|
# query_ys[task] = query_ys[task].cuda() |
|
|
|
|
|
temp_sxs = emb(supp_xs[task]) |
|
|
|
|
|
# temp_qxs = emb(query_xs[task]) |
|
|
|
|
|
y = supp_ys[task].view(-1, 1) |
|
|
|
|
|
input_pairs = torch.cat((temp_sxs, y), dim=1) |
|
|
|
|
|
task_embed = tr.cluster_module.input_to_hidden(input_pairs) |
|
|
|
|
|
|
|
|
|
|
|
# todo : may be useless |
|
|
|
|
|
mean_task = tr.cluster_module.aggregate(task_embed) |
|
|
|
|
|
user_embeddings.append(mean_task.detach().cpu().numpy()) |
|
|
|
|
|
|
|
|
|
|
|
supp_xs[task] = supp_xs[task].cpu() |
|
|
|
|
|
supp_ys[task] = supp_ys[task].cpu() |
|
|
|
|
|
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
|
|
indexes = list(np.arange(training_set_size)) |
|
|
|
|
|
random.shuffle(indexes) |
|
|
|
|
|
|
|
|
from sklearn.cluster import KMeans |
|
|
|
|
|
|
|
|
|
|
|
user_embeddings = np.array(user_embeddings) |
|
|
|
|
|
kmeans_model = KMeans(n_clusters=config['cluster_k'], init="k-means++").fit(user_embeddings) |
|
|
|
|
|
tr.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda() |
|
|
|
|
|
|
|
|
|
|
|
if iteration > 0: |
|
|
|
|
|
# indexes = data_batching(indexes, C_distribs, batch_size, training_set_size, config['cluster_k']) |
|
|
|
|
|
# random.shuffle(indexes) |
|
|
|
|
|
C_distribs = [] |
|
|
|
|
|
else: |
|
|
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
|
|
indexes = list(np.arange(training_set_size)) |
|
|
|
|
|
random.shuffle(indexes) |
|
|
|
|
|
|
|
|
for i in range(num_batch): |
|
|
for i in range(num_batch): |
|
|
meta_train_error = 0.0 |
|
|
meta_train_error = 0.0 |
|
|
|
|
|
meta_cluster_error = 0.0 |
|
|
optimizer.zero_grad() |
|
|
optimizer.zero_grad() |
|
|
print("EPOCH: ", iteration, " BATCH: ", i) |
|
|
print("EPOCH: ", iteration, " BATCH: ", i) |
|
|
supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] |
|
|
supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] |
|
|
|
|
|
|
|
|
query_xs[j] = query_xs[j].cuda() |
|
|
query_xs[j] = query_xs[j].cuda() |
|
|
query_ys[j] = query_ys[j].cuda() |
|
|
query_ys[j] = query_ys[j].cuda() |
|
|
|
|
|
|
|
|
|
|
|
C_distribs = [] |
|
|
for task in range(batch_sz): |
|
|
for task in range(batch_sz): |
|
|
# Compute meta-training loss |
|
|
# Compute meta-training loss |
|
|
# sxs = supp_xs[task].cuda() |
|
|
# sxs = supp_xs[task].cuda() |
|
|
|
|
|
|
|
|
temp_sxs = emb(supp_xs[task]) |
|
|
temp_sxs = emb(supp_xs[task]) |
|
|
temp_qxs = emb(query_xs[task]) |
|
|
temp_qxs = emb(query_xs[task]) |
|
|
|
|
|
|
|
|
evaluation_error = fast_adapt(learner, |
|
|
|
|
|
temp_sxs, |
|
|
|
|
|
temp_qxs, |
|
|
|
|
|
supp_ys[task], |
|
|
|
|
|
query_ys[task], |
|
|
|
|
|
config['inner']) |
|
|
|
|
|
|
|
|
evaluation_error, c, k_loss = fast_adapt(learner, |
|
|
|
|
|
temp_sxs, |
|
|
|
|
|
temp_qxs, |
|
|
|
|
|
supp_ys[task], |
|
|
|
|
|
query_ys[task], |
|
|
|
|
|
config['inner'], |
|
|
|
|
|
epoch=iteration) |
|
|
|
|
|
|
|
|
evaluation_error.backward() |
|
|
|
|
|
|
|
|
# C_distribs.append(c) |
|
|
|
|
|
evaluation_error.backward(retain_graph=True) |
|
|
meta_train_error += evaluation_error.item() |
|
|
meta_train_error += evaluation_error.item() |
|
|
|
|
|
meta_cluster_error += k_loss |
|
|
|
|
|
|
|
|
# supp_xs[task].cpu() |
|
|
# supp_xs[task].cpu() |
|
|
# query_xs[task].cpu() |
|
|
# query_xs[task].cpu() |
|
|
|
|
|
|
|
|
# Print some metrics |
|
|
# Print some metrics |
|
|
print('Iteration', iteration) |
|
|
print('Iteration', iteration) |
|
|
print('Meta Train Error', meta_train_error / batch_sz) |
|
|
print('Meta Train Error', meta_train_error / batch_sz) |
|
|
|
|
|
print('KL Train Error', meta_cluster_error / batch_sz) |
|
|
|
|
|
|
|
|
|
|
|
# clustering_loss = config['kl_loss_weight'] * kl_loss(C_distribs) |
|
|
|
|
|
# clustering_loss.backward() |
|
|
|
|
|
# print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy()) |
|
|
|
|
|
|
|
|
|
|
|
# if i != (num_batch - 1): |
|
|
|
|
|
# C_distribs = [] |
|
|
|
|
|
|
|
|
# Average the accumulated gradients and optimize |
|
|
# Average the accumulated gradients and optimize |
|
|
for p in all_parameters: |
|
|
for p in all_parameters: |
|
|
|
|
|
|
|
|
gc.collect() |
|
|
gc.collect() |
|
|
print("===============================================\n") |
|
|
print("===============================================\n") |
|
|
|
|
|
|
|
|
if iteration % 2 == 0 or iteration>0: |
|
|
|
|
|
# testing |
|
|
|
|
|
print("start of test phase") |
|
|
|
|
|
trainer.eval() |
|
|
|
|
|
|
|
|
|
|
|
with open("results2.txt", "a") as f: |
|
|
|
|
|
f.write("epoch:{}\n".format(iteration)) |
|
|
|
|
|
|
|
|
|
|
|
for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: |
|
|
|
|
|
test_dataset = None |
|
|
|
|
|
test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) |
|
|
|
|
|
supp_xs_s = [] |
|
|
|
|
|
supp_ys_s = [] |
|
|
|
|
|
query_xs_s = [] |
|
|
|
|
|
query_ys_s = [] |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
|
|
|
|
|
|
print("===================== " + test_state + " =====================") |
|
|
|
|
|
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],num_epoch=config['num_epoch'],test_state=test_state,args=args) |
|
|
|
|
|
with open("results2.txt", "a") as f: |
|
|
|
|
|
f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) |
|
|
|
|
|
print("===================================================") |
|
|
|
|
|
del (test_dataset) |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
with open("results2.txt", "a") as f: |
|
|
|
|
|
f.write("\n") |
|
|
|
|
|
print("\n\n\n") |
|
|
|
|
|
|
|
|
# if iteration % 2 == 0 and iteration != 0: |
|
|
|
|
|
# # testing |
|
|
|
|
|
# print("start of test phase") |
|
|
|
|
|
# trainer.eval() |
|
|
|
|
|
# |
|
|
|
|
|
# with open("results2.txt", "a") as f: |
|
|
|
|
|
# f.write("epoch:{}\n".format(iteration)) |
|
|
|
|
|
# |
|
|
|
|
|
# for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']: |
|
|
|
|
|
# test_dataset = None |
|
|
|
|
|
# test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4) |
|
|
|
|
|
# supp_xs_s = [] |
|
|
|
|
|
# supp_ys_s = [] |
|
|
|
|
|
# query_xs_s = [] |
|
|
|
|
|
# query_ys_s = [] |
|
|
|
|
|
# gc.collect() |
|
|
|
|
|
# |
|
|
|
|
|
# print("===================== " + test_state + " =====================") |
|
|
|
|
|
# mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'], |
|
|
|
|
|
# num_epoch=config['num_epoch'], test_state=test_state, args=args) |
|
|
|
|
|
# with open("results2.txt", "a") as f: |
|
|
|
|
|
# f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) |
|
|
|
|
|
# print("===================================================") |
|
|
|
|
|
# del (test_dataset) |
|
|
|
|
|
# gc.collect() |
|
|
|
|
|
# |
|
|
|
|
|
# trainer.train() |
|
|
|
|
|
# with open("results2.txt", "a") as f: |
|
|
|
|
|
# f.write("\n") |
|
|
|
|
|
# print("\n\n\n") |
|
|
|
|
|
|
|
|
# save model |
|
|
# save model |
|
|
# final_model = torch.nn.Sequential(emb, head) |
|
|
# final_model = torch.nn.Sequential(emb, head) |