Browse Source

cluster loss (try to hard assignment)

define_task
mohamad maheri 2 years ago
parent
commit
d425ba0c97
4 changed files with 89 additions and 23 deletions
  1. 2
    2
      clustering.py
  2. 29
    7
      fast_adapt.py
  3. 54
    11
      learnToLearn.py
  4. 4
    3
      learnToLearnTest.py

+ 2
- 2
clustering.py View File

@@ -82,7 +82,7 @@ class Trainer(torch.nn.Module):
def aggregate(self, z_i):
return torch.mean(z_i, dim=0)

def forward(self, task_embed, y, training,adaptation_data=None,adaptation_labels=None):
def forward(self, task_embed, y, training, adaptation_data=None, adaptation_labels=None):
if training:
C, clustered_task_embed = self.cluster_module(task_embed, y)
# hidden layers
@@ -122,4 +122,4 @@ class Trainer(torch.nn.Module):

y_pred = self.linear_out(hidden_3)

return y_pred
return y_pred, C

+ 29
- 7
fast_adapt.py View File

@@ -1,5 +1,19 @@
import torch
import pickle
from options import config
import random


def cl_loss(c):
alpha = config['alpha']
beta = config['beta']
d = config['d']
a = torch.div(1, torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c), beta))))))
# a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
# b = 1 * a * (1 - a) * (1 - 2 * a)
loss = torch.sum(b)
return loss


def fast_adapt(
@@ -9,16 +23,24 @@ def fast_adapt(
adaptation_labels,
evaluation_labels,
adaptation_steps,
get_predictions=False):
get_predictions=False,
epoch=None):
for step in range(adaptation_steps):
temp = learn(adaptation_data, adaptation_labels, training=True)
temp, c = learn(adaptation_data, adaptation_labels, training=True)
train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
learn.adapt(train_error)
cluster_loss = cl_loss(c)
total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
learn.adapt(total_loss)

predictions = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
adaptation_labels=adaptation_labels)
predictions, c = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
adaptation_labels=adaptation_labels)
valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
cluster_loss = cl_loss(c)
total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss

if random.random() < 0.05:
print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())

if get_predictions:
return valid_error, predictions
return valid_error
return total_loss, predictions
return total_loss

+ 54
- 11
learnToLearn.py View File

@@ -32,9 +32,9 @@ def parse_args():
help='outer-loop learning rate (used with Adam optimiser)')
# parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate')

parser.add_argument('--inner', type=int, default=2,
parser.add_argument('--inner', type=int, default=1,
help='number of gradient steps in inner loop (during training)')
parser.add_argument('--inner_eval', type=int, default=2,
parser.add_argument('--inner_eval', type=int, default=1,
help='number of gradient updates at test time (for evaluation)')

parser.add_argument('--first_order', action='store_true', default=False,
@@ -102,10 +102,10 @@ if __name__ == '__main__':
fc2_out_dim = config['second_fc_hidden_dim']
use_cuda = config['use_cuda']

fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
linear_out = torch.nn.Linear(fc2_out_dim, 1)
head = torch.nn.Sequential(fc1, fc2, linear_out)
# fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
# fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
# linear_out = torch.nn.Linear(fc2_out_dim, 1)
# head = torch.nn.Sequential(fc1, fc2, linear_out)

if use_cuda:
emb = EmbeddingModule(config).cuda()
@@ -123,6 +123,7 @@ if __name__ == '__main__':
transform = l2l.optim.ModuleTransform(torch.nn.Linear)

trainer = Trainer(config)
tr = trainer

# define meta algorithm
if args.meta_algo == "maml":
@@ -140,7 +141,7 @@ if __name__ == '__main__':
# Setup optimization
print("SETUP OPTIMIZATION PHASE")
all_parameters = list(emb.parameters()) + list(trainer.parameters())
optimizer = torch.optim.Adam(all_parameters, lr=args.lr_meta)
optimizer = torch.optim.Adam(all_parameters, lr=config['lr'])
# loss = torch.nn.MSELoss(reduction='mean')

# Load training dataset.
@@ -156,7 +157,47 @@ if __name__ == '__main__':

print("\n\n\n")

for iteration in range(args.epochs):
for iteration in range(config['num_epoch']):

if iteration == 1:
print("changing cluster centroids started ...")
indexes = list(np.arange(training_set_size))
supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
for idx in range(0, 2500):
supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
query_xs.append(
pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
query_ys.append(
pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
batch_sz = len(supp_xs)

user_embeddings = []

for task in range(batch_sz):
# Compute meta-training loss
supp_xs[task] = supp_xs[task].cuda()
supp_ys[task] = supp_ys[task].cuda()
# query_xs[task] = query_xs[task].cuda()
# query_ys[task] = query_ys[task].cuda()
temp_sxs = emb(supp_xs[task])
# temp_qxs = emb(query_xs[task])
y = supp_ys[task].view(-1, 1)
input_pairs = torch.cat((temp_sxs, y), dim=1)
task_embed = tr.cluster_module.input_to_hidden(input_pairs)

# todo : may be useless
mean_task = tr.cluster_module.aggregate(task_embed)
user_embeddings.append(mean_task.detach().cpu().numpy())

supp_xs[task] = supp_xs[task].cpu()
supp_ys[task] = supp_ys[task].cpu()

from sklearn.cluster import KMeans

user_embeddings = np.array(user_embeddings)
kmeans_model = KMeans(n_clusters=config['cluster_k'], init="k-means++").fit(user_embeddings)
tr.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda()

num_batch = int(training_set_size / batch_size)
indexes = list(np.arange(training_set_size))
@@ -199,7 +240,8 @@ if __name__ == '__main__':
temp_qxs,
supp_ys[task],
query_ys[task],
args.inner)
config['inner'],
epoch=iteration)

evaluation_error.backward()
meta_train_error += evaluation_error.item()
@@ -223,7 +265,7 @@ if __name__ == '__main__':
gc.collect()
print("===============================================\n")

if iteration % 2 == 0:
if iteration % 2 == 0 and iteration != 0:
# testing
print("start of test phase")
trainer.eval()
@@ -241,7 +283,8 @@ if __name__ == '__main__':
gc.collect()

print("===================== " + test_state + " =====================")
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],num_epoch=config['num_epoch'],test_state=test_state,args=args)
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],
num_epoch=config['num_epoch'], test_state=test_state, args=args)
with open("results2.txt", "a") as f:
f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3))
print("===================================================")

+ 4
- 3
learnToLearnTest.py View File

@@ -45,9 +45,10 @@ def test(embedding, head, total_dataset, batch_size, num_epoch, test_state=None,
temp_qxs,
supp_ys,
query_ys,
# config['inner'],
args.inner_eval,
get_predictions=True)
config['inner'],
# args.inner_eval,
get_predictions=True,
epoch=0)

l1 = L1Loss(reduction='mean')
loss_q = l1(predictions.view(-1), query_ys)

Loading…
Cancel
Save