Browse Source

tanp loss (kl loss)

define_task
mohamad maheri 2 years ago
parent
commit
acacb68d33
3 changed files with 98 additions and 20 deletions
  1. 2
    1
      clustering.py
  2. 9
    7
      fast_adapt.py
  3. 87
    12
      learnToLearn.py

+ 2
- 1
clustering.py View File

# 1*k, k*d, 1*d # 1*k, k*d, 1*d
value = torch.mm(C, self.array) value = torch.mm(C, self.array)
# simple add operation # simple add operation
new_task_embed = value + mean_task
# new_task_embed = value + mean_task
new_task_embed = value


return C, new_task_embed return C, new_task_embed



+ 9
- 7
fast_adapt.py View File

for step in range(adaptation_steps): for step in range(adaptation_steps):
temp, c = learn(adaptation_data, adaptation_labels, training=True) temp, c = learn(adaptation_data, adaptation_labels, training=True)
train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
cluster_loss = cl_loss(c)
total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
# cluster_loss = cl_loss(c)
# total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
total_loss = train_error
learn.adapt(total_loss) learn.adapt(total_loss)


predictions, c = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data, predictions, c = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
adaptation_labels=adaptation_labels) adaptation_labels=adaptation_labels)
valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
cluster_loss = cl_loss(c)
total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss
# cluster_loss = cl_loss(c)
# total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss
total_loss = valid_error


if random.random() < 0.05:
print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())
# if random.random() < 0.05:
# print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())


if get_predictions: if get_predictions:
return total_loss, predictions return total_loss, predictions
return total_loss
return total_loss,c

+ 87
- 12
learnToLearn.py View File

from torch.nn import functional as F from torch.nn import functional as F




def data_batching(indexes, C_distribs, batch_size, training_set_size, num_clusters):
probs = np.squeeze(C_distribs)
cs = [np.random.choice(num_clusters, p=i) for i in probs]
num_batch = int(training_set_size / batch_size)
res = [[] for i in range(num_batch)]
clas = [[] for i in range(num_clusters)]

for idx, c in zip(indexes, cs):
clas[c].append(idx)

t = np.array([len(i) for i in clas])
t = t / t.sum()

dif = list(set(list(np.arange(training_set_size))) - set(indexes[0:(num_batch * batch_size)]))
cnt = 0

for i in range(len(res)):
for j in range(batch_size):
temp = np.random.choice(num_clusters, p=t)
if len(clas[temp]) > 0:
res[i].append(clas[temp].pop(0))
else:
# res[i].append(indexes[training_set_size-1-cnt])
res[i].append(random.choice(dif))
cnt = cnt + 1

res = np.random.permutation(res)
final_result = np.array(res).flatten()
return final_result


def parse_args(): def parse_args():
print("==============") print("==============")
parser = argparse.ArgumentParser([], description='Fast Context Adaptation via Meta-Learning (CAVIA),' parser = argparse.ArgumentParser([], description='Fast Context Adaptation via Meta-Learning (CAVIA),'
return args return args




from torch.nn import functional as F


def kl_loss(C_distribs):
# batchsize * k
C_distribs = torch.stack(C_distribs).squeeze()

# print("injam:",len(C_distribs))
# print(C_distribs[0].shape)
# batchsize * k
# print("injam2",C_distribs)
C_distribs_sq = torch.pow(C_distribs, 2)
# print("injam3",C_distribs_sq)
# 1*k
C_distribs_sum = torch.sum(C_distribs, dim=0, keepdim=True)
# print("injam4",C_distribs_sum)
# batchsize * k
temp = C_distribs_sq / C_distribs_sum
# print("injam5",temp)
# batchsize * 1
temp_sum = torch.sum(temp, dim=1, keepdim=True)
# print("injam6",temp_sum)
target_distribs = temp / temp_sum
# print("injam7",target_distribs)
# calculate the kl loss
clustering_loss = F.kl_div(C_distribs.log(), target_distribs, reduction='batchmean')
# print("injam8",clustering_loss)
return clustering_loss


if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
print(args) print(args)


for iteration in range(config['num_epoch']): for iteration in range(config['num_epoch']):


if iteration == 1:
if iteration == 0:
print("changing cluster centroids started ...") print("changing cluster centroids started ...")
indexes = list(np.arange(training_set_size)) indexes = list(np.arange(training_set_size))
supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
kmeans_model = KMeans(n_clusters=config['cluster_k'], init="k-means++").fit(user_embeddings) kmeans_model = KMeans(n_clusters=config['cluster_k'], init="k-means++").fit(user_embeddings)
tr.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda() tr.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda()


num_batch = int(training_set_size / batch_size)
indexes = list(np.arange(training_set_size))
random.shuffle(indexes)
if iteration > (0):
# indexes = data_batching(indexes, C_distribs, batch_size, training_set_size, config['cluster_k'])
# random.shuffle(indexes)
C_distribs = []
else:
num_batch = int(training_set_size / batch_size)
indexes = list(np.arange(training_set_size))
random.shuffle(indexes)


for i in range(num_batch): for i in range(num_batch):
meta_train_error = 0.0 meta_train_error = 0.0
query_xs[j] = query_xs[j].cuda() query_xs[j] = query_xs[j].cuda()
query_ys[j] = query_ys[j].cuda() query_ys[j] = query_ys[j].cuda()


C_distribs = []
for task in range(batch_sz): for task in range(batch_sz):
# Compute meta-training loss # Compute meta-training loss
# sxs = supp_xs[task].cuda() # sxs = supp_xs[task].cuda()
temp_sxs = emb(supp_xs[task]) temp_sxs = emb(supp_xs[task])
temp_qxs = emb(query_xs[task]) temp_qxs = emb(query_xs[task])


evaluation_error = fast_adapt(learner,
temp_sxs,
temp_qxs,
supp_ys[task],
query_ys[task],
config['inner'],
epoch=iteration)
evaluation_error, c = fast_adapt(learner,
temp_sxs,
temp_qxs,
supp_ys[task],
query_ys[task],
config['inner'],
epoch=iteration)


evaluation_error.backward()
C_distribs.append(c)
evaluation_error.backward(retain_graph=True)
meta_train_error += evaluation_error.item() meta_train_error += evaluation_error.item()


# supp_xs[task].cpu() # supp_xs[task].cpu()
print('Iteration', iteration) print('Iteration', iteration)
print('Meta Train Error', meta_train_error / batch_sz) print('Meta Train Error', meta_train_error / batch_sz)


clustering_loss = config['kl_loss_weight'] * kl_loss(C_distribs)
clustering_loss.backward()
print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy())

# if i != (num_batch - 1):
# C_distribs = []

# Average the accumulated gradients and optimize # Average the accumulated gradients and optimize
for p in all_parameters: for p in all_parameters:
p.grad.data.mul_(1.0 / batch_sz) p.grad.data.mul_(1.0 / batch_sz)

Loading…
Cancel
Save