You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cross_validation.py 9.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import argparse
  2. import os
  3. import logging
  4. import time
  5. import pickle
  6. import torch
  7. import torch.nn as nn
  8. from datetime import datetime
  9. time_str = str(datetime.now().strftime('%y%m%d%H%M'))
  10. from model.datasets import FastSynergyDataset, FastTensorDataLoader
  11. from model.models import MLP
  12. from model.utils import save_args, save_best_model, find_best_model, arg_min, random_split_indices, calc_stat, conf_inv
  13. from const import SYNERGY_FILE, DRUG2ID_FILE, DRUG_FEAT_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR
  14. def eval_model(model, optimizer, loss_func, train_data, test_data,
  15. batch_size, n_epoch, patience, gpu_id, mdl_dir):
  16. tr_indices, es_indices = random_split_indices(len(train_data), test_rate=0.1)
  17. train_loader = FastTensorDataLoader(*train_data.tensor_samples(tr_indices), batch_size=batch_size, shuffle=True)
  18. valid_loader = FastTensorDataLoader(*train_data.tensor_samples(es_indices), batch_size=len(es_indices) // 4)
  19. test_loader = FastTensorDataLoader(*test_data.tensor_samples(), batch_size=len(test_data) // 4)
  20. train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
  21. sl=True, mdl_dir=mdl_dir)
  22. test_loss = eval_epoch(model, test_loader, loss_func, gpu_id)
  23. test_loss /= len(test_data)
  24. return test_loss
  25. def step_batch(model, batch, loss_func, gpu_id=None, train=True):
  26. drug1_feats, drug2_feats, cell_feats, y_true = batch
  27. if gpu_id is not None:
  28. drug1_feats, drug2_feats, cell_feats, y_true = drug1_feats.cuda(gpu_id), drug2_feats.cuda(gpu_id), \
  29. cell_feats.cuda(gpu_id), y_true.cuda(gpu_id)
  30. if train:
  31. y_pred = model(drug1_feats, drug2_feats, cell_feats)
  32. else:
  33. yp1 = model(drug1_feats, drug2_feats, cell_feats)
  34. yp2 = model(drug2_feats, drug1_feats, cell_feats)
  35. y_pred = (yp1 + yp2) / 2
  36. loss = loss_func(y_pred, y_true)
  37. return loss
  38. def train_epoch(model, loader, loss_func, optimizer, gpu_id=None):
  39. model.train()
  40. epoch_loss = 0
  41. for _, batch in enumerate(loader):
  42. optimizer.zero_grad()
  43. loss = step_batch(model, batch, loss_func, gpu_id)
  44. loss.backward()
  45. optimizer.step()
  46. epoch_loss += loss.item()
  47. return epoch_loss
  48. def eval_epoch(model, loader, loss_func, gpu_id=None):
  49. model.eval()
  50. with torch.no_grad():
  51. epoch_loss = 0
  52. for batch in loader:
  53. loss = step_batch(model, batch, loss_func, gpu_id, train=False)
  54. epoch_loss += loss.item()
  55. return epoch_loss
  56. def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
  57. sl=False, mdl_dir=None):
  58. min_loss = float('inf')
  59. angry = 0
  60. for epoch in range(1, n_epoch + 1):
  61. trn_loss = train_epoch(model, train_loader, loss_func, optimizer, gpu_id)
  62. trn_loss /= train_loader.dataset_len
  63. val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id)
  64. val_loss /= valid_loader.dataset_len
  65. if val_loss < min_loss:
  66. angry = 0
  67. min_loss = val_loss
  68. if sl:
  69. # TODO: save best model in case of need
  70. save_best_model(model.state_dict(), mdl_dir, epoch, keep=1)
  71. pass
  72. else:
  73. angry += 1
  74. if angry >= patience:
  75. break
  76. if sl:
  77. model.load_state_dict(torch.load(find_best_model(mdl_dir)))
  78. return min_loss
  79. def create_model(data, hidden_size, gpu_id=None):
  80. # TODO: use our own MLP model
  81. model = MLP(data.cell_feat_len() + 2 * data.drug_feat_len(), hidden_size)
  82. if gpu_id is not None:
  83. model = model.cuda(gpu_id)
  84. return model
  85. def cv(args, out_dir):
  86. save_args(args, os.path.join(out_dir, 'args.json'))
  87. test_loss_file = os.path.join(out_dir, 'test_loss.pkl')
  88. if torch.cuda.is_available() and (args.gpu is not None):
  89. gpu_id = args.gpu
  90. else:
  91. gpu_id = None
  92. n_folds = 5
  93. n_delimiter = 60
  94. loss_func = nn.MSELoss(reduction='sum')
  95. test_losses = []
  96. for test_fold in range(n_folds):
  97. outer_trn_folds = [x for x in range(n_folds) if x != test_fold]
  98. logging.info("Outer: train folds {}, test folds {}".format(outer_trn_folds, test_fold))
  99. logging.info("-" * n_delimiter)
  100. param = []
  101. losses = []
  102. for hs in args.hidden:
  103. for lr in args.lr:
  104. param.append((hs, lr))
  105. logging.info("Hidden size: {} | Learning rate: {}".format(hs, lr))
  106. ret_vals = []
  107. for valid_fold in outer_trn_folds:
  108. inner_trn_folds = [x for x in outer_trn_folds if x != valid_fold]
  109. valid_folds = [valid_fold]
  110. train_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, DRUG_FEAT_FILE, CELL_FEAT_FILE,
  111. SYNERGY_FILE, use_folds=inner_trn_folds)
  112. valid_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, DRUG_FEAT_FILE, CELL_FEAT_FILE,
  113. SYNERGY_FILE, use_folds=valid_folds, train=False)
  114. train_loader = FastTensorDataLoader(*train_data.tensor_samples(), batch_size=args.batch,
  115. shuffle=True)
  116. valid_loader = FastTensorDataLoader(*valid_data.tensor_samples(), batch_size=len(valid_data) // 4)
  117. model = create_model(train_data, hs, gpu_id)
  118. optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  119. logging.info(
  120. "Start inner loop: train folds {}, valid folds {}".format(inner_trn_folds, valid_folds))
  121. ret = train_model(model, optimizer, loss_func, train_loader, valid_loader,
  122. args.epoch, args.patience, gpu_id, sl=False)
  123. ret_vals.append(ret)
  124. del model
  125. inner_loss = sum(ret_vals) / len(ret_vals)
  126. logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss))
  127. logging.info("-" * n_delimiter)
  128. losses.append(inner_loss)
  129. torch.cuda.empty_cache()
  130. time.sleep(10)
  131. min_ls, min_idx = arg_min(losses)
  132. best_hs, best_lr = param[min_idx]
  133. train_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, DRUG_FEAT_FILE, CELL_FEAT_FILE,
  134. SYNERGY_FILE, use_folds=outer_trn_folds)
  135. test_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, DRUG_FEAT_FILE, CELL_FEAT_FILE,
  136. SYNERGY_FILE, use_folds=[test_fold], train=False)
  137. model = create_model(train_data, best_hs, gpu_id)
  138. optimizer = torch.optim.Adam(model.parameters(), lr=best_lr)
  139. logging.info("Best hidden size: {} | Best learning rate: {}".format(best_hs, best_lr))
  140. logging.info("Start test on fold {}.".format([test_fold]))
  141. test_mdl_dir = os.path.join(out_dir, str(test_fold))
  142. if not os.path.exists(test_mdl_dir):
  143. os.makedirs(test_mdl_dir)
  144. test_loss = eval_model(model, optimizer, loss_func, train_data, test_data,
  145. args.batch, args.epoch, args.patience, gpu_id, test_mdl_dir)
  146. test_losses.append(test_loss)
  147. logging.info("Test loss: {:.4f}".format(test_loss))
  148. logging.info("*" * n_delimiter + '\n')
  149. logging.info("CV completed")
  150. with open(test_loss_file, 'wb') as f:
  151. pickle.dump(test_losses, f)
  152. mu, sigma = calc_stat(test_losses)
  153. logging.info("MSE: {:.4f} ± {:.4f}".format(mu, sigma))
  154. lo, hi = conf_inv(mu, sigma, len(test_losses))
  155. logging.info("Confidence interval: [{:.4f}, {:.4f}]".format(lo, hi))
  156. rmse_loss = [x ** 0.5 for x in test_losses]
  157. mu, sigma = calc_stat(rmse_loss)
  158. logging.info("RMSE: {:.4f} ± {:.4f}".format(mu, sigma))
  159. def main():
  160. parser = argparse.ArgumentParser()
  161. parser.add_argument('--epoch', type=int, default=500, help="n epoch")
  162. parser.add_argument('--batch', type=int, default=256, help="batch size")
  163. parser.add_argument('--gpu', type=int, default=None, help="cuda device")
  164. parser.add_argument('--patience', type=int, default=100, help='patience for early stop')
  165. parser.add_argument('--suffix', type=str, default=time_str, help="model dir suffix")
  166. parser.add_argument('--hidden', type=int, nargs='+', default=[2048, 4096, 8192], help="hidden size")
  167. parser.add_argument('--lr', type=float, nargs='+', default=[1e-3, 1e-4, 1e-5], help="learning rate")
  168. args = parser.parse_args()
  169. out_dir = os.path.join(OUTPUT_DIR, 'cv_{}'.format(args.suffix))
  170. if not os.path.exists(out_dir):
  171. os.makedirs(out_dir)
  172. log_file = os.path.join(out_dir, 'cv.log')
  173. logging.basicConfig(filename=log_file,
  174. format='%(asctime)s %(message)s',
  175. datefmt='[%Y-%m-%d %H:%M:%S]',
  176. level=logging.INFO)
  177. cv(args, out_dir)
  178. if __name__ == '__main__':
  179. main()