Browse Source

deep kmeans loss

define_task
mohamad maheri 2 years ago
parent
commit
b4d130e28e
3 changed files with 104 additions and 54 deletions
  1. 42
    6
      clustering.py
  2. 18
    7
      fast_adapt.py
  3. 44
    41
      learnToLearn.py

+ 42
- 6
clustering.py View File

@@ -51,9 +51,45 @@ class ClustringModule(torch.nn.Module):
value = torch.mm(C, self.array)
# simple add operation
# new_task_embed = value + mean_task
new_task_embed = value

return C, new_task_embed
# new_task_embed = value
new_task_embed = mean_task

# print("injam1:", new_task_embed)
# print("injam2:", self.array)
list_dist = []
# list_dist = torch.norm(new_task_embed - self.array, p=2, dim=1,keepdim=True)
list_dist = torch.sum(torch.pow(new_task_embed - self.array,2),dim=1)
stack_dist = list_dist

# print("injam3:", stack_dist)

## Second, find the minimum squared distance for softmax normalization
min_dist = min(list_dist)
# print("injam4:", min_dist)

## Third, compute exponentials shifted with min_dist to avoid underflow (0/0) issues in softmaxes
alpha = config['kmeans_alpha'] # Placeholder tensor for alpha
list_exp = []
for i in range(self.clusters_k):
exp = torch.exp(-alpha * (stack_dist[i] - min_dist))
list_exp.append(exp)
stack_exp = torch.stack(list_exp)
sum_exponentials = torch.sum(stack_exp)

# print("injam5:", stack_exp, sum_exponentials)

## Fourth, compute softmaxes and the embedding/representative distances weighted by softmax
list_softmax = []
list_weighted_dist = []
for j in range(self.clusters_k):
softmax = stack_exp[j] / sum_exponentials
weighted_dist = stack_dist[j] * softmax
list_softmax.append(softmax)
list_weighted_dist.append(weighted_dist)
stack_weighted_dist = torch.stack(list_weighted_dist)

kmeans_loss = torch.sum(stack_weighted_dist, dim=0)
return C, new_task_embed, kmeans_loss


class Trainer(torch.nn.Module):
@@ -85,7 +121,7 @@ class Trainer(torch.nn.Module):

def forward(self, task_embed, y, training, adaptation_data=None, adaptation_labels=None):
if training:
C, clustered_task_embed = self.cluster_module(task_embed, y)
C, clustered_task_embed, k_loss = self.cluster_module(task_embed, y)
# hidden layers
# todo : adding activation function or remove it
hidden_1 = self.fc1(task_embed)
@@ -105,7 +141,7 @@ class Trainer(torch.nn.Module):
y_pred = self.linear_out(hidden_3)

else:
C, clustered_task_embed = self.cluster_module(adaptation_data, adaptation_labels)
C, clustered_task_embed, k_loss = self.cluster_module(adaptation_data, adaptation_labels)
beta_1 = torch.tanh(self.film_layer_1_beta(clustered_task_embed))
gamma_1 = torch.tanh(self.film_layer_1_gamma(clustered_task_embed))
beta_2 = torch.tanh(self.film_layer_2_beta(clustered_task_embed))
@@ -123,4 +159,4 @@ class Trainer(torch.nn.Module):

y_pred = self.linear_out(hidden_3)

return y_pred, C
return y_pred, C, k_loss

+ 18
- 7
fast_adapt.py View File

@@ -8,7 +8,8 @@ def cl_loss(c):
alpha = config['alpha']
beta = config['beta']
d = config['d']
a = torch.div(1, torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta))))))
a = torch.div(1,
torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c.squeeze()), beta))))))
# a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
# b = 1 * a * (1 - a) * (1 - 2 * a)
@@ -25,24 +26,34 @@ def fast_adapt(
adaptation_steps,
get_predictions=False,
epoch=None):
is_print = random.random() < 0.05

for step in range(adaptation_steps):
temp, c = learn(adaptation_data, adaptation_labels, training=True)
temp, c, k_loss = learn(adaptation_data, adaptation_labels, training=True)
train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
# cluster_loss = cl_loss(c)
# total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
total_loss = train_error
total_loss = train_error + config['kmeans_loss_weight'] * k_loss
learn.adapt(total_loss)
if is_print:
# print("in support:\t", round(k_loss.item(),4))
pass

predictions, c = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
adaptation_labels=adaptation_labels)
predictions, c, k_loss = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
adaptation_labels=adaptation_labels)
valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
# cluster_loss = cl_loss(c)
# total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss
total_loss = valid_error
total_loss = valid_error + config['kmeans_loss_weight'] * k_loss

if is_print:

# print("in query:\t", round(k_loss.item(),4))
print(c[0].detach().cpu().numpy(),"\t",round(k_loss.item(),3),"\n")

# if random.random() < 0.05:
# print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())

if get_predictions:
return total_loss, predictions
return total_loss,c
return total_loss, c, k_loss.item()

+ 44
- 41
learnToLearn.py View File

@@ -271,6 +271,7 @@ if __name__ == '__main__':

for i in range(num_batch):
meta_train_error = 0.0
meta_cluster_error = 0.0
optimizer.zero_grad()
print("EPOCH: ", iteration, " BATCH: ", i)
supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
@@ -302,17 +303,18 @@ if __name__ == '__main__':
temp_sxs = emb(supp_xs[task])
temp_qxs = emb(query_xs[task])

evaluation_error, c = fast_adapt(learner,
temp_sxs,
temp_qxs,
supp_ys[task],
query_ys[task],
config['inner'],
epoch=iteration)
evaluation_error, c, k_loss = fast_adapt(learner,
temp_sxs,
temp_qxs,
supp_ys[task],
query_ys[task],
config['inner'],
epoch=iteration)

C_distribs.append(c)
# C_distribs.append(c)
evaluation_error.backward(retain_graph=True)
meta_train_error += evaluation_error.item()
meta_cluster_error += k_loss

# supp_xs[task].cpu()
# query_xs[task].cpu()
@@ -322,10 +324,11 @@ if __name__ == '__main__':
# Print some metrics
print('Iteration', iteration)
print('Meta Train Error', meta_train_error / batch_sz)
print('KL Train Error', meta_cluster_error / batch_sz)

clustering_loss = config['kl_loss_weight'] * kl_loss(C_distribs)
clustering_loss.backward()
print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy())
# clustering_loss = config['kl_loss_weight'] * kl_loss(C_distribs)
# clustering_loss.backward()
# print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy())

# if i != (num_batch - 1):
# C_distribs = []
@@ -340,36 +343,36 @@ if __name__ == '__main__':
gc.collect()
print("===============================================\n")

if iteration % 2 == 0 and iteration != 0:
# testing
print("start of test phase")
trainer.eval()
with open("results2.txt", "a") as f:
f.write("epoch:{}\n".format(iteration))
for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']:
test_dataset = None
test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
supp_xs_s = []
supp_ys_s = []
query_xs_s = []
query_ys_s = []
gc.collect()
print("===================== " + test_state + " =====================")
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],
num_epoch=config['num_epoch'], test_state=test_state, args=args)
with open("results2.txt", "a") as f:
f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3))
print("===================================================")
del (test_dataset)
gc.collect()
trainer.train()
with open("results2.txt", "a") as f:
f.write("\n")
print("\n\n\n")
# if iteration % 2 == 0 and iteration != 0:
# # testing
# print("start of test phase")
# trainer.eval()
#
# with open("results2.txt", "a") as f:
# f.write("epoch:{}\n".format(iteration))
#
# for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']:
# test_dataset = None
# test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
# supp_xs_s = []
# supp_ys_s = []
# query_xs_s = []
# query_ys_s = []
# gc.collect()
#
# print("===================== " + test_state + " =====================")
# mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],
# num_epoch=config['num_epoch'], test_state=test_state, args=args)
# with open("results2.txt", "a") as f:
# f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3))
# print("===================================================")
# del (test_dataset)
# gc.collect()
#
# trainer.train()
# with open("results2.txt", "a") as f:
# f.write("\n")
# print("\n\n\n")

# save model
# final_model = torch.nn.Sequential(emb, head)

Loading…
Cancel
Save