DeepTraCDR: Prediction Cancer Drug Response using multimodal deep learning with Transformers
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_random.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import argparse
  2. import numpy as np
  3. import torch
  4. from sklearn.model_selection import KFold
  5. from Regression.DeepTraCDR_model import DeepTraCDR, Optimizer
  6. from data_sampler import RegressionSampler
  7. from data_loader import load_data
  8. def parse_arguments() -> argparse.Namespace:
  9. """
  10. Parses command-line arguments for the DeepTraCDR regression task.
  11. Returns:
  12. Parsed arguments as a Namespace object.
  13. """
  14. parser = argparse.ArgumentParser(description="DeepTraCDR Regression Task")
  15. parser.add_argument('-device', type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
  16. help="Device to run the model on (e.g., 'cuda:0' or 'cpu')")
  17. parser.add_argument('-data', type=str, default='gdsc', help="Dataset to use (default: gdsc)")
  18. parser.add_argument('--wd', type=float, default=1e-5, help="Weight decay for optimizer")
  19. parser.add_argument('--layer_size', nargs='+', type=int, default=[512], help="Layer sizes for the model")
  20. parser.add_argument('--gamma', type=float, default=15, help="Gamma parameter for decoder")
  21. parser.add_argument('--epochs', type=int, default=1000, help="Number of training epochs")
  22. parser.add_argument('--test_freq', type=int, default=50, help="Frequency of evaluation during training")
  23. parser.add_argument('--lr', type=float, default=0.0001, help="Learning rate for optimizer")
  24. parser.add_argument('--patience', type=int, default=20, help="Patience for early stopping")
  25. return parser.parse_args()
  26. def normalize_adj_matrix(adj_matrix: np.ndarray) -> torch.Tensor:
  27. """
  28. Normalizes the adjacency matrix using min-shift normalization and converts it to a torch tensor.
  29. Args:
  30. adj_matrix: Input adjacency matrix as a NumPy array.
  31. Returns:
  32. Normalized adjacency matrix as a torch tensor.
  33. """
  34. adj_matrix = adj_matrix - np.min(adj_matrix)
  35. if isinstance(adj_matrix, np.ndarray):
  36. adj_matrix = torch.from_numpy(adj_matrix).float()
  37. return adj_matrix
  38. def main():
  39. """
  40. Main function to run the DeepTraCDR regression task with k-fold cross-validation.
  41. """
  42. # Set precision for matrix multiplication
  43. torch.set_float32_matmul_precision('high')
  44. # Parse command-line arguments
  45. args = parse_arguments()
  46. # Load dataset
  47. full_adj, drug_fingerprints, exprs, null_mask, pos_num, args = load_data(args)
  48. print(f"Original full_adj shape: {full_adj.shape}")
  49. print(f"Normalized full_adj shape: {full_adj.shape}")
  50. print("\n--- Data Shapes ---")
  51. print(f"Expression data shape: {exprs.shape}")
  52. print(f"Null mask shape: {null_mask.shape}")
  53. # Normalize adjacency matrix
  54. full_adj = normalize_adj_matrix(full_adj)
  55. # Initialize k-fold cross-validation parameters
  56. k = 5
  57. n_kfolds = 5
  58. all_metrics = {'rmse': [], 'pcc': [], 'scc': []}
  59. # Perform k-fold cross-validation
  60. for n_kfold in range(n_kfolds):
  61. kfold = KFold(n_splits=k, shuffle=True, random_state=n_kfold)
  62. for fold, (train_idx, test_idx) in enumerate(kfold.split(np.arange(pos_num))):
  63. # Initialize data sampler
  64. sampler = RegressionSampler(full_adj, train_idx, test_idx, null_mask)
  65. # Initialize model
  66. model = DeepTraCDR(
  67. adj_mat=full_adj,
  68. cell_exprs=exprs,
  69. drug_finger=drug_fingerprints,
  70. layer_size=args.layer_size,
  71. gamma=args.gamma,
  72. device=args.device
  73. )
  74. # Initialize optimizer
  75. opt = Optimizer(
  76. model=model,
  77. train_data=sampler.train_data,
  78. test_data=sampler.test_data,
  79. test_mask=sampler.test_mask,
  80. train_mask=sampler.train_mask,
  81. adj_matrix=full_adj,
  82. lr=args.lr,
  83. wd=args.wd,
  84. epochs=args.epochs,
  85. test_freq=args.test_freq,
  86. device=args.device,
  87. patience=args.patience
  88. )
  89. # Train model and collect metrics
  90. true, pred, best_rmse, best_pcc, best_scc = opt.train()
  91. all_metrics['rmse'].append(best_rmse)
  92. all_metrics['pcc'].append(best_pcc)
  93. all_metrics['scc'].append(best_scc)
  94. print(f"Fold {n_kfold * k + fold + 1}: RMSE={best_rmse:.4f}, PCC={best_pcc:.4f}, SCC={best_scc:.4f}")
  95. # Compute and display final average metrics
  96. print("\nFinal Average Metrics:")
  97. for metric, values in all_metrics.items():
  98. mean = np.mean(values)
  99. std = np.std(values)
  100. print(f"{metric.upper()}: {mean:.4f} ± {std:.4f}")
  101. if __name__ == "__main__":
  102. main()