123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import argparse
- import torch
- import numpy as np
- from sklearn.metrics import roc_auc_score
- from data_loader import load_data
- from new_model import run_DeepTraCDR
-
-
- def parse_arguments() -> argparse.Namespace:
- """
- Parses command-line arguments for running the DeepTraCDR model.
-
- Returns:
- argparse.Namespace: Parsed arguments containing configuration parameters.
- """
- parser = argparse.ArgumentParser(description="Run DeepTraCDR model for cell-drug interaction prediction.")
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
- help="Device to run the model on (e.g., 'cuda:0' or 'cpu').")
- parser.add_argument("--data", type=str, default="gdsc",
- help="Dataset to use (e.g., 'gdsc').")
- parser.add_argument("--wd", type=float, default=1e-5,
- help="Weight decay for the optimizer.")
- parser.add_argument("--layer_size", nargs="+", type=int, default=[512],
- help="List of layer sizes for the model.")
- parser.add_argument("--gamma", type=float, default=15.0,
- help="Gamma parameter for decoder scaling.")
- parser.add_argument("--epochs", type=int, default=1000,
- help="Number of training epochs.")
- parser.add_argument("--test_freq", type=int, default=50,
- help="Frequency of evaluation during training.")
- parser.add_argument("--lr", type=float, default=0.0001,
- help="Learning rate for the optimizer.")
- return parser.parse_args()
-
- def main():
- """
- Main function to run the DeepTraCDR model, including data loading, model training,
- and evaluation across specified targets.
- """
- # Parse command-line arguments
- args = parse_arguments()
-
- # Load dataset
- adj_matrix, drug_fingerprints, cell_expressions, null_mask, pos_num, args = load_data(args)
- # Compute interaction sums for filtering
- cell_interaction_sums = np.sum(adj_matrix, axis=1) # Sum of interactions per cell line
- drug_interaction_sums = np.sum(adj_matrix, axis=0) # Sum of interactions per drug
-
- # Log dataset and configuration details
- print(f"\n{'='*40}")
- print(f"Dataset: {args.data} | Adjacency Matrix Shape: {adj_matrix.shape}")
- print(f"Device: {args.device} | Layer Sizes: {args.layer_size} | Learning Rate: {args.lr}")
- print(f"{'='*40}\n")
-
- # Define target dimension (0 for cell lines)
- target_dimensions = [0] # Only process cell lines
- adj_matrix = torch.from_numpy(adj_matrix).float() if isinstance(adj_matrix, np.ndarray) else adj_matrix
-
- num_folds = 1 # Number of cross-validation folds
- total_targets_processed = 0
- auc_scores = []
- auprc_scores = []
-
- # Process each target dimension
- for dim in target_dimensions:
- dim_name = "Cell Lines" if dim == 0 else "Drugs"
- # Filter valid targets with at least 10 interactions
- valid_indices = [
- idx for idx in range(adj_matrix.shape[dim])
- if (drug_interaction_sums[idx] >= 10 if dim == 1 else cell_interaction_sums[idx] >= 10)
- ]
-
- print(f"Processing {dim_name} ({len(valid_indices)} valid targets):")
-
- # Process each valid target
- for target_index in valid_indices:
- total_targets_processed += 1
- print(f" Target {target_index} - ", end="", flush=True)
-
- # Run model for each fold
- for fold in range(num_folds):
- best_auc, best_auprc = run_DeepTraCDR(
- cell_exprs=cell_expressions,
- drug_finger=drug_fingerprints,
- res_mat=adj_matrix,
- null_mask=null_mask,
- target_dim=dim,
- target_index=target_index,
- evaluate_fun=roc_auc_score,
- args=args
- )
- auc_scores.append(best_auc)
- auprc_scores.append(best_auprc)
-
- print(f"AUC: {best_auc:.4f}, AUPRC: {best_auprc:.4f}")
-
- # Compute and display average metrics
- mean_auc = np.mean(auc_scores)
- std_auc = np.std(auc_scores)
- mean_auprc = np.mean(auprc_scores)
- std_auprc = np.std(auprc_scores)
-
- print(f"\n{'='*40}")
- print(f"Average AUC: {mean_auc:.4f} ± {std_auc:.4f}")
- print(f"Average AUPRC: {mean_auprc:.4f} ± {std_auprc:.4f}")
- print(f"{'='*40}\n")
-
-
- if __name__ == "__main__":
- main()
|