DeepTraCDR: Prediction Cancer Drug Response using multimodal deep learning with Transformers
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_random.py 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import argparse
  2. import numpy as np
  3. import pandas as pd
  4. import torch
  5. from sklearn.model_selection import KFold
  6. from DeepTraCDR_model import DeepTraCDR, Optimizer
  7. from utils import evaluate_auc
  8. from data_sampler import RandomSampler
  9. from data_loader import load_data
  10. def parse_arguments():
  11. """
  12. Parses command-line arguments for the DeepTraCDR model.
  13. Returns:
  14. Parsed arguments containing model and training configurations.
  15. """
  16. parser = argparse.ArgumentParser(description="DeepTraCDR: Graph-based Cell-Drug Interaction Prediction")
  17. parser.add_argument('-device', type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
  18. help="Device to run the model on (cuda:0 or cpu)")
  19. parser.add_argument('-data', type=str, default='ccle', help="Dataset to use (default: ccle)")
  20. parser.add_argument('--wd', type=float, default=1e-4, help="Weight decay for optimizer")
  21. parser.add_argument('--layer_size', nargs='+', type=int, default=[512], help="Layer sizes for the model")
  22. parser.add_argument('--gamma', type=float, default=15, help="Gamma parameter for decoder")
  23. parser.add_argument('--epochs', type=int, default=1000, help="Number of training epochs")
  24. parser.add_argument('--test_freq', type=int, default=50, help="Frequency of evaluation during training")
  25. parser.add_argument('--lr', type=float, default=0.0005, help="Learning rate for optimizer")
  26. return parser.parse_args()
  27. def main():
  28. """Main function to execute the DeepTraCDR training and evaluation pipeline."""
  29. args = parse_arguments()
  30. # Load dataset
  31. full_adj, drug_fingerprints, exprs, null_mask, pos_num, args = load_data(args)
  32. # Log data shapes for debugging
  33. print(f"Original adj_mat shape: {full_adj.shape}")
  34. print("\n--- Data Shapes ---")
  35. print(f"Expression data shape: {exprs.shape}")
  36. print(f"Null mask shape: {null_mask.shape}")
  37. # Convert adjacency matrix to torch tensor if necessary
  38. if isinstance(full_adj, np.ndarray):
  39. full_adj = torch.from_numpy(full_adj).float()
  40. print(f"Converted adj_mat shape: {full_adj.shape}")
  41. # Initialize k-fold cross-validation
  42. k = 5
  43. n_kfolds = 5
  44. all_metrics = {
  45. 'auc': [], 'auprc': [], 'precision': [], 'recall': [], 'f1_score': []
  46. }
  47. # Perform k-fold cross-validation
  48. for n_kfold in range(n_kfolds):
  49. kfold = KFold(n_splits=k, shuffle=True, random_state=n_kfold)
  50. for fold, (train_idx, test_idx) in enumerate(kfold.split(np.arange(pos_num))):
  51. # Initialize data sampler
  52. sampler = RandomSampler(full_adj, train_idx, test_idx, null_mask)
  53. # Initialize model
  54. model = DeepTraCDR(
  55. adj_mat=full_adj,
  56. cell_exprs=exprs,
  57. drug_finger=drug_fingerprints,
  58. layer_size=args.layer_size,
  59. gamma=args.gamma,
  60. device=args.device
  61. )
  62. # Initialize optimizer
  63. opt = Optimizer(
  64. model=model,
  65. train_data=sampler.train_data,
  66. test_data=sampler.test_data,
  67. test_mask=sampler.test_mask,
  68. train_mask=sampler.train_mask,
  69. adj_matrix=full_adj,
  70. evaluate_fun=evaluate_auc,
  71. lr=args.lr,
  72. wd=args.wd,
  73. epochs=args.epochs,
  74. test_freq=args.test_freq,
  75. device=args.device
  76. )
  77. # Train model and collect metrics
  78. true, pred, best_auc, best_auprc, best_precision, best_recall, best_f1 = opt.train()
  79. # Store metrics
  80. all_metrics['auc'].append(best_auc)
  81. all_metrics['auprc'].append(best_auprc)
  82. all_metrics['precision'].append(best_precision)
  83. all_metrics['recall'].append(best_recall)
  84. all_metrics['f1_score'].append(best_f1)
  85. print(f"Fold {n_kfold * k + fold + 1}: AUC={best_auc:.4f}, AUPRC={best_auprc:.4f}, "
  86. f"Precision={best_precision:.4f}, Recall={best_recall:.4f}, F1-Score={best_f1:.4f}")
  87. # Compute and display final metrics
  88. print("\nFinal Average Metrics:")
  89. for metric, values in all_metrics.items():
  90. mean = np.mean(values)
  91. std = np.std(values)
  92. print(f"{metric.upper()}: {mean:.4f} ± {std:.4f}")
  93. if __name__ == "__main__":
  94. torch.set_float32_matmul_precision('high')
  95. main()