make other meta-learning algorithms implemented in l2l.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learnToLearn.py 6.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import os
  2. import torch
  3. import pickle
  4. from MeLU import MeLU
  5. from options import config
  6. from model_training import training
  7. from data_generation import generate
  8. from evidence_candidate import selection
  9. from model_test import test
  10. from embedding_module import EmbeddingModule
  11. import learn2learn as l2l
  12. from embeddings import item, user
  13. import random
  14. import numpy as np
  15. from learnToLearnTest import test
  16. from fast_adapt import fast_adapt
  17. import gc
  18. if config['use_cuda']:
  19. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  20. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  21. master_path= "/media/external_10TB/10TB/maheri/melu_data5"
  22. # DATA GENERATION
  23. print("DATA GENERATION PHASE")
  24. if not os.path.exists("{}/".format(master_path)):
  25. os.mkdir("{}/".format(master_path))
  26. # preparing dataset. It needs about 22GB of your hard disk space.
  27. generate(master_path)
  28. # TRAINING
  29. print("TRAINING PHASE")
  30. embedding_dim = config['embedding_dim']
  31. fc1_in_dim = config['embedding_dim'] * 8
  32. fc2_in_dim = config['first_fc_hidden_dim']
  33. fc2_out_dim = config['second_fc_hidden_dim']
  34. use_cuda = config['use_cuda']
  35. fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
  36. fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
  37. linear_out = torch.nn.Linear(fc2_out_dim, 1)
  38. head = torch.nn.Sequential(fc1,fc2,linear_out)
  39. if use_cuda:
  40. emb = EmbeddingModule(config).cuda()
  41. else:
  42. emb = EmbeddingModule(config)
  43. # META LEARNING
  44. print("META LEARNING PHASE")
  45. # head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True)
  46. transform = l2l.optim.ModuleTransform(torch.nn.Linear)
  47. head = l2l.algorithms.GBML(head , transform=transform , lr=config['local_lr'] , adapt_transform=True,first_order=False)
  48. if use_cuda:
  49. head.cuda()
  50. # Setup optimization
  51. print("SETUP OPTIMIZATION PHASE")
  52. all_parameters = list(emb.parameters()) + list(head.parameters())
  53. optimizer = torch.optim.Adam(all_parameters, lr=config['lr'])
  54. # loss = torch.nn.MSELoss(reduction='mean')
  55. # Load training dataset.
  56. print("LOAD DATASET PHASE")
  57. training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4)
  58. supp_xs_s = []
  59. supp_ys_s = []
  60. query_xs_s = []
  61. query_ys_s = []
  62. for idx in range(training_set_size):
  63. supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb")))
  64. supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb")))
  65. query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb")))
  66. query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb")))
  67. total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  68. del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  69. training_set_size = len(total_dataset)
  70. batch_size = config['batch_size']
  71. # torch.cuda.empty_cache()
  72. random.shuffle(total_dataset)
  73. num_batch = int(training_set_size / batch_size)
  74. a, b, c, d = zip(*total_dataset)
  75. print("\n\n\n")
  76. for iteration in range(config['num_epoch']):
  77. for i in range(num_batch):
  78. optimizer.zero_grad()
  79. meta_train_error = 0.0
  80. meta_train_accuracy = 0.0
  81. meta_valid_error = 0.0
  82. meta_valid_accuracy = 0.0
  83. meta_test_error = 0.0
  84. meta_test_accuracy = 0.0
  85. print("EPOCH: ", iteration, " BATCH: ", i)
  86. supp_xs = list(a[batch_size * i:batch_size * (i + 1)])
  87. supp_ys = list(b[batch_size * i:batch_size * (i + 1)])
  88. query_xs = list(c[batch_size * i:batch_size * (i + 1)])
  89. query_ys = list(d[batch_size * i:batch_size * (i + 1)])
  90. batch_sz = len(supp_xs)
  91. if use_cuda:
  92. for j in range(batch_size):
  93. supp_xs[j] = supp_xs[j].cuda()
  94. supp_ys[j] = supp_ys[j].cuda()
  95. query_xs[j] = query_xs[j].cuda()
  96. query_ys[j] = query_ys[j].cuda()
  97. for task in range(batch_sz):
  98. # print("EPOCH: ", iteration," BATCH: ",i, "TASK: ",task)
  99. # Compute meta-training loss
  100. learner = head.clone()
  101. temp_sxs = emb(supp_xs[task])
  102. temp_qxs = emb(query_xs[task])
  103. evaluation_error = fast_adapt(learner,
  104. temp_sxs,
  105. temp_qxs,
  106. supp_ys[task],
  107. query_ys[task],
  108. config['inner']
  109. )
  110. evaluation_error.backward()
  111. meta_train_error += evaluation_error.item()
  112. # Print some metrics
  113. print('Iteration', iteration)
  114. print('Meta Train Error', meta_train_error / batch_sz)
  115. # print('Meta Train Accuracy', meta_train_accuracy / batch_sz)
  116. # print('Meta Valid Error', meta_valid_error / batch_sz)
  117. # print('Meta Valid Accuracy', meta_valid_accuracy / batch_sz)
  118. # print('Meta Test Error', meta_test_error / batch_sz)
  119. # print('Meta Test Accuracy', meta_test_accuracy / batch_sz)
  120. # Average the accumulated gradients and optimize
  121. for p in all_parameters:
  122. p.grad.data.mul_(1.0 / batch_sz)
  123. optimizer.step()
  124. # torch.cuda.empty_cache()
  125. del(supp_xs,supp_ys,query_xs,query_ys)
  126. gc.collect()
  127. print("===============================================\n")
  128. # save model
  129. final_model = torch.nn.Sequential(emb,head)
  130. torch.save(final_model.state_dict(), master_path + "/models_gbml.pkl")
  131. # testing
  132. print("start of test phase")
  133. for test_state in ['warm_state', 'user_cold_state', 'item_cold_state', 'user_and_item_cold_state']:
  134. test_dataset = None
  135. test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
  136. supp_xs_s = []
  137. supp_ys_s = []
  138. query_xs_s = []
  139. query_ys_s = []
  140. for idx in range(test_set_size):
  141. supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, idx), "rb")))
  142. supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, idx), "rb")))
  143. query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, idx), "rb")))
  144. query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, idx), "rb")))
  145. test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
  146. del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
  147. print("===================== " + test_state + " =====================")
  148. test(emb,head, test_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'])
  149. print("===================================================\n\n\n")