You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cross_validation.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import argparse
  2. import os
  3. import logging
  4. import time
  5. import pickle
  6. import torch
  7. import torch.nn as nn
  8. import gc
  9. from datetime import datetime
  10. time_str = str(datetime.now().strftime('%y%m%d%H%M'))
  11. from model.datasets import FastSynergyDataset, FastTensorDataLoader
  12. from model.models import MLP
  13. from model.utils import save_args, save_best_model, find_best_model, arg_min, random_split_indices, calc_stat, conf_inv
  14. from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE, DTI_FEAT_FILE, FP_FEAT_FILE, DRUG2ID_FILE
  15. def eval_model(model, optimizer, loss_func, train_data, test_data,
  16. batch_size, n_epoch, patience, gpu_id, mdl_dir):
  17. tr_indices, es_indices = random_split_indices(len(train_data), test_rate=0.1)
  18. train_loader = FastTensorDataLoader(*train_data.tensor_samples(tr_indices), batch_size=batch_size, shuffle=True)
  19. valid_loader = FastTensorDataLoader(*train_data.tensor_samples(es_indices), batch_size=len(es_indices) // 4)
  20. test_loader = FastTensorDataLoader(*test_data.tensor_samples(), batch_size=len(test_data) // 4)
  21. train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
  22. sl=True, mdl_dir=mdl_dir)
  23. test_loss = eval_epoch(model, test_loader, loss_func, gpu_id)
  24. test_loss /= len(test_data)
  25. test_accuracy, test_recall = eval_acc(model, test_loader, loss_func, gpu_id, len(test_data))
  26. return test_loss, test_accuracy, test_recall
  27. def step_batch(model, batch, loss_func, gpu_id=None, train=True):
  28. drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat, y_true = batch
  29. if gpu_id is not None:
  30. drug1_idx = drug1_idx.cuda(gpu_id)
  31. drug2_idx = drug2_idx.cuda(gpu_id)
  32. drug1_fp = drug1_fp.cuda(gpu_id)
  33. drug2_fp = drug2_fp.cuda(gpu_id)
  34. drug1_dti = drug1_dti.cuda(gpu_id)
  35. drug2_dti = drug2_dti.cuda(gpu_id)
  36. cell_feat = cell_feat.cuda(gpu_id)
  37. y_true = y_true.cuda(gpu_id)
  38. if train:
  39. y_pred = model(drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat)
  40. else:
  41. yp1 = model(drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat)
  42. yp2 = model(drug1_idx, drug2_idx, drug2_fp, drug1_fp, drug2_dti, drug1_dti, cell_feat)
  43. y_pred = (yp1 + yp2) / 2
  44. loss = loss_func(y_pred, y_true)
  45. # weighted loss:
  46. # indices_true = []
  47. # indices_false = []
  48. # for i in range(len(y_pred)):
  49. # if (y_true[i] > 10 and y_pred[i] > 10) or (y_true[i] < 10 and y_pred[i] < 10):
  50. # indices_true.append(i)
  51. # else:
  52. # indices_false.append(i)
  53. # loss_true = loss_func(y_pred[indices_true], y_true[indices_true])
  54. # loss_false = loss_func(y_pred[indices_false], y_true[indices_false])
  55. # loss = loss_true + 5 * loss_false
  56. return loss
  57. def train_epoch(model, loader, loss_func, optimizer, gpu_id=None):
  58. model.train()
  59. epoch_loss = 0
  60. for _, batch in enumerate(loader):
  61. optimizer.zero_grad()
  62. loss = step_batch(model, batch, loss_func, gpu_id)
  63. loss.backward()
  64. optimizer.step()
  65. epoch_loss += loss.item()
  66. return epoch_loss
  67. def eval_epoch(model, loader, loss_func, gpu_id=None):
  68. model.eval()
  69. with torch.no_grad():
  70. epoch_loss = 0
  71. for batch in loader:
  72. loss = step_batch(model, batch, loss_func, gpu_id, train=False)
  73. epoch_loss += loss.item()
  74. return epoch_loss
  75. def step_acc(model, batch, loss_func, gpu_id=None, train=True):
  76. drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat, y_true = batch
  77. if gpu_id is not None:
  78. drug1_idx = drug1_idx.cuda(gpu_id)
  79. drug2_idx = drug2_idx.cuda(gpu_id)
  80. drug1_fp = drug1_fp.cuda(gpu_id)
  81. drug2_fp = drug2_fp.cuda(gpu_id)
  82. drug1_dti = drug1_dti.cuda(gpu_id)
  83. drug2_dti = drug2_dti.cuda(gpu_id)
  84. cell_feat = cell_feat.cuda(gpu_id)
  85. y_true = y_true.cuda(gpu_id)
  86. if train:
  87. y_pred = model(drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat)
  88. else:
  89. yp1 = model(drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat)
  90. yp2 = model(drug1_idx, drug2_idx, drug2_fp, drug1_fp, drug2_dti, drug1_dti, cell_feat)
  91. y_pred = (yp1 + yp2) / 2
  92. TP = 0
  93. TN = 0
  94. FN = 0
  95. loss = loss_func(y_pred, y_true)
  96. for i in range(len(y_pred)):
  97. if (y_true[i] > 10 and y_pred[i] > 10) or (y_true[i] < 10 and y_pred[i] < 10):
  98. if y_true[i] > 0:
  99. TP += 1
  100. else:
  101. TN += 1
  102. elif y_true[i] > 0:
  103. FN += 1
  104. return TP, TN, FN
  105. def eval_acc(model, loader, loss_func, gpu_id, test_len):
  106. model.eval()
  107. with torch.no_grad():
  108. TP = 0
  109. TN = 0
  110. FN = 0
  111. for batch in loader:
  112. t1, t2, t3 = step_acc(model, batch, loss_func, gpu_id, train=False)
  113. TP += t1
  114. TN += t2
  115. FN += t3
  116. acc = (TP + TN) / test_len
  117. recall = TP / (TP + FN)
  118. return acc, recall
  119. def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
  120. sl=False, mdl_dir=None):
  121. min_loss = float('inf')
  122. angry = 0
  123. for epoch in range(1, n_epoch + 1):
  124. trn_loss = train_epoch(model, train_loader, loss_func, optimizer, gpu_id)
  125. trn_loss /= train_loader.dataset_len
  126. val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id)
  127. val_loss /= valid_loader.dataset_len
  128. print("loss epoch " + str(epoch) + ": " + str(trn_loss))
  129. if val_loss < min_loss:
  130. angry = 0
  131. min_loss = val_loss
  132. if sl:
  133. # TODO: save best model in case of need
  134. save_best_model(model.state_dict(), mdl_dir, epoch, keep=1)
  135. pass
  136. else:
  137. angry += 1
  138. if angry >= patience:
  139. break
  140. if sl:
  141. model.load_state_dict(torch.load(find_best_model(mdl_dir)))
  142. return min_loss
  143. def create_model(data, hidden_size, gpu_id=None):
  144. # TODO: use our own MLP model
  145. # get 256
  146. model = MLP(data.cell_feat_len() + 2 * 40 + 2 * 256 + 2 * 64, hidden_size, gpu_id)
  147. if gpu_id is not None:
  148. model = model.cuda(gpu_id)
  149. return model
  150. def cv(args, out_dir):
  151. gc.collect()
  152. save_args(args, os.path.join(out_dir, 'args.json'))
  153. test_loss_file = os.path.join(out_dir, 'test_loss.pkl')
  154. print("cuda available: " + str(torch.cuda.is_available()))
  155. if torch.cuda.is_available() and (args.gpu is not None):
  156. gpu_id = args.gpu
  157. else:
  158. gpu_id = None
  159. n_folds = 5
  160. n_delimiter = 60
  161. loss_func = nn.MSELoss(reduction='sum')
  162. test_losses = []
  163. all_acc = []
  164. all_recall = []
  165. for test_fold in range(n_folds):
  166. outer_trn_folds = [x for x in range(n_folds) if x != test_fold]
  167. logging.info("Outer: train folds {}, test folds {}".format(outer_trn_folds, test_fold))
  168. logging.info("-" * n_delimiter)
  169. param = []
  170. losses = []
  171. for hs in args.hidden:
  172. for lr in args.lr:
  173. param.append((hs, lr))
  174. logging.info("Hidden size: {} | Learning rate: {}".format(hs, lr))
  175. ret_vals = []
  176. for valid_fold in outer_trn_folds:
  177. inner_trn_folds = [x for x in outer_trn_folds if x != valid_fold]
  178. valid_folds = [valid_fold]
  179. train_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
  180. SYNERGY_FILE, DTI_FEAT_FILE, FP_FEAT_FILE, use_folds=inner_trn_folds)
  181. valid_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
  182. SYNERGY_FILE, DTI_FEAT_FILE, FP_FEAT_FILE, use_folds=valid_folds, train=False)
  183. train_loader = FastTensorDataLoader(*train_data.tensor_samples(), batch_size=args.batch,
  184. shuffle=True)
  185. valid_loader = FastTensorDataLoader(*valid_data.tensor_samples(), batch_size=len(valid_data) // 4)
  186. model = create_model(train_data, hs, gpu_id)
  187. optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  188. logging.info(
  189. "Start inner loop: train folds {}, valid folds {}".format(inner_trn_folds, valid_folds))
  190. ret = train_model(model, optimizer, loss_func, train_loader, valid_loader,
  191. args.epoch, args.patience, gpu_id, sl=False)
  192. ret_vals.append(ret)
  193. del model
  194. inner_loss = sum(ret_vals) / len(ret_vals)
  195. logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss))
  196. logging.info("-" * n_delimiter)
  197. losses.append(inner_loss)
  198. torch.cuda.memory_summary(device=None, abbreviated=False)
  199. gc.collect()
  200. torch.cuda.empty_cache()
  201. time.sleep(10)
  202. min_ls, min_idx = arg_min(losses)
  203. best_hs, best_lr = param[min_idx]
  204. train_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
  205. SYNERGY_FILE, DTI_FEAT_FILE, FP_FEAT_FILE, use_folds=outer_trn_folds)
  206. test_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
  207. SYNERGY_FILE, DTI_FEAT_FILE, FP_FEAT_FILE, use_folds=[test_fold], train=False)
  208. model = create_model(train_data, best_hs, gpu_id)
  209. optimizer = torch.optim.Adam(model.parameters(), lr=best_lr)
  210. logging.info("Best hidden size: {} | Best learning rate: {}".format(best_hs, best_lr))
  211. logging.info("Start test on fold {}.".format([test_fold]))
  212. test_mdl_dir = os.path.join(out_dir, str(test_fold))
  213. if not os.path.exists(test_mdl_dir):
  214. os.makedirs(test_mdl_dir)
  215. test_loss, acc, recall = eval_model(model, optimizer, loss_func, train_data, test_data,
  216. args.batch, args.epoch, args.patience, gpu_id, test_mdl_dir)
  217. all_acc.append(acc)
  218. all_recall.append(recall)
  219. test_losses.append(test_loss)
  220. logging.info("Test loss: {:.4f}".format(test_loss))
  221. logging.info("Test accuracy: {:.4f}".format(acc))
  222. logging.info("Test recall: {:.4f}".format(recall))
  223. logging.info("*" * n_delimiter + '\n')
  224. logging.info("CV completed")
  225. with open(test_loss_file, 'wb') as f:
  226. pickle.dump(test_losses, f)
  227. mu, sigma = calc_stat(test_losses)
  228. logging.info("MSE: {:.4f} ± {:.4f}".format(mu, sigma))
  229. lo, hi = conf_inv(mu, sigma, len(test_losses))
  230. logging.info("Confidence interval: [{:.4f}, {:.4f}]".format(lo, hi))
  231. rmse_loss = [x ** 0.5 for x in test_losses]
  232. mu, sigma = calc_stat(rmse_loss)
  233. logging.info("RMSE: {:.4f} ± {:.4f}".format(mu, sigma))
  234. logging.info("average accuracy on test: {}".format(sum(all_acc) / len(all_acc)))
  235. logging.info("average recall on test: {}".format({sum(all_recall) / len(all_recall)}))
  236. def main():
  237. parser = argparse.ArgumentParser()
  238. parser.add_argument('--epoch', type=int, default=500, help="n epoch")
  239. parser.add_argument('--batch', type=int, default=256, help="batch size")
  240. parser.add_argument('--gpu', type=int, default=None, help="cuda device")
  241. parser.add_argument('--patience', type=int, default=100, help='patience for early stop')
  242. parser.add_argument('--suffix', type=str, default=time_str, help="model dir suffix")
  243. parser.add_argument('--hidden', type=int, nargs='+', default=[2048, 4096, 8192], help="hidden size")
  244. parser.add_argument('--lr', type=float, nargs='+', default=[1e-3, 1e-4, 1e-5], help="learning rate")
  245. args = parser.parse_args()
  246. out_dir = os.path.join(OUTPUT_DIR, 'cv_{}'.format(args.suffix))
  247. if not os.path.exists(out_dir):
  248. os.makedirs(out_dir)
  249. log_file = os.path.join(out_dir, 'cv.log')
  250. logging.basicConfig(filename=log_file,
  251. format='%(asctime)s %(message)s',
  252. datefmt='[%Y-%m-%d %H:%M:%S]',
  253. level=logging.INFO)
  254. cv(args, out_dir)
  255. if __name__ == '__main__':
  256. main()