extend Melu code to perform different meta algorithms and hyperparameters
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learnToLearn.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. import os
  2. import torch
  3. import pickle
  4. from options import config
  5. from data_generation import generate
  6. from embedding_module import EmbeddingModule
  7. import learn2learn as l2l
  8. import random
  9. from learnToLearnTest import test
  10. from fast_adapt import fast_adapt
  11. import gc
  12. from learn2learn.optim.transforms import KroneckerTransform
  13. import argparse
  14. from clustering import ClustringModule, Trainer
  15. import numpy as np
  16. from torch.nn import functional as F
  17. def data_batching(indexes, C_distribs, batch_size, training_set_size, num_clusters):
  18. probs = np.squeeze(C_distribs)
  19. cs = [np.random.choice(num_clusters, p=i) for i in probs]
  20. num_batch = int(training_set_size / batch_size)
  21. res = [[] for i in range(num_batch)]
  22. clas = [[] for i in range(num_clusters)]
  23. for idx, c in zip(indexes, cs):
  24. clas[c].append(idx)
  25. t = np.array([len(i) for i in clas])
  26. t = t / t.sum()
  27. dif = list(set(list(np.arange(training_set_size))) - set(indexes[0:(num_batch * batch_size)]))
  28. cnt = 0
  29. for i in range(len(res)):
  30. for j in range(batch_size):
  31. temp = np.random.choice(num_clusters, p=t)
  32. if len(clas[temp]) > 0:
  33. res[i].append(clas[temp].pop(0))
  34. else:
  35. # res[i].append(indexes[training_set_size-1-cnt])
  36. res[i].append(random.choice(dif))
  37. cnt = cnt + 1
  38. res = np.random.permutation(res)
  39. final_result = np.array(res).flatten()
  40. return final_result
  41. def parse_args():
  42. print("==============")
  43. parser = argparse.ArgumentParser([], description='Fast Context Adaptation via Meta-Learning (CAVIA),'
  44. 'Clasification experiments.')
  45. print("==============\n")
  46. parser.add_argument('--seed', type=int, default=53)
  47. parser.add_argument('--task', type=str, default='multi', help='problem setting: sine or celeba')
  48. parser.add_argument('--tasks_per_metaupdate', type=int, default=32,
  49. help='number of tasks in each batch per meta-update')
  50. parser.add_argument('--lr_inner', type=float, default=5e-6, help='inner-loop learning rate (per task)')
  51. parser.add_argument('--lr_meta', type=float, default=5e-5,
  52. help='outer-loop learning rate (used with Adam optimiser)')
  53. # parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate')
  54. parser.add_argument('--inner', type=int, default=1,
  55. help='number of gradient steps in inner loop (during training)')
  56. parser.add_argument('--inner_eval', type=int, default=1,
  57. help='number of gradient updates at test time (for evaluation)')
  58. parser.add_argument('--first_order', action='store_true', default=False,
  59. help='run first order approximation of CAVIA')
  60. parser.add_argument('--adapt_transform', action='store_true', default=False,
  61. help='run adaptation transform')
  62. parser.add_argument('--transformer', type=str, default="kronoker",
  63. help='transformer type')
  64. parser.add_argument('--meta_algo', type=str, default="gbml",
  65. help='MAML/MetaSGD/GBML')
  66. parser.add_argument('--gpu', type=int, default=0,
  67. help='number of gpu to run the code')
  68. parser.add_argument('--epochs', type=int, default=config['num_epoch'],
  69. help='number of gpu to run the code')
  70. # parser.add_argument('--data_root', type=str, default="./movielens/ml-1m", help='path to data root')
  71. # parser.add_argument('--num_workers', type=int, default=4, help='num of workers to use')
  72. # parser.add_argument('--test', action='store_true', default=False, help='num of workers to use')
  73. # parser.add_argument('--embedding_dim', type=int, default=32, help='num of workers to use')
  74. # parser.add_argument('--first_fc_hidden_dim', type=int, default=64, help='num of workers to use')
  75. # parser.add_argument('--second_fc_hidden_dim', type=int, default=64, help='num of workers to use')
  76. # parser.add_argument('--num_epoch', type=int, default=30, help='num of workers to use')
  77. # parser.add_argument('--num_genre', type=int, default=25, help='num of workers to use')
  78. # parser.add_argument('--num_director', type=int, default=2186, help='num of workers to use')
  79. # parser.add_argument('--num_actor', type=int, default=8030, help='num of workers to use')
  80. # parser.add_argument('--num_rate', type=int, default=6, help='num of workers to use')
  81. # parser.add_argument('--num_gender', type=int, default=2, help='num of workers to use')
  82. # parser.add_argument('--num_age', type=int, default=7, help='num of workers to use')
  83. # parser.add_argument('--num_occupation', type=int, default=21, help='num of workers to use')
  84. # parser.add_argument('--num_zipcode', type=int, default=3402, help='num of workers to use')
  85. # parser.add_argument('--rerun', action='store_true', default=False,
  86. # help='Re-run experiment (will override previously saved results)')
  87. args = parser.parse_args()
  88. # use the GPU if available
  89. # args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  90. # print('Running on device: {}'.format(args.device))
  91. return args
  92. from torch.nn import functional as F
  93. def kl_loss(C_distribs):
  94. # batchsize * k
  95. C_distribs = torch.stack(C_distribs).squeeze()
  96. # print("injam:",len(C_distribs))
  97. # print(C_distribs[0].shape)
  98. # batchsize * k
  99. # print("injam2",C_distribs)
  100. C_distribs_sq = torch.pow(C_distribs, 2)
  101. # print("injam3",C_distribs_sq)
  102. # 1*k
  103. C_distribs_sum = torch.sum(C_distribs, dim=0, keepdim=True)
  104. # print("injam4",C_distribs_sum)
  105. # batchsize * k
  106. temp = C_distribs_sq / C_distribs_sum
  107. # print("injam5",temp)
  108. # batchsize * 1
  109. temp_sum = torch.sum(temp, dim=1, keepdim=True)
  110. # print("injam6",temp_sum)
  111. target_distribs = temp / temp_sum
  112. # print("injam7",target_distribs)
  113. # calculate the kl loss
  114. clustering_loss = F.kl_div(C_distribs.log(), target_distribs, reduction='batchmean')
  115. # print("injam8",clustering_loss)
  116. return clustering_loss
  117. if __name__ == '__main__':
  118. args = parse_args()
  119. print(args)
  120. if config['use_cuda']:
  121. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  122. os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
  123. master_path = "/media/external_10TB/10TB/maheri/define_task_melu_data2"
  124. config['master_path'] = master_path
  125. # DATA GENERATION
  126. print("DATA GENERATION PHASE")
  127. if not os.path.exists("{}/".format(master_path)):
  128. os.mkdir("{}/".format(master_path))
  129. # preparing dataset. It needs about 22GB of your hard disk space.
  130. generate(master_path)
  131. # TRAINING
  132. print("TRAINING PHASE")
  133. embedding_dim = config['embedding_dim']
  134. fc1_in_dim = config['embedding_dim'] * 8
  135. fc2_in_dim = config['first_fc_hidden_dim']
  136. fc2_out_dim = config['second_fc_hidden_dim']
  137. use_cuda = config['use_cuda']
  138. # fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
  139. # fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
  140. # linear_out = torch.nn.Linear(fc2_out_dim, 1)
  141. # head = torch.nn.Sequential(fc1, fc2, linear_out)
  142. if use_cuda:
  143. emb = EmbeddingModule(config).cuda()
  144. else:
  145. emb = EmbeddingModule(config)
  146. # META LEARNING
  147. print("META LEARNING PHASE")
  148. # define transformer
  149. transform = None
  150. if args.transformer == "kronoker":
  151. transform = KroneckerTransform(l2l.nn.KroneckerLinear)
  152. elif args.transformer == "linear":
  153. transform = l2l.optim.ModuleTransform(torch.nn.Linear)
  154. trainer = Trainer(config)
  155. tr = trainer
  156. # define meta algorithm
  157. if args.meta_algo == "maml":
  158. trainer = l2l.algorithms.MAML(trainer, lr=args.lr_inner, first_order=args.first_order)
  159. elif args.meta_algo == 'metasgd':
  160. trainer = l2l.algorithms.MetaSGD(trainer, lr=args.lr_inner, first_order=args.first_order)
  161. elif args.meta_algo == 'gbml':
  162. trainer = l2l.algorithms.GBML(trainer, transform=transform, lr=args.lr_inner,
  163. adapt_transform=args.adapt_transform,
  164. first_order=args.first_order)
  165. if use_cuda:
  166. trainer.cuda()
  167. # Setup optimization
  168. print("SETUP OPTIMIZATION PHASE")
  169. all_parameters = list(emb.parameters()) + list(trainer.parameters())
  170. optimizer = torch.optim.Adam(all_parameters, lr=config['lr'])
  171. # loss = torch.nn.MSELoss(reduction='mean')
  172. # Load training dataset.
  173. print("LOAD DATASET PHASE")
  174. training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4)
  175. supp_xs_s = []
  176. supp_ys_s = []
  177. query_xs_s = []
  178. query_ys_s = []
  179. batch_size = config['batch_size']
  180. # torch.cuda.empty_cache()
  181. print("\n\n\n")
  182. for iteration in range(config['num_epoch']):
  183. if iteration == 0:
  184. print("changing cluster centroids started ...")
  185. indexes = list(np.arange(training_set_size))
  186. supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
  187. for idx in range(0, 2500):
  188. supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
  189. supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
  190. query_xs.append(
  191. pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
  192. query_ys.append(
  193. pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
  194. batch_sz = len(supp_xs)
  195. user_embeddings = []
  196. for task in range(batch_sz):
  197. # Compute meta-training loss
  198. supp_xs[task] = supp_xs[task].cuda()
  199. supp_ys[task] = supp_ys[task].cuda()
  200. # query_xs[task] = query_xs[task].cuda()
  201. # query_ys[task] = query_ys[task].cuda()
  202. temp_sxs = emb(supp_xs[task])
  203. # temp_qxs = emb(query_xs[task])
  204. y = supp_ys[task].view(-1, 1)
  205. input_pairs = torch.cat((temp_sxs, y), dim=1)
  206. task_embed = tr.cluster_module.input_to_hidden(input_pairs)
  207. # todo : may be useless
  208. mean_task = tr.cluster_module.aggregate(task_embed)
  209. user_embeddings.append(mean_task.detach().cpu().numpy())
  210. supp_xs[task] = supp_xs[task].cpu()
  211. supp_ys[task] = supp_ys[task].cpu()
  212. from sklearn.cluster import KMeans
  213. user_embeddings = np.array(user_embeddings)
  214. kmeans_model = KMeans(n_clusters=config['cluster_k'], init="k-means++").fit(user_embeddings)
  215. tr.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda()
  216. if iteration > 0:
  217. # indexes = data_batching(indexes, C_distribs, batch_size, training_set_size, config['cluster_k'])
  218. # random.shuffle(indexes)
  219. C_distribs = []
  220. else:
  221. num_batch = int(training_set_size / batch_size)
  222. indexes = list(np.arange(training_set_size))
  223. random.shuffle(indexes)
  224. for i in range(num_batch):
  225. meta_train_error = 0.0
  226. meta_cluster_error = 0.0
  227. optimizer.zero_grad()
  228. print("EPOCH: ", iteration, " BATCH: ", i)
  229. supp_xs, supp_ys, query_xs, query_ys = [], [], [], []
  230. for idx in range(batch_size * i, batch_size * (i + 1)):
  231. supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
  232. supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
  233. query_xs.append(
  234. pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb")))
  235. query_ys.append(
  236. pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb")))
  237. batch_sz = len(supp_xs)
  238. if use_cuda:
  239. for j in range(batch_size):
  240. supp_xs[j] = supp_xs[j].cuda()
  241. supp_ys[j] = supp_ys[j].cuda()
  242. query_xs[j] = query_xs[j].cuda()
  243. query_ys[j] = query_ys[j].cuda()
  244. C_distribs = []
  245. for task in range(batch_sz):
  246. # Compute meta-training loss
  247. # sxs = supp_xs[task].cuda()
  248. # qxs = query_xs[task].cuda()
  249. # sys = supp_ys[task].cuda()
  250. # qys = query_ys[task].cuda()
  251. learner = trainer.clone()
  252. temp_sxs = emb(supp_xs[task])
  253. temp_qxs = emb(query_xs[task])
  254. evaluation_error, c, k_loss = fast_adapt(learner,
  255. temp_sxs,
  256. temp_qxs,
  257. supp_ys[task],
  258. query_ys[task],
  259. config['inner'],
  260. epoch=iteration)
  261. # C_distribs.append(c)
  262. evaluation_error.backward(retain_graph=True)
  263. meta_train_error += evaluation_error.item()
  264. meta_cluster_error += k_loss
  265. # supp_xs[task].cpu()
  266. # query_xs[task].cpu()
  267. # supp_ys[task].cpu()
  268. # query_ys[task].cpu()
  269. # Print some metrics
  270. print('Iteration', iteration)
  271. print('Meta Train Error', meta_train_error / batch_sz)
  272. print('KL Train Error', meta_cluster_error / batch_sz)
  273. # clustering_loss = config['kl_loss_weight'] * kl_loss(C_distribs)
  274. # clustering_loss.backward()
  275. # print("kl_loss:", round(clustering_loss.item(), 8), "\t", C_distribs[0].cpu().detach().numpy())
  276. # if i != (num_batch - 1):
  277. # C_distribs = []
  278. # Average the accumulated gradients and optimize
  279. for p in all_parameters:
  280. p.grad.data.mul_(1.0 / batch_sz)
  281. optimizer.step()
  282. # torch.cuda.empty_cache()
  283. del (supp_xs, supp_ys, query_xs, query_ys, learner, temp_sxs, temp_qxs)
  284. gc.collect()
  285. print("===============================================\n")
  286. # if iteration % 2 == 0 and iteration != 0:
  287. # # testing
  288. # print("start of test phase")
  289. # trainer.eval()
  290. #
  291. # with open("results2.txt", "a") as f:
  292. # f.write("epoch:{}\n".format(iteration))
  293. #
  294. # for test_state in ['user_cold_state', 'item_cold_state', 'user_and_item_cold_state']:
  295. # test_dataset = None
  296. # test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
  297. # supp_xs_s = []
  298. # supp_ys_s = []
  299. # query_xs_s = []
  300. # query_ys_s = []
  301. # gc.collect()
  302. #
  303. # print("===================== " + test_state + " =====================")
  304. # mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],
  305. # num_epoch=config['num_epoch'], test_state=test_state, args=args)
  306. # with open("results2.txt", "a") as f:
  307. # f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3))
  308. # print("===================================================")
  309. # del (test_dataset)
  310. # gc.collect()
  311. #
  312. # trainer.train()
  313. # with open("results2.txt", "a") as f:
  314. # f.write("\n")
  315. # print("\n\n\n")
  316. # save model
  317. # final_model = torch.nn.Sequential(emb, head)
  318. # torch.save(final_model.state_dict(), master_path + "/models_gbml.pkl")
  319. # testing
  320. # print("start of test phase")
  321. # for test_state in ['warm_state', 'user_cold_state', 'item_cold_state', 'user_and_item_cold_state']:
  322. # test_dataset = None
  323. # test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
  324. # supp_xs_s = []
  325. # supp_ys_s = []
  326. # query_xs_s = []
  327. # query_ys_s = []
  328. # for idx in range(test_set_size):
  329. # supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, idx), "rb")))
  330. # supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, idx), "rb")))
  331. # query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, idx), "rb")))
  332. # query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, idx), "rb")))
  333. # test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  334. # del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  335. #
  336. # print("===================== " + test_state + " =====================")
  337. # test(emb, head, test_dataset, batch_size=config['batch_size'], num_epoch=args.epochs,
  338. # adaptation_step=args.inner_eval)
  339. # print("===================================================\n\n\n")
  340. # print(args)