extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_tunning.py 6.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from ray import tune
  5. import pickle
  6. from options import config
  7. from embedding_module import EmbeddingModule
  8. import learn2learn as l2l
  9. import random
  10. from fast_adapt import fast_adapt
  11. import gc
  12. from learn2learn.optim.transforms import KroneckerTransform
  13. from hyper_testing import hyper_test
  14. from clustering import Trainer
  15. # Define paths (for data)
  16. # master_path= "/media/external_10TB/10TB/maheri/melu_data5"
  17. def load_data(data_dir=None, test_state='warm_state'):
  18. training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4)
  19. supp_xs_s = []
  20. supp_ys_s = []
  21. query_xs_s = []
  22. query_ys_s = []
  23. for idx in range(training_set_size):
  24. supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb")))
  25. supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb")))
  26. query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb")))
  27. query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb")))
  28. total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  29. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  30. trainset = total_dataset
  31. test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4)
  32. supp_xs_s = []
  33. supp_ys_s = []
  34. query_xs_s = []
  35. query_ys_s = []
  36. for idx in range(test_set_size):
  37. supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
  38. supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
  39. query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
  40. query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
  41. test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  42. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  43. random.shuffle(test_dataset)
  44. random.shuffle(trainset)
  45. val_size = int(test_set_size * 0.2)
  46. validationset = test_dataset[:val_size]
  47. testset = test_dataset[val_size:]
  48. return trainset, validationset, testset
  49. def train_melu(conf, checkpoint_dir=None, data_dir=None):
  50. print("inajm1:", checkpoint_dir)
  51. embedding_dim = conf['embedding_dim']
  52. fc1_in_dim = conf['embedding_dim'] * 8
  53. fc2_in_dim = conf['first_fc_hidden_dim']
  54. fc2_out_dim = conf['second_fc_hidden_dim']
  55. # fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
  56. # fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
  57. # linear_out = torch.nn.Linear(fc2_out_dim, 1)
  58. # head = torch.nn.Sequential(fc1, fc2, linear_out)
  59. emb = EmbeddingModule(config).cuda()
  60. transform = None
  61. if conf['transformer'] == "kronoker":
  62. transform = KroneckerTransform(l2l.nn.KroneckerLinear)
  63. elif conf['transformer'] == "linear":
  64. transform = l2l.optim.ModuleTransform(torch.nn.Linear)
  65. trainer = Trainer(config)
  66. # define meta algorithm
  67. if conf['meta_algo'] == "maml":
  68. trainer = l2l.algorithms.MAML(trainer, lr=conf['local_lr'], first_order=conf['first_order'])
  69. elif conf['meta_algo'] == 'metasgd':
  70. trainer = l2l.algorithms.MetaSGD(trainer, lr=conf['local_lr'], first_order=conf['first_order'])
  71. elif conf['meta_algo'] == 'gbml':
  72. trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=conf['local_lr'],
  73. adapt_transform=conf['adapt_transform'], first_order=conf['first_order'])
  74. trainer.cuda()
  75. # net = nn.Sequential(emb, head)
  76. criterion = nn.MSELoss()
  77. all_parameters = list(emb.parameters()) + list(trainer.parameters())
  78. optimizer = torch.optim.Adam(all_parameters, lr=conf['lr'])
  79. if checkpoint_dir:
  80. print("in checkpoint - bug happened")
  81. # model_state, optimizer_state = torch.load(
  82. # os.path.join(checkpoint_dir, "checkpoint"))
  83. # net.load_state_dict(model_state)
  84. # optimizer.load_state_dict(optimizer_state)
  85. # loading data
  86. train_dataset, validation_dataset, test_dataset = load_data(data_dir, test_state=conf['test_state'])
  87. batch_size = conf['batch_size']
  88. num_batch = int(len(train_dataset) / batch_size)
  89. a, b, c, d = zip(*train_dataset)
  90. for epoch in range(config['num_epoch']): # loop over the dataset multiple times
  91. for i in range(num_batch):
  92. optimizer.zero_grad()
  93. meta_train_error = 0.0
  94. # print("EPOCH: ", epoch, " BATCH: ", i)
  95. supp_xs = list(a[batch_size * i:batch_size * (i + 1)])
  96. supp_ys = list(b[batch_size * i:batch_size * (i + 1)])
  97. query_xs = list(c[batch_size * i:batch_size * (i + 1)])
  98. query_ys = list(d[batch_size * i:batch_size * (i + 1)])
  99. batch_sz = len(supp_xs)
  100. # iterate over all tasks
  101. for task in range(batch_sz):
  102. sxs = supp_xs[task].cuda()
  103. qxs = query_xs[task].cuda()
  104. sys = supp_ys[task].cuda()
  105. qys = query_ys[task].cuda()
  106. learner = trainer.clone()
  107. temp_sxs = emb(sxs)
  108. temp_qxs = emb(qxs)
  109. evaluation_error = fast_adapt(learner,
  110. temp_sxs,
  111. temp_qxs,
  112. sys,
  113. qys,
  114. conf['inner'])
  115. evaluation_error.backward()
  116. meta_train_error += evaluation_error.item()
  117. del (sxs, qxs, sys, qys)
  118. supp_xs[task].cpu()
  119. query_xs[task].cpu()
  120. supp_ys[task].cpu()
  121. query_ys[task].cpu()
  122. # Average the accumulated gradients and optimize (After each batch we will update params)
  123. for p in all_parameters:
  124. p.grad.data.mul_(1.0 / batch_sz)
  125. optimizer.step()
  126. del (supp_xs, supp_ys, query_xs, query_ys)
  127. gc.collect()
  128. # test results on the validation data
  129. val_loss, val_ndcg1, val_ndcg3 = hyper_test(emb, trainer, validation_dataset, adaptation_step=conf['inner'])
  130. # with tune.checkpoint_dir(epoch) as checkpoint_dir:
  131. # path = os.path.join(checkpoint_dir, "checkpoint")
  132. # torch.save((net.state_dict(), optimizer.state_dict()), path)
  133. tune.report(loss=val_loss, ndcg1=val_ndcg1, ndcg3=val_ndcg3)
  134. print("Finished Training")