extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_tunning.py 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from ray import tune
  5. import pickle
  6. from options import config
  7. from embedding_module import EmbeddingModule
  8. import learn2learn as l2l
  9. import random
  10. from fast_adapt import fast_adapt
  11. import gc
  12. from learn2learn.optim.transforms import KroneckerTransform
  13. from hyper_testing import hyper_test
  14. # Define paths (for data)
  15. master_path= "/media/external_10TB/10TB/maheri/melu_data5"
  16. def load_data(data_dir=master_path,test_state='warm_state'):
  17. training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4)
  18. supp_xs_s = []
  19. supp_ys_s = []
  20. query_xs_s = []
  21. query_ys_s = []
  22. for idx in range(training_set_size):
  23. supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb")))
  24. supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb")))
  25. query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb")))
  26. query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb")))
  27. total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  28. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  29. trainset = total_dataset
  30. test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4)
  31. supp_xs_s = []
  32. supp_ys_s = []
  33. query_xs_s = []
  34. query_ys_s = []
  35. for idx in range(test_set_size):
  36. supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
  37. supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
  38. query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
  39. query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
  40. test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  41. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  42. random.shuffle(test_dataset)
  43. val_size = int(test_set_size * 0.2)
  44. validationset = test_dataset[:val_size]
  45. testset = test_dataset[val_size:]
  46. return trainset, validationset,testset
  47. def train_melu(conf, checkpoint_dir=None, data_dir=None):
  48. embedding_dim = config['embedding_dim']
  49. print("inajm1:",checkpoint_dir)
  50. fc1_in_dim = config['embedding_dim'] * 8
  51. fc2_in_dim = config['first_fc_hidden_dim']
  52. fc2_out_dim = config['second_fc_hidden_dim']
  53. fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
  54. fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
  55. linear_out = torch.nn.Linear(fc2_out_dim, 1)
  56. head = torch.nn.Sequential(fc1, fc2, linear_out)
  57. emb = EmbeddingModule(config).cuda()
  58. transform = None
  59. if conf['transformer'] == "kronoker":
  60. transform = KroneckerTransform(l2l.nn.KroneckerLinear)
  61. elif conf['transformer'] == "linear":
  62. transform = l2l.optim.ModuleTransform(torch.nn.Linear)
  63. # define meta algorithm
  64. if conf['meta_algo'] == "maml":
  65. head = l2l.algorithms.MAML(head, lr=conf['local_lr'], first_order=conf['first_order'])
  66. elif conf['meta_algo'] == 'metasgd':
  67. head = l2l.algorithms.MetaSGD(head, lr=conf['local_lr'], first_order=conf['first_order'])
  68. elif conf['meta_algo'] == 'gbml':
  69. head = l2l.algorithms.GBML(head, transform=transform, lr=conf['local_lr'],
  70. adapt_transform=conf['adapt_transform'], first_order=conf['first_order'])
  71. head.cuda()
  72. net = nn.Sequential(emb,head)
  73. criterion = nn.MSELoss()
  74. all_parameters = list(emb.parameters()) + list(head.parameters())
  75. optimizer = torch.optim.Adam(all_parameters, lr=conf['lr'])
  76. if checkpoint_dir:
  77. model_state, optimizer_state = torch.load(
  78. os.path.join(checkpoint_dir, "checkpoint"))
  79. net.load_state_dict(model_state)
  80. optimizer.load_state_dict(optimizer_state)
  81. # loading data
  82. train_dataset,validation_dataset,test_dataset = load_data(data_dir,test_state=conf['test_state'])
  83. print(conf['test_state'])
  84. batch_size = conf['batch_size']
  85. num_batch = int(len(train_dataset) / batch_size)
  86. a, b, c, d = zip(*train_dataset)
  87. for epoch in range(config['num_epoch']): # loop over the dataset multiple times
  88. for i in range(num_batch):
  89. optimizer.zero_grad()
  90. meta_train_error = 0.0
  91. # print("EPOCH: ", epoch, " BATCH: ", i)
  92. supp_xs = list(a[batch_size * i:batch_size * (i + 1)])
  93. supp_ys = list(b[batch_size * i:batch_size * (i + 1)])
  94. query_xs = list(c[batch_size * i:batch_size * (i + 1)])
  95. query_ys = list(d[batch_size * i:batch_size * (i + 1)])
  96. batch_sz = len(supp_xs)
  97. # iterate over all tasks
  98. for task in range(batch_sz):
  99. sxs = supp_xs[task].cuda()
  100. qxs = query_xs[task].cuda()
  101. sys = supp_ys[task].cuda()
  102. qys = query_ys[task].cuda()
  103. learner = head.clone()
  104. temp_sxs = emb(sxs)
  105. temp_qxs = emb(qxs)
  106. evaluation_error = fast_adapt(learner,
  107. temp_sxs,
  108. temp_qxs,
  109. sys,
  110. qys,
  111. conf['inner'])
  112. evaluation_error.backward()
  113. meta_train_error += evaluation_error.item()
  114. del(sxs,qxs,sys,qys)
  115. supp_xs[task].cpu()
  116. query_xs[task].cpu()
  117. supp_ys[task].cpu()
  118. query_ys[task].cpu()
  119. # Average the accumulated gradients and optimize (After each batch we will update params)
  120. for p in all_parameters:
  121. p.grad.data.mul_(1.0 / batch_sz)
  122. optimizer.step()
  123. del (supp_xs, supp_ys, query_xs, query_ys)
  124. gc.collect()
  125. # test results on the validation data
  126. val_loss,val_ndcg1,val_ndcg3 = hyper_test(emb,head,validation_dataset,adaptation_step=conf['inner'])
  127. with tune.checkpoint_dir(epoch) as checkpoint_dir:
  128. path = os.path.join(checkpoint_dir, "checkpoint")
  129. torch.save((net.state_dict(), optimizer.state_dict()), path)
  130. tune.report(loss=val_loss, ndcg1=val_ndcg1,ndcg3=val_ndcg3)
  131. print("Finished Training")