|
|
@@ -16,6 +16,37 @@ import numpy as np |
|
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
|
|
|
def data_batching(indexes, C_distribs, batch_size, training_set_size, num_clusters): |
|
|
|
probs = np.squeeze(C_distribs) |
|
|
|
cs = [np.random.choice(num_clusters, p=i) for i in probs] |
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
res = [[] for i in range(num_batch)] |
|
|
|
clas = [[] for i in range(num_clusters)] |
|
|
|
|
|
|
|
for idx, c in zip(indexes, cs): |
|
|
|
clas[c].append(idx) |
|
|
|
|
|
|
|
t = np.array([len(i) for i in clas]) |
|
|
|
t = t / t.sum() |
|
|
|
|
|
|
|
dif = list(set(list(np.arange(training_set_size))) - set(indexes[0:(num_batch * batch_size)])) |
|
|
|
cnt = 0 |
|
|
|
|
|
|
|
for i in range(len(res)): |
|
|
|
for j in range(batch_size): |
|
|
|
temp = np.random.choice(num_clusters, p=t) |
|
|
|
if len(clas[temp]) > 0: |
|
|
|
res[i].append(clas[temp].pop(0)) |
|
|
|
else: |
|
|
|
# res[i].append(indexes[training_set_size-1-cnt]) |
|
|
|
res[i].append(random.choice(dif)) |
|
|
|
cnt = cnt + 1 |
|
|
|
|
|
|
|
res = np.random.permutation(res) |
|
|
|
final_result = np.array(res).flatten() |
|
|
|
return final_result |
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
|
print("==============") |
|
|
|
parser = argparse.ArgumentParser([], description='Fast Context Adaptation via Meta-Learning (CAVIA),' |
|
|
@@ -77,6 +108,36 @@ def parse_args(): |
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
|
|
|
def kl_loss(C_distribs): |
|
|
|
# batchsize * k |
|
|
|
C_distribs = torch.stack(C_distribs).squeeze() |
|
|
|
|
|
|
|
# print("injam:",len(C_distribs)) |
|
|
|
# print(C_distribs[0].shape) |
|
|
|
# batchsize * k |
|
|
|
# print("injam2",C_distribs) |
|
|
|
C_distribs_sq = torch.pow(C_distribs, 2) |
|
|
|
# print("injam3",C_distribs_sq) |
|
|
|
# 1*k |
|
|
|
C_distribs_sum = torch.sum(C_distribs, dim=0, keepdim=True) |
|
|
|
# print("injam4",C_distribs_sum) |
|
|
|
# batchsize * k |
|
|
|
temp = C_distribs_sq / C_distribs_sum |
|
|
|
# print("injam5",temp) |
|
|
|
# batchsize * 1 |
|
|
|
temp_sum = torch.sum(temp, dim=1, keepdim=True) |
|
|
|
# print("injam6",temp_sum) |
|
|
|
target_distribs = temp / temp_sum |
|
|
|
# print("injam7",target_distribs) |
|
|
|
# calculate the kl loss |
|
|
|
clustering_loss = F.kl_div(C_distribs.log(), target_distribs, reduction='batchmean') |
|
|
|
# print("injam8",clustering_loss) |
|
|
|
return clustering_loss |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
args = parse_args() |
|
|
|
print(args) |
|
|
@@ -159,7 +220,7 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
for iteration in range(config['num_epoch']): |
|
|
|
|
|
|
|
if iteration == 1: |
|
|
|
if iteration == 0: |
|
|
|
print("changing cluster centroids started ...") |
|
|
|
indexes = list(np.arange(training_set_size)) |
|
|
|
supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] |
|
|
@@ -199,9 +260,14 @@ if __name__ == '__main__': |
|
|
|
kmeans_model = KMeans(n_clusters=config['cluster_k'], init="k-means++").fit(user_embeddings) |
|
|
|
tr.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda() |
|
|
|
|
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
indexes = list(np.arange(training_set_size)) |
|
|
|
random.shuffle(indexes) |
|
|
|
if iteration > (0): |
|
|
|
# indexes = data_batching(indexes, C_distribs, batch_size, training_set_size, config['cluster_k']) |
|
|
|
# random.shuffle(indexes) |
|
|
|
C_distribs = [] |
|
|
|
else: |
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
indexes = list(np.arange(training_set_size)) |
|
|
|
random.shuffle(indexes) |
|
|
|
|
|
|
|
for i in range(num_batch): |
|
|
|
meta_train_error = 0.0 |
|
|
@@ -224,6 +290,7 @@ if __name__ == '__main__': |
|
|
|
query_xs[j] = query_xs[j].cuda() |
|
|
|
query_ys[j] = query_ys[j].cuda() |
|
|
|
|
|
|
|
C_distribs = [] |
|
|
|
for task in range(batch_sz): |
|
|
|
# Compute meta-training loss |
|
|
|
# sxs = supp_xs[task].cuda() |
|
|
@@ -235,15 +302,16 @@ if __name__ == '__main__': |
|
|
|
temp_sxs = emb(supp_xs[task]) |
|
|
|
temp_qxs = emb(query_xs[task]) |
|
|
|
|
|
|
|
evaluation_error = fast_adapt(learner, |
|
|
|
temp_sxs, |
|
|
|
temp_qxs, |
|
|
|
supp_ys[task], |
|
|
|
query_ys[task], |
|
|
|
config['inner'], |
|
|
|
epoch=iteration) |
|
|
|
evaluation_error, c = fast_adapt(learner, |
|
|
|
temp_sxs, |
|
|
|
temp_qxs, |
|
|
|
supp_ys[task], |
|
|
|
query_ys[task], |
|
|
|
config['inner'], |
|
|
|
epoch=iteration) |
|
|
|
|
|
|
|
evaluation_error.backward() |
|
|
|
C_distribs.append(c) |
|
|
|
evaluation_error.backward(retain_graph=True) |
|
|
|
meta_train_error += evaluation_error.item() |
|
|
|
|
|
|
|
# supp_xs[task].cpu() |
|
|
@@ -255,6 +323,13 @@ if __name__ == '__main__': |
|
|
|
print('Iteration', iteration) |
|
|
|
print('Meta Train Error', meta_train_error / batch_sz) |
|
|
|
|
|
|
|
clustering_loss = config['kl_loss_weight'] * kl_loss(C_distribs) |
|
|
|
clustering_loss.backward() |
|
|
|
print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy()) |
|
|
|
|
|
|
|
# if i != (num_batch - 1): |
|
|
|
# C_distribs = [] |
|
|
|
|
|
|
|
# Average the accumulated gradients and optimize |
|
|
|
for p in all_parameters: |
|
|
|
p.grad.data.mul_(1.0 / batch_sz) |