extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_tunning.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from ray import tune
  5. import pickle
  6. from embedding_module import EmbeddingModule
  7. import learn2learn as l2l
  8. import random
  9. from fast_adapt import fast_adapt
  10. import gc
  11. from learn2learn.optim.transforms import KroneckerTransform
  12. from hyper_testing import hyper_test
  13. from clustering import Trainer
  14. from Head import Head
  15. import numpy as np
  16. # Define paths (for data)
  17. # master_path= "/media/external_10TB/10TB/maheri/melu_data5"
  18. def load_data(data_dir=None, test_state='warm_state'):
  19. # training_set_size = int(len(os.listdir("{}/warm_state".format(data_dir))) / 4)
  20. # supp_xs_s = []
  21. # supp_ys_s = []
  22. # query_xs_s = []
  23. # query_ys_s = []
  24. # for idx in range(training_set_size):
  25. # supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(data_dir, idx), "rb")))
  26. # supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(data_dir, idx), "rb")))
  27. # query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(data_dir, idx), "rb")))
  28. # query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(data_dir, idx), "rb")))
  29. # total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  30. # del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  31. # trainset = total_dataset
  32. test_set_size = int(len(os.listdir("{}/{}".format(data_dir, test_state))) / 4)
  33. supp_xs_s = []
  34. supp_ys_s = []
  35. query_xs_s = []
  36. query_ys_s = []
  37. for idx in range(test_set_size):
  38. supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
  39. supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
  40. query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(data_dir, test_state, idx), "rb")))
  41. query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(data_dir, test_state, idx), "rb")))
  42. test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  43. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  44. random.shuffle(test_dataset)
  45. # random.shuffle(trainset)
  46. val_size = int(test_set_size * 0.3)
  47. validationset = test_dataset[:val_size]
  48. # testset = test_dataset[val_size:]
  49. return None, validationset, None
  50. def data_batching_new(indexes, C_distribs, batch_size, training_set_size, num_clusters,config):
  51. probs = np.squeeze(C_distribs)
  52. probs = np.array(probs) ** config['distribution_power'] / np.sum(np.array(probs) ** config['distribution_power'],
  53. axis=1, keepdims=True)
  54. cs = [np.random.choice(num_clusters, p=i) for i in probs]
  55. num_batch = int(training_set_size / batch_size)
  56. res = [[] for i in range(num_batch)]
  57. clas = [[] for i in range(num_clusters)]
  58. clas_temp = [[] for i in range(num_clusters)]
  59. for idx, c in zip(indexes, cs):
  60. clas[c].append(idx)
  61. for i in range(num_clusters):
  62. random.shuffle(clas[i])
  63. # t = np.array([len(i) for i in clas])
  64. t = np.array([len(i) ** config['data_selection_pow'] for i in clas])
  65. t = t / t.sum()
  66. dif = list(set(list(np.arange(training_set_size))) - set(indexes[0:(num_batch * batch_size)]))
  67. cnt = 0
  68. for i in range(len(res)):
  69. for j in range(batch_size):
  70. temp = np.random.choice(num_clusters, p=t)
  71. if len(clas[temp]) > 0:
  72. selected = clas[temp].pop(0)
  73. res[i].append(selected)
  74. clas_temp[temp].append(selected)
  75. else:
  76. # res[i].append(indexes[training_set_size-1-cnt])
  77. if len(dif) > 0:
  78. if random.random() < 0.5 or len(clas_temp[temp]) == 0:
  79. res[i].append(dif.pop(0))
  80. else:
  81. selected = clas_temp[temp].pop(0)
  82. clas_temp[temp].append(selected)
  83. res[i].append(selected)
  84. else:
  85. selected = clas_temp[temp].pop(0)
  86. res[i].append(selected)
  87. cnt = cnt + 1
  88. print("data_batching : ", cnt)
  89. res = np.random.permutation(res)
  90. final_result = np.array(res).flatten()
  91. return final_result
  92. def train_melu(conf, checkpoint_dir=None, data_dir=None):
  93. config = conf
  94. master_path = data_dir
  95. emb = EmbeddingModule(conf).cuda()
  96. transform = None
  97. if conf['transformer'] == "kronoker":
  98. transform = KroneckerTransform(l2l.nn.KroneckerLinear)
  99. elif conf['transformer'] == "linear":
  100. transform = l2l.optim.ModuleTransform(torch.nn.Linear)
  101. trainer = Trainer(conf)
  102. trainer.cuda()
  103. head = Head(config)
  104. # define meta algorithm
  105. if conf['meta_algo'] == "maml":
  106. head = l2l.algorithms.MAML(head, lr=conf['local_lr'], first_order=conf['first_order'])
  107. elif conf['meta_algo'] == 'metasgd':
  108. head = l2l.algorithms.MetaSGD(head, lr=conf['local_lr'], first_order=conf['first_order'])
  109. elif conf['meta_algo'] == 'gbml':
  110. head = l2l.algorithms.GBML(head, transform=transform, lr=conf['local_lr'],
  111. adapt_transform=conf['adapt_transform'], first_order=conf['first_order'])
  112. head.cuda()
  113. criterion = nn.MSELoss()
  114. all_parameters = list(emb.parameters()) + list(trainer.parameters()) + list(head.parameters())
  115. optimizer = torch.optim.Adam(all_parameters, lr=conf['lr'])
  116. # Load training dataset.
  117. print("LOAD DATASET PHASE")
  118. training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4)
  119. supp_xs_s = []
  120. supp_ys_s = []
  121. query_xs_s = []
  122. query_ys_s = []
  123. if checkpoint_dir:
  124. print("in checkpoint - bug happened")
  125. # model_state, optimizer_state = torch.load(
  126. # os.path.join(checkpoint_dir, "checkpoint"))
  127. # net.load_state_dict(model_state)
  128. # optimizer.load_state_dict(optimizer_state)
  129. # loading data
  130. # _, validation_dataset, _ = load_data(data_dir, test_state=conf['test_state'])
  131. batch_size = conf['batch_size']
  132. # num_batch = int(len(train_dataset) / batch_size)
  133. # a, b, c, d = zip(*train_dataset)
  134. C_distribs = []
  135. indexes = list(np.arange(training_set_size))
  136. all_test_users = []
  137. for iteration in range(conf['num_epoch']): # loop over the dataset multiple times
  138. print("iteration:", iteration)
  139. num_batch = int(training_set_size / batch_size)
  140. if iteration == 0:
  141. print("changing cluster centroids started ...")
  142. indexes = list(np.arange(training_set_size))
  143. supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
  144. for idx in range(0, 2500):
  145. supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
  146. supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
  147. query_xs.append(
  148. pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
  149. query_ys.append(
  150. pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
  151. batch_sz = len(supp_xs)
  152. user_embeddings = []
  153. for task in range(batch_sz):
  154. # Compute meta-training loss
  155. supp_xs[task] = supp_xs[task].cuda()
  156. supp_ys[task] = supp_ys[task].cuda()
  157. temp_sxs = emb(supp_xs[task])
  158. y = supp_ys[task].view(-1, 1)
  159. input_pairs = torch.cat((temp_sxs, y), dim=1)
  160. _, mean_task, _ = trainer.cluster_module(temp_sxs, y)
  161. user_embeddings.append(mean_task.detach().cpu().numpy())
  162. supp_xs[task] = supp_xs[task].cpu()
  163. supp_ys[task] = supp_ys[task].cpu()
  164. from sklearn.cluster import KMeans
  165. user_embeddings = np.array(user_embeddings)
  166. kmeans_model = KMeans(n_clusters=conf['cluster_k'], init="k-means++").fit(user_embeddings)
  167. trainer.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda()
  168. if iteration > (0):
  169. indexes = data_batching_new(indexes, C_distribs, batch_size, training_set_size, conf['cluster_k'], conf)
  170. else:
  171. random.shuffle(indexes)
  172. C_distribs = []
  173. for i in range(num_batch):
  174. optimizer.zero_grad()
  175. meta_train_error = 0.0
  176. meta_cluster_error = 0.0
  177. supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
  178. for idx in range(batch_size * i, batch_size * (i + 1)):
  179. supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
  180. supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
  181. query_xs.append(
  182. pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
  183. query_ys.append(
  184. pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
  185. batch_sz = len(supp_xs)
  186. for task in range(batch_sz):
  187. # Compute meta-training loss
  188. supp_xs[task] = supp_xs[task].cuda()
  189. supp_ys[task] = supp_ys[task].cuda()
  190. query_xs[task] = query_xs[task].cuda()
  191. query_ys[task] = query_ys[task].cuda()
  192. learner = head.clone()
  193. temp_sxs = emb(supp_xs[task])
  194. temp_qxs = emb(query_xs[task])
  195. evaluation_error, c, K_LOSS = fast_adapt(learner,
  196. temp_sxs,
  197. temp_qxs,
  198. supp_ys[task],
  199. query_ys[task],
  200. conf['inner'],
  201. trainer=trainer,
  202. test=False,
  203. iteration=iteration
  204. )
  205. C_distribs.append(c.detach().cpu().numpy())
  206. meta_cluster_error += K_LOSS
  207. evaluation_error.backward(retain_graph=True)
  208. meta_train_error += evaluation_error.item()
  209. supp_xs[task] = supp_xs[task].cpu()
  210. supp_ys[task] = supp_ys[task].cpu()
  211. query_xs[task] = query_xs[task].cpu()
  212. query_ys[task] = query_ys[task].cpu()
  213. ################################################
  214. # Print some metrics
  215. print('Iteration', iteration)
  216. print('Meta Train Error', meta_train_error / batch_sz)
  217. print('KL Train Error', round(meta_cluster_error / batch_sz, 4), "\t", C_distribs[-1])
  218. # Average the accumulated gradients and optimize
  219. for p in all_parameters:
  220. # if p.grad!=None:
  221. p.grad.data.mul_(1.0 / batch_sz)
  222. optimizer.step()
  223. # test results on the validation data
  224. val_loss, val_ndcg1, val_ndcg3 = hyper_test(emb, head, trainer, batch_size, master_path, conf['test_state'],
  225. adaptation_step=conf['inner'], num_epoch=iteration)
  226. # with tune.checkpoint_dir(epoch) as checkpoint_dir:
  227. # path = os.path.join(checkpoint_dir, "checkpoint")
  228. # torch.save((net.state_dict(), optimizer.state_dict()), path)
  229. tune.report(loss=val_loss, ndcg1=val_ndcg1, ndcg3=val_ndcg3)
  230. print("Finished Training")