DeepTraCDR: Prediction Cancer Drug Response using multimodal deep learning with Transformers
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_new.py 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import argparse
  2. import torch
  3. import numpy as np
  4. from sklearn.metrics import roc_auc_score
  5. from data_loader import load_data
  6. from new_model import run_DeepTraCDR
  7. def parse_arguments() -> argparse.Namespace:
  8. """
  9. Parses command-line arguments for running the DeepTraCDR model.
  10. Returns:
  11. argparse.Namespace: Parsed arguments containing configuration parameters.
  12. """
  13. parser = argparse.ArgumentParser(description="Run DeepTraCDR model for cell-drug interaction prediction.")
  14. parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
  15. help="Device to run the model on (e.g., 'cuda:0' or 'cpu').")
  16. parser.add_argument("--data", type=str, default="gdsc",
  17. help="Dataset to use (e.g., 'gdsc').")
  18. parser.add_argument("--wd", type=float, default=1e-5,
  19. help="Weight decay for the optimizer.")
  20. parser.add_argument("--layer_size", nargs="+", type=int, default=[512],
  21. help="List of layer sizes for the model.")
  22. parser.add_argument("--gamma", type=float, default=15.0,
  23. help="Gamma parameter for decoder scaling.")
  24. parser.add_argument("--epochs", type=int, default=1000,
  25. help="Number of training epochs.")
  26. parser.add_argument("--test_freq", type=int, default=50,
  27. help="Frequency of evaluation during training.")
  28. parser.add_argument("--lr", type=float, default=0.0001,
  29. help="Learning rate for the optimizer.")
  30. return parser.parse_args()
  31. def main():
  32. """
  33. Main function to run the DeepTraCDR model, including data loading, model training,
  34. and evaluation across specified targets.
  35. """
  36. # Parse command-line arguments
  37. args = parse_arguments()
  38. # Load dataset
  39. adj_matrix, drug_fingerprints, cell_expressions, null_mask, pos_num, args = load_data(args)
  40. # Compute interaction sums for filtering
  41. cell_interaction_sums = np.sum(adj_matrix, axis=1) # Sum of interactions per cell line
  42. drug_interaction_sums = np.sum(adj_matrix, axis=0) # Sum of interactions per drug
  43. # Log dataset and configuration details
  44. print(f"\n{'='*40}")
  45. print(f"Dataset: {args.data} | Adjacency Matrix Shape: {adj_matrix.shape}")
  46. print(f"Device: {args.device} | Layer Sizes: {args.layer_size} | Learning Rate: {args.lr}")
  47. print(f"{'='*40}\n")
  48. # Define target dimension (0 for cell lines)
  49. target_dimensions = [0] # Only process cell lines
  50. adj_matrix = torch.from_numpy(adj_matrix).float() if isinstance(adj_matrix, np.ndarray) else adj_matrix
  51. num_folds = 1 # Number of cross-validation folds
  52. total_targets_processed = 0
  53. auc_scores = []
  54. auprc_scores = []
  55. # Process each target dimension
  56. for dim in target_dimensions:
  57. dim_name = "Cell Lines" if dim == 0 else "Drugs"
  58. # Filter valid targets with at least 10 interactions
  59. valid_indices = [
  60. idx for idx in range(adj_matrix.shape[dim])
  61. if (drug_interaction_sums[idx] >= 10 if dim == 1 else cell_interaction_sums[idx] >= 10)
  62. ]
  63. print(f"Processing {dim_name} ({len(valid_indices)} valid targets):")
  64. # Process each valid target
  65. for target_index in valid_indices:
  66. total_targets_processed += 1
  67. print(f" Target {target_index} - ", end="", flush=True)
  68. # Run model for each fold
  69. for fold in range(num_folds):
  70. best_auc, best_auprc = run_DeepTraCDR(
  71. cell_exprs=cell_expressions,
  72. drug_finger=drug_fingerprints,
  73. res_mat=adj_matrix,
  74. null_mask=null_mask,
  75. target_dim=dim,
  76. target_index=target_index,
  77. evaluate_fun=roc_auc_score,
  78. args=args
  79. )
  80. auc_scores.append(best_auc)
  81. auprc_scores.append(best_auprc)
  82. print(f"AUC: {best_auc:.4f}, AUPRC: {best_auprc:.4f}")
  83. # Compute and display average metrics
  84. mean_auc = np.mean(auc_scores)
  85. std_auc = np.std(auc_scores)
  86. mean_auprc = np.mean(auprc_scores)
  87. std_auprc = np.std(auprc_scores)
  88. print(f"\n{'='*40}")
  89. print(f"Average AUC: {mean_auc:.4f} ± {std_auc:.4f}")
  90. print(f"Average AUPRC: {mean_auprc:.4f} ± {std_auprc:.4f}")
  91. print(f"{'='*40}\n")
  92. if __name__ == "__main__":
  93. main()