You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cross_validation.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import argparse
  2. import os
  3. import logging
  4. import time
  5. import pickle
  6. import torch
  7. import torch.nn as nn
  8. import gc
  9. from datetime import datetime
  10. time_str = str(datetime.now().strftime('%y%m%d%H%M'))
  11. from model.datasets import FastSynergyDataset, FastTensorDataLoader
  12. from model.models import MLP
  13. from drug.datasets import DDInteractionDataset
  14. from model.utils import save_args, save_best_model, find_best_model, arg_min, random_split_indices, calc_stat, conf_inv
  15. from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE
  16. from torch_geometric.loader import NeighborLoader
  17. def find_subgraph_mask(ddi_graph, batch):
  18. # print('----------------------------------')
  19. drug1_id, drug2_id, cell_feat, y_true = batch
  20. drug1_id = torch.flatten(drug1_id)
  21. drug2_id = torch.flatten(drug2_id)
  22. drug_ids = torch.cat((drug1_id, drug2_id), 0)
  23. drug_ids, count = drug_ids.unique(return_counts=True)
  24. # Removing negative ids
  25. drug_id_range = (drug_ids >= 0).nonzero(as_tuple=True)[0]
  26. drug_ids = torch.index_select(drug_ids, 0, drug_id_range)
  27. return drug_ids
  28. def eval_model(model, optimizer, loss_func, train_data, test_data,
  29. batch_size, n_epoch, patience, ddi_graph, gpu_id, mdl_dir):
  30. tr_indices, es_indices = random_split_indices(len(train_data), test_rate=0.1)
  31. train_loader = FastTensorDataLoader(*train_data.tensor_samples(tr_indices), batch_size=batch_size, shuffle=True)
  32. valid_loader = FastTensorDataLoader(*train_data.tensor_samples(es_indices), batch_size=len(es_indices) // 4)
  33. test_loader = FastTensorDataLoader(*test_data.tensor_samples(), batch_size=len(test_data) // 4)
  34. train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id, ddi_graph,
  35. sl=True, mdl_dir=mdl_dir)
  36. test_loss = eval_epoch(model, test_loader, loss_func, ddi_graph, gpu_id)
  37. test_loss /= len(test_data)
  38. return test_loss
  39. def step_batch(model, batch, loss_func, subgraph, gpu_id=None, train=True):
  40. drug1_id, drug2_id, cell_feat, y_true = batch
  41. if gpu_id is not None:
  42. drug1_id = drug1_id.cuda(gpu_id)
  43. drug2_id = drug2_id.cuda(gpu_id)
  44. cell_feat = cell_feat.cuda(gpu_id)
  45. y_true = y_true.cuda(gpu_id)
  46. if train:
  47. y_pred = model(drug1_id, drug2_id, cell_feat, subgraph)
  48. else:
  49. yp1 = model(drug1_id, drug2_id, cell_feat, subgraph)
  50. yp2 = model(drug2_id, drug1_id, cell_feat, subgraph)
  51. y_pred = (yp1 + yp2) / 2
  52. loss = loss_func(y_pred, y_true)
  53. return loss
  54. def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None):
  55. # print('this is in train epoch: ', ddi_graph)
  56. model.train()
  57. epoch_loss = 0
  58. ddi_graph.n_id = torch.arange(ddi_graph.num_nodes)
  59. start = time.time()
  60. # print("len of batch loader: ",len(loader))
  61. for i, batch in enumerate(loader):
  62. # print(f"this is batch {i} ------------------")
  63. # TODO: for batch ... desired subgraph
  64. subgraph_mask = find_subgraph_mask(ddi_graph, batch)
  65. # print("this is subgraph mask: ---------------")
  66. # print(subgraph_mask)
  67. # print("subgraph_mask: ",subgraph_mask) ---> subgraph_mask is OK
  68. # print("This is ddi_graph:")
  69. # print(ddi_graph) ---> ddi_graph is OK
  70. subgraph_batch_size = len(subgraph_mask)
  71. neighbor_loader = NeighborLoader(ddi_graph,
  72. num_neighbors=[3,2], batch_size=subgraph_batch_size,
  73. input_nodes=subgraph_mask,
  74. directed=False, shuffle=True)
  75. # sampled_data = next(iter(neighbor_loader))
  76. # print(sampled_data)
  77. # print("---------------")
  78. # print(type(sampled_data))
  79. # print("Sampled_data: ")
  80. # print(sampled_data.batch_size)
  81. start_inner = time.time()
  82. for subgraph in neighbor_loader:
  83. # print("this is subgraph in cross_validation:")
  84. # print(subgraph)
  85. optimizer.zero_grad()
  86. loss = step_batch(model, batch, loss_func, subgraph, gpu_id)
  87. loss.backward()
  88. optimizer.step()
  89. epoch_loss += loss.item()
  90. print("for on subgraphs time: ", time.time() - start_inner)
  91. # print("epoch_loss: ", epoch_loss)
  92. print("train epoch time: ", time.time() - start)
  93. return epoch_loss
  94. def eval_epoch(model, loader, loss_func, ddi_graph, gpu_id=None):
  95. model.eval()
  96. with torch.no_grad():
  97. epoch_loss = 0
  98. for batch in loader:
  99. subgraph_mask = find_subgraph_mask(ddi_graph, batch)
  100. subgraph_batch_size = len(subgraph_mask)
  101. neighbor_loader = NeighborLoader(ddi_graph,
  102. num_neighbors=[3,2], batch_size=subgraph_batch_size,
  103. input_nodes=subgraph_mask,
  104. directed=False, shuffle=True)
  105. for subgraph in neighbor_loader:
  106. # print("this is subgraph in cross_validation:")
  107. # print(subgraph)
  108. loss = step_batch(model, batch, loss_func, subgraph, gpu_id, train=False)
  109. epoch_loss += loss.item()
  110. return epoch_loss
  111. def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
  112. ddi_graph,
  113. sl=False, mdl_dir=None):
  114. # print('this is in train: ', ddi_graph)
  115. min_loss = float('inf')
  116. angry = 0
  117. for epoch in range(1, n_epoch + 1):
  118. start_time = time.time()
  119. print("epoch: ",str(epoch))
  120. trn_loss = train_epoch(model, train_loader, loss_func, optimizer, ddi_graph, gpu_id)
  121. trn_loss /= train_loader.dataset_len
  122. # print("train loss: ", trn_loss)
  123. val_loss = eval_epoch(model, valid_loader, loss_func, ddi_graph, gpu_id)
  124. val_loss /= valid_loader.dataset_len
  125. # print("val loss: ", val_loss)
  126. print("loss epoch " + str(epoch) + ": " + str(trn_loss))
  127. if val_loss < min_loss:
  128. angry = 0
  129. min_loss = val_loss
  130. if sl:
  131. # TODO: save best model in case of need
  132. save_best_model(model.state_dict(), mdl_dir, epoch, keep=1)
  133. pass
  134. else:
  135. angry += 1
  136. if angry >= patience:
  137. break
  138. end_time = time.time()
  139. print("epoch time: ", (end_time - start_time))
  140. if sl:
  141. model.load_state_dict(torch.load(find_best_model(mdl_dir)))
  142. return min_loss
  143. def create_model(data, hidden_size, gpu_id=None):
  144. # get 256
  145. model = MLP(data.cell_feat_len() + 2 * 256, hidden_size, gpu_id)
  146. if gpu_id is not None:
  147. model = model.cuda(gpu_id)
  148. return model
  149. def cv(args, out_dir):
  150. torch.cuda.set_per_process_memory_fraction(0.6, 0)
  151. # Clear any cached memory
  152. gc.collect()
  153. torch.cuda.empty_cache()
  154. save_args(args, os.path.join(out_dir, 'args.json'))
  155. test_loss_file = os.path.join(out_dir, 'test_loss.pkl')
  156. print("cuda available: " + str(torch.cuda.is_available()))
  157. if torch.cuda.is_available() and (args.gpu is not None):
  158. gpu_id = args.gpu
  159. else:
  160. gpu_id = None
  161. ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
  162. ddi_graph = ddiDataset.get()
  163. n_folds = 5
  164. n_delimiter = 60
  165. loss_func = nn.MSELoss(reduction='sum')
  166. test_losses = []
  167. for test_fold in range(n_folds):
  168. outer_trn_folds = [x for x in range(n_folds) if x != test_fold]
  169. logging.info("Outer: train folds {}, test folds {}".format(outer_trn_folds, test_fold))
  170. logging.info("-" * n_delimiter)
  171. param = []
  172. losses = []
  173. for hs in args.hidden:
  174. for lr in args.lr:
  175. param.append((hs, lr))
  176. logging.info("Hidden size: {} | Learning rate: {}".format(hs, lr))
  177. ret_vals = []
  178. for valid_fold in outer_trn_folds:
  179. inner_trn_folds = [x for x in outer_trn_folds if x != valid_fold]
  180. valid_folds = [valid_fold]
  181. train_data = FastSynergyDataset(DRUGNAME_2_DRUGBANKID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
  182. SYNERGY_FILE, use_folds=inner_trn_folds)
  183. valid_data = FastSynergyDataset(DRUGNAME_2_DRUGBANKID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
  184. SYNERGY_FILE, use_folds=valid_folds, train=False)
  185. train_loader = FastTensorDataLoader(*train_data.tensor_samples(), batch_size=args.batch,
  186. shuffle=True)
  187. valid_loader = FastTensorDataLoader(*valid_data.tensor_samples(), batch_size=len(valid_data) // 4)
  188. model = create_model(train_data, hs, gpu_id)
  189. optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  190. logging.info(
  191. "Start inner loop: train folds {}, valid folds {}".format(inner_trn_folds, valid_folds))
  192. ret = train_model(model, optimizer, loss_func, train_loader, valid_loader,
  193. args.epoch, args.patience, gpu_id, ddi_graph = ddi_graph, sl=False, )
  194. ret_vals.append(ret)
  195. del model
  196. inner_loss = sum(ret_vals) / len(ret_vals)
  197. logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss))
  198. logging.info("-" * n_delimiter)
  199. losses.append(inner_loss)
  200. torch.cuda.memory_summary(device=None, abbreviated=False)
  201. gc.collect()
  202. torch.cuda.empty_cache()
  203. time.sleep(10)
  204. min_ls, min_idx = arg_min(losses)
  205. best_hs, best_lr = param[min_idx]
  206. train_data = FastSynergyDataset(DRUGNAME_2_DRUGBANKID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
  207. SYNERGY_FILE, use_folds=outer_trn_folds)
  208. test_data = FastSynergyDataset(DRUGNAME_2_DRUGBANKID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
  209. SYNERGY_FILE, use_folds=[test_fold], train=False)
  210. model = create_model(train_data, best_hs, gpu_id)
  211. optimizer = torch.optim.Adam(model.parameters(), lr=best_lr)
  212. logging.info("Best hidden size: {} | Best learning rate: {}".format(best_hs, best_lr))
  213. logging.info("Start test on fold {}.".format([test_fold]))
  214. test_mdl_dir = os.path.join(out_dir, str(test_fold))
  215. if not os.path.exists(test_mdl_dir):
  216. os.makedirs(test_mdl_dir)
  217. test_loss = eval_model(model, optimizer, loss_func, train_data, test_data,
  218. args.batch, args.epoch, args.patience, ddi_graph, gpu_id, test_mdl_dir)
  219. test_losses.append(test_loss)
  220. logging.info("Test loss: {:.4f}".format(test_loss))
  221. logging.info("*" * n_delimiter + '\n')
  222. logging.info("CV completed")
  223. with open(test_loss_file, 'wb') as f:
  224. pickle.dump(test_losses, f)
  225. mu, sigma = calc_stat(test_losses)
  226. logging.info("MSE: {:.4f} ± {:.4f}".format(mu, sigma))
  227. lo, hi = conf_inv(mu, sigma, len(test_losses))
  228. logging.info("Confidence interval: [{:.4f}, {:.4f}]".format(lo, hi))
  229. rmse_loss = [x ** 0.5 for x in test_losses]
  230. mu, sigma = calc_stat(rmse_loss)
  231. logging.info("RMSE: {:.4f} ± {:.4f}".format(mu, sigma))
  232. def main():
  233. parser = argparse.ArgumentParser()
  234. parser.add_argument('--epoch', type=int, default=500, help="n epoch")
  235. parser.add_argument('--batch', type=int, default=256, help="batch size")
  236. parser.add_argument('--gpu', type=int, default=None, help="cuda device")
  237. parser.add_argument('--patience', type=int, default=100, help='patience for early stop')
  238. parser.add_argument('--suffix', type=str, default=time_str, help="model dir suffix")
  239. parser.add_argument('--hidden', type=int, nargs='+', default=[2048, 4096, 8192], help="hidden size")
  240. parser.add_argument('--lr', type=float, nargs='+', default=[1e-3, 1e-4, 1e-5], help="learning rate")
  241. args = parser.parse_args()
  242. out_dir = os.path.join(OUTPUT_DIR, 'cv_{}'.format(args.suffix))
  243. if not os.path.exists(out_dir):
  244. os.makedirs(out_dir)
  245. log_file = os.path.join(out_dir, 'cv.log')
  246. logging.basicConfig(filename=log_file,
  247. format='%(asctime)s %(message)s',
  248. datefmt='[%Y-%m-%d %H:%M:%S]',
  249. level=logging.INFO)
  250. cv(args, out_dir)
  251. if __name__ == '__main__':
  252. main()