# main.py
import argparse
import numpy as np
import torch
from sklearn.metrics import roc_auc_score, average_precision_score
from typing import Dict, List, Tuple
from model import DeepTraCDR, Optimizer
from utils import evaluate_auc
from data_sampler import ExterSampler
from data_loader import load_data
from torch.optim.lr_scheduler import OneCycleLR

def parse_arguments() -> argparse.Namespace:
    """
    Parse command-line arguments for the DeepTraCDR model training pipeline.
    
    Returns:
        argparse.Namespace: Parsed arguments.
    """
    parser = argparse.ArgumentParser(description="DeepTraCDR Advanced: Graph-based Neural Network for Drug Response Prediction")
    parser.add_argument('--device', type=str, default="cuda:0" if torch.cuda.is_available() else "cpu",
                        help="Device to run the model on (cuda:0 or cpu)")
    parser.add_argument('--data', type=str, default='tcga', help="Dataset to use (e.g., tcga)")
    parser.add_argument('--wd', type=float, default=1e-7, help="Weight decay for optimizer")
    parser.add_argument('--layer_size', nargs='+', type=int, default=[512],
                        help="List of layer sizes for the GCN model")
    parser.add_argument('--gamma', type=float, default=20.0, help="Gamma parameter for model")
    parser.add_argument('--epochs', type=int, default=1000, help="Number of training epochs")
    parser.add_argument('--test_freq', type=int, default=50, help="Frequency of evaluation during training")
    parser.add_argument('--lr', type=float, default=0.0005, help="Learning rate for optimizer")
    return parser.parse_args()

def initialize_data(args: argparse.Namespace) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, argparse.Namespace]:
    """
    Load and preprocess the dataset for training.
    
    Args:
        args (argparse.Namespace): Command-line arguments.
        
    Returns:
        Tuple containing adjacency matrix, drug fingerprints, expression data, null mask, positive sample count, and args.
    """
    try:
        full_adj, drug_fingerprints, exprs, null_mask, pos_num, args = load_data(args)
        print(f"Data loaded successfully:")
        print(f"  - Adjacency matrix shape: {full_adj.shape}")
        print(f"  - Expression data shape: {exprs.shape}")
        print(f"  - Null mask shape: {null_mask.shape}")
        print(f"  - Drug fingerprints shape: {drug_fingerprints.shape}")
        return full_adj, drug_fingerprints, exprs, null_mask, pos_num, args
    except Exception as e:
        raise RuntimeError(f"Failed to load data: {str(e)}")

def convert_to_tensor(data: np.ndarray, device: str) -> torch.Tensor:
    """
    Convert a NumPy array to a PyTorch tensor and move it to the specified device.
    
    Args:
        data (np.ndarray): Input NumPy array.
        device (str): Target device (e.g., 'cuda:0' or 'cpu').
        
    Returns:
        torch.Tensor: Tensor on the specified device.
    """
    if isinstance(data, np.ndarray):
        return torch.from_numpy(data).float().to(device)
    return data.float().to(device)

def train_single_fold(
    fold_idx: int,
    full_adj: torch.Tensor,
    exprs: torch.Tensor,
    drug_fingerprints: torch.Tensor,
    null_mask: torch.Tensor,
    pos_num: int,
    args: argparse.Namespace
) -> Tuple[float, float]:
    """
    Train the DeepTraCDR model for a single fold and return evaluation metrics.
    
    Args:
        fold_idx (int): Current fold index.
        full_adj (torch.Tensor): Adjacency matrix.
        exprs (torch.Tensor): Gene expression data.
        drug_fingerprints (torch.Tensor): Drug fingerprint data.
        null_mask (torch.Tensor): Null mask for sampling.
        pos_num (int): Number of positive samples.
        args (argparse.Namespace): Command-line arguments.
        
    Returns:
        Tuple[float, float]: Best AUC and AUPRC for the fold.
    """
    # Define train/test split
    train_index = np.arange(pos_num)
    test_index = np.arange(full_adj.shape[0] - pos_num) + pos_num

    # Initialize sampler
    sampler = ExterSampler(full_adj, null_mask, train_index, test_index)

    # Initialize model
    model = DeepTraCDR(
        adj_mat=full_adj,
        cell_exprs=exprs,
        drug_finger=drug_fingerprints,
        layer_size=args.layer_size,
        gamma=args.gamma,
        device=args.device
    )

    # Initialize optimizer
    optimizer = Optimizer(
        model=model,
        train_data=sampler.train_data,
        test_data=sampler.test_data,
        test_mask=sampler.test_mask,
        train_mask=sampler.train_mask,
        adj_matrix=full_adj,
        evaluate_fun=evaluate_auc,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        test_freq=args.test_freq,
        device=args.device
    )

    # Train model and collect metrics
    _, _, best_auc, best_auprc = optimizer.train()
    print(f"Fold {fold_idx + 1}: AUC={best_auc:.4f}, AUPRC={best_auprc:.4f}")
    return best_auc, best_auprc

def summarize_metrics(metrics: Dict[str, List[float]]) -> None:
    """
    Summarize metrics across all folds by computing mean and standard deviation.
    
    Args:
        metrics (Dict[str, List[float]]): Dictionary of metrics (e.g., {'auc': [...], 'auprc': [...]})
    """
    print("\nFinal Average Metrics:")
    for metric, values in metrics.items():
        mean_val = np.mean(values)
        std_val = np.std(values)
        print(f"{metric.upper()}: {mean_val:.4f} ± {std_val:.4f}")

def main():
    """
    Main function to orchestrate the DeepTraCDR training and evaluation pipeline.
    """
    # Set precision for matrix multiplications
    torch.set_float32_matmul_precision('high')

    # Parse arguments
    args = parse_arguments()

    # Load and preprocess data
    full_adj, drug_fingerprints, exprs, null_mask, pos_num, args = initialize_data(args)

    # Convert adjacency matrix to tensor
    full_adj = convert_to_tensor(full_adj, args.device)

    # Initialize metrics storage
    metrics = {'auc': [], 'auprc': []}
    n_folds = 25

    # Perform k-fold cross-validation
    for fold_idx in range(n_folds):
        best_auc, best_auprc = train_single_fold(
            fold_idx, full_adj, exprs, drug_fingerprints, null_mask, pos_num, args
        )
        metrics['auc'].append(best_auc)
        metrics['auprc'].append(best_auprc)

    # Summarize results
    summarize_metrics(metrics)

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"Error occurred: {str(e)}")
        raise