extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learnToLearn.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import os
  2. import torch
  3. import pickle
  4. from options import config
  5. from data_generation import generate
  6. from embedding_module import EmbeddingModule
  7. import learn2learn as l2l
  8. import random
  9. from learnToLearnTest import test
  10. from fast_adapt import fast_adapt
  11. import gc
  12. from learn2learn.optim.transforms import KroneckerTransform
  13. import argparse
  14. from clustering import ClustringModule, Trainer
  15. import numpy as np
  16. from torch.nn import functional as F
  17. def parse_args():
  18. print("==============")
  19. parser = argparse.ArgumentParser([], description='Fast Context Adaptation via Meta-Learning (CAVIA),'
  20. 'Clasification experiments.')
  21. print("==============\n")
  22. parser.add_argument('--seed', type=int, default=53)
  23. # parser.add_argument('--task', type=str, default='multi', help='problem setting: sine or celeba')
  24. # parser.add_argument('--tasks_per_metaupdate', type=int, default=32,
  25. # help='number of tasks in each batch per meta-update')
  26. #
  27. # parser.add_argument('--lr_inner', type=float, default=5e-6, help='inner-loop learning rate (per task)')
  28. # parser.add_argument('--lr_meta', type=float, default=5e-5,
  29. # help='outer-loop learning rate (used with Adam optimiser)')
  30. # parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate')
  31. #
  32. # parser.add_argument('--inner', type=int, default=1,
  33. # help='number of gradient steps in inner loop (during training)')
  34. # parser.add_argument('--inner_eval', type=int, default=1,
  35. # help='number of gradient updates at test time (for evaluation)')
  36. parser.add_argument('--first_order', action='store_true', default=False,
  37. help='run first order approximation of CAVIA')
  38. parser.add_argument('--adapt_transform', action='store_true', default=False,
  39. help='run adaptation transform')
  40. parser.add_argument('--transformer', type=str, default="kronoker",
  41. help='transformer type')
  42. parser.add_argument('--meta_algo', type=str, default="metasgd",
  43. help='MAML/MetaSGD/GBML')
  44. parser.add_argument('--gpu', type=int, default=0,
  45. help='number of gpu to run the code')
  46. parser.add_argument('--epochs', type=int, default=config['num_epoch'],
  47. help='number of gpu to run the code')
  48. # parser.add_argument('--data_root', type=str, default="./movielens/ml-1m", help='path to data root')
  49. # parser.add_argument('--num_workers', type=int, default=4, help='num of workers to use')
  50. # parser.add_argument('--test', action='store_true', default=False, help='num of workers to use')
  51. # parser.add_argument('--embedding_dim', type=int, default=32, help='num of workers to use')
  52. # parser.add_argument('--first_fc_hidden_dim', type=int, default=64, help='num of workers to use')
  53. # parser.add_argument('--second_fc_hidden_dim', type=int, default=64, help='num of workers to use')
  54. # parser.add_argument('--num_epoch', type=int, default=30, help='num of workers to use')
  55. # parser.add_argument('--num_genre', type=int, default=25, help='num of workers to use')
  56. # parser.add_argument('--num_director', type=int, default=2186, help='num of workers to use')
  57. # parser.add_argument('--num_actor', type=int, default=8030, help='num of workers to use')
  58. # parser.add_argument('--num_rate', type=int, default=6, help='num of workers to use')
  59. # parser.add_argument('--num_gender', type=int, default=2, help='num of workers to use')
  60. # parser.add_argument('--num_age', type=int, default=7, help='num of workers to use')
  61. # parser.add_argument('--num_occupation', type=int, default=21, help='num of workers to use')
  62. # parser.add_argument('--num_zipcode', type=int, default=3402, help='num of workers to use')
  63. # parser.add_argument('--rerun', action='store_true', default=False,
  64. # help='Re-run experiment (will override previously saved results)')
  65. args = parser.parse_args()
  66. # use the GPU if available
  67. # args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  68. # print('Running on device: {}'.format(args.device))
  69. return args
  70. if __name__ == '__main__':
  71. args = parse_args()
  72. print(args)
  73. if config['use_cuda']:
  74. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  75. os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
  76. master_path = "/media/external_10TB/10TB/maheri/new_data_dir3"
  77. config['master_path'] = master_path
  78. # DATA GENERATION
  79. print("DATA GENERATION PHASE")
  80. if not os.path.exists("{}/".format(master_path)):
  81. os.mkdir("{}/".format(master_path))
  82. # preparing dataset. It needs about 22GB of your hard disk space.
  83. generate(master_path)
  84. # TRAINING
  85. print("TRAINING PHASE")
  86. embedding_dim = config['embedding_dim']
  87. fc1_in_dim = config['embedding_dim'] * 8
  88. fc2_in_dim = config['first_fc_hidden_dim']
  89. fc2_out_dim = config['second_fc_hidden_dim']
  90. use_cuda = config['use_cuda']
  91. fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
  92. fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
  93. linear_out = torch.nn.Linear(fc2_out_dim, 1)
  94. head = torch.nn.Sequential(fc1, fc2, linear_out)
  95. if use_cuda:
  96. emb = EmbeddingModule(config).cuda()
  97. else:
  98. emb = EmbeddingModule(config)
  99. # META LEARNING
  100. print("META LEARNING PHASE")
  101. # define transformer
  102. transform = None
  103. if args.transformer == "kronoker":
  104. transform = KroneckerTransform(l2l.nn.KroneckerLinear)
  105. elif args.transformer == "linear":
  106. transform = l2l.optim.ModuleTransform(torch.nn.Linear)
  107. trainer = Trainer(config)
  108. # define meta algorithm
  109. if args.meta_algo == "maml":
  110. trainer = l2l.algorithms.MAML(trainer, lr=args.lr_inner, first_order=args.first_order)
  111. elif args.meta_algo == 'metasgd':
  112. trainer = l2l.algorithms.MetaSGD(trainer, lr=config['local_lr'], first_order=args.first_order)
  113. elif args.meta_algo == 'gbml':
  114. trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=config['local_lr'],
  115. adapt_transform=args.adapt_transform,
  116. first_order=args.first_order)
  117. if use_cuda:
  118. trainer.cuda()
  119. # Setup optimization
  120. print("SETUP OPTIMIZATION PHASE")
  121. all_parameters = list(emb.parameters()) + list(trainer.parameters())
  122. optimizer = torch.optim.Adam(all_parameters, lr=config['lr'])
  123. # loss = torch.nn.MSELoss(reduction='mean')
  124. # Load training dataset.
  125. print("LOAD DATASET PHASE")
  126. training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4)
  127. supp_xs_s = []
  128. supp_ys_s = []
  129. query_xs_s = []
  130. query_ys_s = []
  131. batch_size = config['batch_size']
  132. # torch.cuda.empty_cache()
  133. print("\n\n\n")
  134. for iteration in range(args.epochs):
  135. num_batch = int(training_set_size / batch_size)
  136. indexes = list(np.arange(training_set_size))
  137. random.shuffle(indexes)
  138. for i in range(num_batch):
  139. meta_train_error = 0.0
  140. optimizer.zero_grad()
  141. print("EPOCH: ", iteration, " BATCH: ", i)
  142. supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
  143. for idx in range(batch_size * i, batch_size * (i + 1)):
  144. supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
  145. supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
  146. query_xs.append(
  147. pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
  148. query_ys.append(
  149. pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
  150. batch_sz = len(supp_xs)
  151. if use_cuda:
  152. for j in range(batch_size):
  153. supp_xs[j] = supp_xs[j].cuda()
  154. supp_ys[j] = supp_ys[j].cuda()
  155. query_xs[j] = query_xs[j].cuda()
  156. query_ys[j] = query_ys[j].cuda()
  157. for task in range(batch_sz):
  158. # Compute meta-training loss
  159. # sxs = supp_xs[task].cuda()
  160. # qxs = query_xs[task].cuda()
  161. # sys = supp_ys[task].cuda()
  162. # qys = query_ys[task].cuda()
  163. learner = trainer.clone()
  164. temp_sxs = emb(supp_xs[task])
  165. temp_qxs = emb(query_xs[task])
  166. evaluation_error = fast_adapt(learner,
  167. temp_sxs,
  168. temp_qxs,
  169. supp_ys[task],
  170. query_ys[task],
  171. config['inner'])
  172. evaluation_error.backward()
  173. meta_train_error += evaluation_error.item()
  174. # supp_xs[task].cpu()
  175. # query_xs[task].cpu()
  176. # supp_ys[task].cpu()
  177. # query_ys[task].cpu()
  178. # Print some metrics
  179. print('Iteration', iteration)
  180. print('Meta Train Error', meta_train_error / batch_sz)
  181. # Average the accumulated gradients and optimize
  182. for p in all_parameters:
  183. p.grad.data.mul_(1.0 / batch_sz)
  184. optimizer.step()
  185. # torch.cuda.empty_cache()
  186. del (supp_xs, supp_ys, query_xs, query_ys, learner, temp_sxs, temp_qxs)
  187. gc.collect()
  188. print("===============================================\n")
  189. if iteration % 2 == 0 or iteration>0:
  190. # testing
  191. print("start of test phase")
  192. trainer.eval()
  193. with open("results2.txt", "a") as f:
  194. f.write("epoch:{}\n".format(iteration))
  195. for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']:
  196. test_dataset = None
  197. test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
  198. supp_xs_s = []
  199. supp_ys_s = []
  200. query_xs_s = []
  201. query_ys_s = []
  202. gc.collect()
  203. print("===================== " + test_state + " =====================")
  204. mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],num_epoch=config['num_epoch'],test_state=test_state,args=args)
  205. with open("results2.txt", "a") as f:
  206. f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3))
  207. print("===================================================")
  208. del (test_dataset)
  209. gc.collect()
  210. trainer.train()
  211. with open("results2.txt", "a") as f:
  212. f.write("\n")
  213. print("\n\n\n")
  214. # save model
  215. # final_model = torch.nn.Sequential(emb, head)
  216. # torch.save(final_model.state_dict(), master_path + "/models_gbml.pkl")
  217. # testing
  218. # print("start of test phase")
  219. # for test_state in ['warm_state', 'user_cold_state', 'item_cold_state', 'user_and_item_cold_state']:
  220. # test_dataset = None
  221. # test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
  222. # supp_xs_s = []
  223. # supp_ys_s = []
  224. # query_xs_s = []
  225. # query_ys_s = []
  226. # for idx in range(test_set_size):
  227. # supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, idx), "rb")))
  228. # supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, idx), "rb")))
  229. # query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, idx), "rb")))
  230. # query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, idx), "rb")))
  231. # test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  232. # del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  233. #
  234. # print("===================== " + test_state + " =====================")
  235. # test(emb, head, test_dataset, batch_size=config['batch_size'], num_epoch=args.epochs,
  236. # adaptation_step=args.inner_eval)
  237. # print("===================================================\n\n\n")
  238. # print(args)