extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_tunning.py 6.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from ray import tune
  5. import pickle
  6. # from options import config
  7. from embedding_module import EmbeddingModule
  8. import learn2learn as l2l
  9. import random
  10. from fast_adapt import fast_adapt
  11. import gc
  12. from learn2learn.optim.transforms import KroneckerTransform
  13. from hyper_testing import hyper_test
  14. from clustering import Trainer
  15. # Define paths (for data)
  16. # master_path= "/media/external_10TB/10TB/maheri/melu_data5"
  17. def load_data(data_dir=None, test_state='warm_state'):
  18. training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4)
  19. supp_xs_s = []
  20. supp_ys_s = []
  21. query_xs_s = []
  22. query_ys_s = []
  23. for idx in range(training_set_size):
  24. supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb")))
  25. supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb")))
  26. query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb")))
  27. query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb")))
  28. total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  29. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  30. trainset = total_dataset
  31. test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4)
  32. supp_xs_s = []
  33. supp_ys_s = []
  34. query_xs_s = []
  35. query_ys_s = []
  36. for idx in range(test_set_size):
  37. supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
  38. supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
  39. query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
  40. query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
  41. test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  42. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  43. random.shuffle(test_dataset)
  44. random.shuffle(trainset)
  45. val_size = int(test_set_size * 0.3)
  46. validationset = test_dataset[:val_size]
  47. testset = test_dataset[val_size:]
  48. return trainset, validationset, testset
  49. def train_melu(conf, checkpoint_dir=None, data_dir=None):
  50. embedding_dim = conf['embedding_dim']
  51. fc1_in_dim = conf['embedding_dim'] * 8
  52. fc2_in_dim = conf['first_fc_hidden_dim']
  53. fc2_out_dim = conf['second_fc_hidden_dim']
  54. emb = EmbeddingModule(conf).cuda()
  55. transform = None
  56. if conf['transformer'] == "kronoker":
  57. transform = KroneckerTransform(l2l.nn.KroneckerLinear)
  58. elif conf['transformer'] == "linear":
  59. transform = l2l.optim.ModuleTransform(torch.nn.Linear)
  60. trainer = Trainer(conf)
  61. # define meta algorithm
  62. if conf['meta_algo'] == "maml":
  63. trainer = l2l.algorithms.MAML(trainer, lr=conf['local_lr'], first_order=conf['first_order'])
  64. elif conf['meta_algo'] == 'metasgd':
  65. trainer = l2l.algorithms.MetaSGD(trainer, lr=conf['local_lr'], first_order=conf['first_order'])
  66. elif conf['meta_algo'] == 'gbml':
  67. trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=conf['local_lr'],
  68. adapt_transform=conf['adapt_transform'], first_order=conf['first_order'])
  69. trainer.cuda()
  70. all_parameters = list(emb.parameters()) + list(trainer.parameters())
  71. optimizer = torch.optim.Adam(all_parameters, lr=conf['lr'])
  72. if checkpoint_dir:
  73. print("in checkpoint - bug happened")
  74. # model_state, optimizer_state = torch.load(
  75. # os.path.join(checkpoint_dir, "checkpoint"))
  76. # net.load_state_dict(model_state)
  77. # optimizer.load_state_dict(optimizer_state)
  78. # loading data
  79. train_dataset, validation_dataset, test_dataset = load_data(data_dir, test_state=conf['test_state'])
  80. batch_size = conf['batch_size']
  81. num_batch = int(len(train_dataset) / batch_size)
  82. a, b, c, d = zip(*train_dataset)
  83. for epoch in range(conf['num_epoch']): # loop over the dataset multiple times
  84. for i in range(num_batch):
  85. optimizer.zero_grad()
  86. meta_train_error = 0.0
  87. # print("EPOCH: ", epoch, " BATCH: ", i)
  88. supp_xs = list(a[batch_size * i:batch_size * (i + 1)])
  89. supp_ys = list(b[batch_size * i:batch_size * (i + 1)])
  90. query_xs = list(c[batch_size * i:batch_size * (i + 1)])
  91. query_ys = list(d[batch_size * i:batch_size * (i + 1)])
  92. batch_sz = len(supp_xs)
  93. # iterate over all tasks
  94. for task in range(batch_sz):
  95. sxs = supp_xs[task].cuda()
  96. qxs = query_xs[task].cuda()
  97. sys = supp_ys[task].cuda()
  98. qys = query_ys[task].cuda()
  99. learner = trainer.clone()
  100. temp_sxs = emb(sxs)
  101. temp_qxs = emb(qxs)
  102. evaluation_error = fast_adapt(learner,
  103. temp_sxs,
  104. temp_qxs,
  105. sys,
  106. qys,
  107. conf['inner'])
  108. evaluation_error.backward()
  109. meta_train_error += evaluation_error.item()
  110. del (sxs, qxs, sys, qys)
  111. supp_xs[task].cpu()
  112. query_xs[task].cpu()
  113. supp_ys[task].cpu()
  114. query_ys[task].cpu()
  115. # Average the accumulated gradients and optimize (After each batch we will update params)
  116. for p in all_parameters:
  117. p.grad.data.mul_(1.0 / batch_sz)
  118. optimizer.step()
  119. del (supp_xs, supp_ys, query_xs, query_ys)
  120. gc.collect()
  121. # test results on the validation data
  122. val_loss, val_ndcg1, val_ndcg3 = hyper_test(emb, trainer, validation_dataset, adaptation_step=conf['inner'])
  123. # with tune.checkpoint_dir(epoch) as checkpoint_dir:
  124. # path = os.path.join(checkpoint_dir, "checkpoint")
  125. # torch.save((net.state_dict(), optimizer.state_dict()), path)
  126. tune.report(loss=val_loss, ndcg1=val_ndcg1, ndcg3=val_ndcg3)
  127. print("Finished Training")