Browse Source

only consider positive interaction for task embedding

define_task
mohamad maheri 2 years ago
parent
commit
c701e2c338
3 changed files with 50 additions and 60 deletions
  1. 12
    2
      clustering.py
  2. 0
    4
      fast_adapt.py
  3. 38
    54
      learnToLearn.py

+ 12
- 2
clustering.py View File

@@ -16,7 +16,9 @@ class ClustringModule(torch.nn.Module):
self.final_dim = config_param['cluster_final_dim']
self.dropout_rate = config_param['cluster_dropout_rate']

layers = [nn.Linear(config_param['embedding_dim'] * 8 + 1, self.h1_dim),
layers = [
# nn.Linear(config_param['embedding_dim'] * 8 + 1, self.h1_dim),
nn.Linear(config_param['embedding_dim'] * 8, self.h1_dim),
torch.nn.Dropout(self.dropout_rate),
nn.ReLU(inplace=True),
# nn.BatchNorm1d(self.h1_dim),
@@ -37,7 +39,15 @@ class ClustringModule(torch.nn.Module):

def forward(self, task_embed, y, training=True):
y = y.view(-1, 1)
input_pairs = torch.cat((task_embed, y), dim=1)
high_idx = y > 3
high_idx = high_idx.squeeze()
if high_idx.sum() > 0:
input_pairs = task_embed.detach()[high_idx]
else:
input_pairs = torch.ones(size=(1, 8 * config['embedding_dim'])).cuda()
print("found")

# input_pairs = torch.cat((task_embed, y), dim=1)
task_embed = self.input_to_hidden(input_pairs)

# todo : may be useless

+ 0
- 4
fast_adapt.py View File

@@ -35,9 +35,6 @@ def fast_adapt(
# total_loss = train_error + config['cluster_loss_weight'] * cluster_loss
total_loss = train_error + config['kmeans_loss_weight'] * k_loss
learn.adapt(total_loss)
if is_print:
# print("in support:\t", round(k_loss.item(),4))
pass

predictions, c, k_loss = learn(evaluation_data, None, training=False, adaptation_data=adaptation_data,
adaptation_labels=adaptation_labels)
@@ -47,7 +44,6 @@ def fast_adapt(
total_loss = valid_error + config['kmeans_loss_weight'] * k_loss

if is_print:

# print("in query:\t", round(k_loss.item(),4))
print(c[0].detach().cpu().numpy(),"\t",round(k_loss.item(),3),"\n")


+ 38
- 54
learnToLearn.py View File

@@ -74,7 +74,7 @@ def parse_args():
help='run adaptation transform')
parser.add_argument('--transformer', type=str, default="kronoker",
help='transformer type')
parser.add_argument('--meta_algo', type=str, default="gbml",
parser.add_argument('--meta_algo', type=str, default="metasgd",
help='MAML/MetaSGD/GBML')
parser.add_argument('--gpu', type=int, default=0,
help='number of gpu to run the code')
@@ -163,11 +163,6 @@ if __name__ == '__main__':
fc2_out_dim = config['second_fc_hidden_dim']
use_cuda = config['use_cuda']

# fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
# fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
# linear_out = torch.nn.Linear(fc2_out_dim, 1)
# head = torch.nn.Sequential(fc1, fc2, linear_out)

if use_cuda:
emb = EmbeddingModule(config).cuda()
else:
@@ -244,7 +239,8 @@ if __name__ == '__main__':
temp_sxs = emb(supp_xs[task])
# temp_qxs = emb(query_xs[task])
y = supp_ys[task].view(-1, 1)
input_pairs = torch.cat((temp_sxs, y), dim=1)
# input_pairs = torch.cat((temp_sxs, y), dim=1)
input_pairs = temp_sxs
task_embed = tr.cluster_module.input_to_hidden(input_pairs)

# todo : may be useless
@@ -263,7 +259,9 @@ if __name__ == '__main__':
if iteration > 0:
# indexes = data_batching(indexes, C_distribs, batch_size, training_set_size, config['cluster_k'])
# random.shuffle(indexes)
C_distribs = []
num_batch = int(training_set_size / batch_size)
indexes = list(np.arange(training_set_size))
random.shuffle(indexes)
else:
num_batch = int(training_set_size / batch_size)
indexes = list(np.arange(training_set_size))
@@ -293,12 +291,6 @@ if __name__ == '__main__':

C_distribs = []
for task in range(batch_sz):
# Compute meta-training loss
# sxs = supp_xs[task].cuda()
# qxs = query_xs[task].cuda()
# sys = supp_ys[task].cuda()
# qys = query_ys[task].cuda()

learner = trainer.clone()
temp_sxs = emb(supp_xs[task])
temp_qxs = emb(query_xs[task])
@@ -316,11 +308,6 @@ if __name__ == '__main__':
meta_train_error += evaluation_error.item()
meta_cluster_error += k_loss

# supp_xs[task].cpu()
# query_xs[task].cpu()
# supp_ys[task].cpu()
# query_ys[task].cpu()

# Print some metrics
print('Iteration', iteration)
print('Meta Train Error', meta_train_error / batch_sz)
@@ -330,49 +317,46 @@ if __name__ == '__main__':
# clustering_loss.backward()
# print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy())

# if i != (num_batch - 1):
# C_distribs = []

# Average the accumulated gradients and optimize
for p in all_parameters:
p.grad.data.mul_(1.0 / batch_sz)
optimizer.step()

# torch.cuda.empty_cache()
del (supp_xs, supp_ys, query_xs, query_ys, learner, temp_sxs, temp_qxs)
gc.collect()
# del (supp_xs, supp_ys, query_xs, query_ys, learner, temp_sxs, temp_qxs)
# gc.collect()
print("===============================================\n")

# if iteration % 2 == 0 and iteration != 0:
# # testing
# print("start of test phase")
# trainer.eval()
#
# with open("results2.txt", "a") as f:
# f.write("epoch:{}\n".format(iteration))
#
# for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']:
# test_dataset = None
# test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
# supp_xs_s = []
# supp_ys_s = []
# query_xs_s = []
# query_ys_s = []
# gc.collect()
#
# print("===================== " + test_state + " =====================")
# mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],
# num_epoch=config['num_epoch'], test_state=test_state, args=args)
# with open("results2.txt", "a") as f:
# f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3))
# print("===================================================")
# del (test_dataset)
# gc.collect()
#
# trainer.train()
# with open("results2.txt", "a") as f:
# f.write("\n")
# print("\n\n\n")
if iteration % 2 == 0 and iteration != 0:
# testing
print("start of test phase")
trainer.eval()
with open("results2.txt", "a") as f:
f.write("epoch:{}\n".format(iteration))
for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']:
test_dataset = None
test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
supp_xs_s = []
supp_ys_s = []
query_xs_s = []
query_ys_s = []
gc.collect()
print("===================== " + test_state + " =====================")
mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],
num_epoch=config['num_epoch'], test_state=test_state, args=args)
with open("results2.txt", "a") as f:
f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3))
print("===================================================")
del (test_dataset)
gc.collect()
trainer.train()
with open("results2.txt", "a") as f:
f.write("\n")
print("\n\n\n")

# save model
# final_model = torch.nn.Sequential(emb, head)

Loading…
Cancel
Save