Browse Source

cluster loss (try to hard assignment)

define_task
mohamad maheri 2 years ago
parent
commit
d425ba0c97
4 changed files with 89 additions and 23 deletions
  1. 2
    2
      clustering.py
  2. 29
    7
      fast_adapt.py
  3. 54
    11
      learnToLearn.py
  4. 4
    3
      learnToLearnTest.py

+ 2
- 2
clustering.py View File

def aggregate(self, z_i): def aggregate(self, z_i):
return torch.mean(z_i, dim=0) return torch.mean(z_i, dim=0)


def forward(self, task_embed, y, training,adaptation_data=None,adaptation_labels=None):
def forward(self, task_embed, y, training, adaptation_data=None, adaptation_labels=None):
if training: if training:
C, clustered_task_embed = self.cluster_module(task_embed, y) C, clustered_task_embed = self.cluster_module(task_embed, y)
# hidden layers # hidden layers


y_pred = self.linear_out(hidden_3) y_pred = self.linear_out(hidden_3)


return y_pred
return y_pred, C

+ 29
- 7
fast_adapt.py View File

import torch import torch
import pickle import pickle
from options import config
import random


def cl_loss(c):
alpha = config['alpha']
beta = config['beta']
d = config['d']
a = torch.div(1, torch.add(1, torch.exp(torch.mul(-1, torch.mul(alpha, torch.sub(torch.mul(d, c), beta))))))
# a = 1 / (1 + torch.exp((-1) * alpha * (d * c - beta)))
b = torch.mul(a, torch.mul(torch.sub(1, a), torch.sub(1, torch.mul(2, a))))
# b = 1 * a * (1 - a) * (1 - 2 * a)
loss = torch.sum(b)
return loss




def fast_adapt( def fast_adapt(
adaptation_labels, adaptation_labels,
evaluation_labels, evaluation_labels,
adaptation_steps, adaptation_steps,
get_predictions=False):
get_predictions=False,
epoch=None):
for step in range(adaptation_steps): for step in range(adaptation_steps):
temp = learn(adaptation_data, adaptation_labels, training=True)
temp, c = learn(adaptation_data, adaptation_labels, training=True)
train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels) train_error = torch.nn.functional.mse_loss(temp.view(-1), adaptation_labels)
learn.adapt(train_error)
cluster_loss = cl_loss(c)
total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
learn.adapt(total_loss)


predictions = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
adaptation_labels=adaptation_labels)
predictions, c = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
adaptation_labels=adaptation_labels)
valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels) valid_error = torch.nn.functional.mse_loss(predictions.view(-1), evaluation_labels)
cluster_loss = cl_loss(c)
total_loss = valid_error + config['cluster_loss_weight'] * cluster_loss

if random.random() < 0.05:
print("cl:", round(cluster_loss.item()), "\t c:", c[0].cpu().data.numpy())


if get_predictions: if get_predictions:
return valid_error, predictions
return valid_error
return total_loss, predictions
return total_loss

+ 54
- 11
learnToLearn.py View File

help='outer-loop learning rate (used with Adam optimiser)') help='outer-loop learning rate (used with Adam optimiser)')
# parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate') # parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate')


parser.add_argument('--inner', type=int, default=2,
parser.add_argument('--inner', type=int, default=1,
help='number of gradient steps in inner loop (during training)') help='number of gradient steps in inner loop (during training)')
parser.add_argument('--inner_eval', type=int, default=2,
parser.add_argument('--inner_eval', type=int, default=1,
help='number of gradient updates at test time (for evaluation)') help='number of gradient updates at test time (for evaluation)')


parser.add_argument('--first_order', action='store_true', default=False, parser.add_argument('--first_order', action='store_true', default=False,
fc2_out_dim = config['second_fc_hidden_dim'] fc2_out_dim = config['second_fc_hidden_dim']
use_cuda = config['use_cuda'] use_cuda = config['use_cuda']


fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
linear_out = torch.nn.Linear(fc2_out_dim, 1)
head = torch.nn.Sequential(fc1, fc2, linear_out)
# fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
# fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
# linear_out = torch.nn.Linear(fc2_out_dim, 1)
# head = torch.nn.Sequential(fc1, fc2, linear_out)


if use_cuda: if use_cuda:
emb = EmbeddingModule(config).cuda() emb = EmbeddingModule(config).cuda()
transform = l2l.optim.ModuleTransform(torch.nn.Linear) transform = l2l.optim.ModuleTransform(torch.nn.Linear)


trainer = Trainer(config) trainer = Trainer(config)
tr = trainer


# define meta algorithm # define meta algorithm
if args.meta_algo == "maml": if args.meta_algo == "maml":
# Setup optimization # Setup optimization
print("SETUP OPTIMIZATION PHASE") print("SETUP OPTIMIZATION PHASE")
all_parameters = list(emb.parameters()) + list(trainer.parameters()) all_parameters = list(emb.parameters()) + list(trainer.parameters())
optimizer = torch.optim.Adam(all_parameters, lr=args.lr_meta)
optimizer = torch.optim.Adam(all_parameters, lr=config['lr'])
# loss = torch.nn.MSELoss(reduction='mean') # loss = torch.nn.MSELoss(reduction='mean')


# Load training dataset. # Load training dataset.


print("\n\n\n") print("\n\n\n")


for iteration in range(args.epochs):
for iteration in range(config['num_epoch']):

if iteration == 1:
print("changing cluster centroids started ...")
indexes = list(np.arange(training_set_size))
supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
for idx in range(0, 2500):
supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
query_xs.append(
pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
query_ys.append(
pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
batch_sz = len(supp_xs)

user_embeddings = []

for task in range(batch_sz):
# Compute meta-training loss
supp_xs[task] = supp_xs[task].cuda()
supp_ys[task] = supp_ys[task].cuda()
# query_xs[task] = query_xs[task].cuda()
# query_ys[task] = query_ys[task].cuda()
temp_sxs = emb(supp_xs[task])
# temp_qxs = emb(query_xs[task])
y = supp_ys[task].view(-1, 1)
input_pairs = torch.cat((temp_sxs, y), dim=1)
task_embed = tr.cluster_module.input_to_hidden(input_pairs)

# todo : may be useless
mean_task = tr.cluster_module.aggregate(task_embed)
user_embeddings.append(mean_task.detach().cpu().numpy())

supp_xs[task] = supp_xs[task].cpu()
supp_ys[task] = supp_ys[task].cpu()

from sklearn.cluster import KMeans

user_embeddings = np.array(user_embeddings)
kmeans_model = KMeans(n_clusters=config['cluster_k'], init="k-means++").fit(user_embeddings)
tr.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda()


num_batch = int(training_set_size / batch_size) num_batch = int(training_set_size / batch_size)
indexes = list(np.arange(training_set_size)) indexes = list(np.arange(training_set_size))
temp_qxs, temp_qxs,
supp_ys[task], supp_ys[task],
query_ys[task], query_ys[task],
args.inner)
config['inner'],
epoch=iteration)


evaluation_error.backward() evaluation_error.backward()
meta_train_error += evaluation_error.item() meta_train_error += evaluation_error.item()
gc.collect() gc.collect()
print("===============================================\n") print("===============================================\n")


if iteration % 2 == 0:
if iteration % 2 == 0 and iteration != 0:
# testing # testing
print("start of test phase") print("start of test phase")
trainer.eval() trainer.eval()
gc.collect() gc.collect()


print("===================== " + test_state + " =====================") print("===================== " + test_state + " =====================")
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],num_epoch=config['num_epoch'],test_state=test_state,args=args)
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],
num_epoch=config['num_epoch'], test_state=test_state, args=args)
with open("results2.txt", "a") as f: with open("results2.txt", "a") as f:
f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3))
print("===================================================") print("===================================================")

+ 4
- 3
learnToLearnTest.py View File

temp_qxs, temp_qxs,
supp_ys, supp_ys,
query_ys, query_ys,
# config['inner'],
args.inner_eval,
get_predictions=True)
config['inner'],
# args.inner_eval,
get_predictions=True,
epoch=0)


l1 = L1Loss(reduction='mean') l1 = L1Loss(reduction='mean')
loss_q = l1(predictions.view(-1), query_ys) loss_q = l1(predictions.view(-1), query_ys)

Loading…
Cancel
Save