import pandas as pd
import numpy as np
import scipy.sparse as sp
from utils import *


def load_data(args):
    """
    Loads dataset based on the specified data type.

    Args:
        args: Object containing configuration parameters, including the dataset type.

    Returns:
        Tuple containing adjacency matrix, drug fingerprints, expression data, null mask, positive edge count, and updated args.

    Raises:
        NotImplementedError: If the specified dataset is not supported.
    """
    if args.data == 'gdsc':
        return _load_gdsc(args)
    elif args.data == 'ccle':
        return _load_ccle(args)
    elif args.data == 'pdx':
        return _load_pdx(args)
    elif args.data == 'tcga':
        return _load_tcga(args)
    else:
        raise NotImplementedError(f"Dataset {args.data} is not supported.")


def _load_gdsc(args):
    """
    Loads GDSC dataset, including cell-drug response, drug fingerprints, gene expression, and null mask.

    Args:
        args: Configuration object to be updated with dataset-specific parameters.

    Returns:
        Tuple of response matrix, drug fingerprints, expression data, null mask, positive edge count, and updated args.
    """
    args.alpha = 0.25
    args.layer_size = [512, 512]

    # Load drug fingerprints
    drug_fingerprints = [
        pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/ECFP6_fingerprints_GDSC.csv", index_col=0).values.astype(np.float32)
    ]

    # Load response, expression, and null mask data
    res = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/cell_drug_binary.csv", index_col=0).values.astype(np.float32)
    exprs = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/cell_gene/merged_file_GDSC.csv", index_col=0).values.astype(np.float32)
    null_mask = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/null_mask.csv", index_col=0).values.astype(np.float32)
    pos_num = sp.coo_matrix(res).data.shape[0]

    return res, drug_fingerprints, exprs, null_mask, pos_num, args


def _load_ccle(args):
    """
    Loads CCLE dataset, including cell-drug response, drug fingerprints, gene expression, and null mask.

    Args:
        args: Configuration object to be updated with dataset-specific parameters.

    Returns:
        Tuple of response matrix, drug fingerprints, expression data, null mask, positive edge count, and updated args.
    """
    args.alpha = 0.45
    args.layer_size = [512, 512]

    # Load drug fingerprints
    drug_fingerprints = [
        pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/CCLE/ECFP6_fingerprints.csv", index_col=0).values.astype(np.float32)
    ]

    # Load response and expression data, initialize null mask
    res = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/CCLE/cell_drug_binary.csv", index_col=0).values.astype(np.float32)
    exprs = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/CCLE/merged_file.csv", index_col=0).values.astype(np.float32)
    null_mask = np.zeros(res.shape, dtype=np.float32)
    pos_num = sp.coo_matrix(res).data.shape[0]

    return res, drug_fingerprints, exprs, null_mask, pos_num, args


def _load_pdx(args):
    """
    Loads PDX dataset by merging GDSC and PDX data, aligning gene expression by common genes.

    Args:
        args: Configuration object to be updated with dataset-specific parameters.

    Returns:
        Tuple of merged response matrix, drug fingerprints, merged expression data, merged null mask, training row count, and updated args.
    """
    args.alpha = 0.15
    args.layer_size = [1024, 1024]

    # Load response matrices
    gdsc_res = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/cell_drug_binary.csv", index_col=0).values.astype(np.float32)
    pdx_res = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/PDX/pdx_response.csv", index_col=0).values.astype(np.float32)
    res = np.concatenate((gdsc_res, pdx_res), axis=0)
    train_row = gdsc_res.shape[0]

    # Load drug fingerprints
    drug_finger = [
        pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/ECFP6_fingerprints_GDSC.csv", index_col=0).values.astype(np.float32)
    ]

    # Load and align gene expression data
    gdsc_exprs_df = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/cell_gene/merged_file_GDSC.csv", index_col=0)
    pdx_exprs_df = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/PDX/pdx_exprs.csv", index_col=0)
    common_genes = gdsc_exprs_df.columns.intersection(pdx_exprs_df.columns)
    gdsc_exprs_filtered = gdsc_exprs_df[common_genes].values.astype(np.float32)
    pdx_exprs_filtered = pdx_exprs_df[common_genes].values.astype(np.float32)
    exprs = np.concatenate((gdsc_exprs_filtered, pdx_exprs_filtered), axis=0)

    # Load and merge null masks
    gdsc_null_mask = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/null_mask.csv", index_col=0).values.astype(np.float32)
    pdx_null_mask = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/PDX/pdx_null_mask.csv", index_col=0).values.astype(np.float32)
    null_mask = np.concatenate((gdsc_null_mask, pdx_null_mask), axis=0)

    return res, drug_finger, exprs, null_mask, train_row, args


def _load_tcga(args):
    """
    Loads TCGA dataset by merging GDSC and TCGA data, aligning gene expression by common genes.

    Args:
        args: Configuration object to be updated with dataset-specific parameters.

    Returns:
        Tuple of merged response matrix, drug fingerprints, merged expression data, merged null mask, training row count, and updated args.
    """
    args.alpha = 0.1
    args.layer_size = [1024, 1024]

    # Load response matrices
    gdsc_res = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/cell_drug_binary.csv", index_col=0).values.astype(np.float32)
    tcga_res = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/TCGA/patient_drug_binary.csv", index_col=0).values.astype(np.float32)
    res = np.concatenate((gdsc_res, tcga_res), axis=0)
    train_row = gdsc_res.shape[0]

    # Load drug fingerprints
    drug_finger = [
        pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/ECFP6_fingerprints_GDSC.csv", index_col=0).values.astype(np.float32)
    ]

    # Load and align gene expression data
    gdsc_exprs = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/cell_gene/merged_file_GDSC.csv", index_col=0)
    patient_gene = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/TCGA/tcga_gene_exprs.csv", index_col=0)
    common_genes = gdsc_exprs.columns.intersection(patient_gene.columns)
    gdsc_exprs_filtered = gdsc_exprs[common_genes].values.astype(np.float32)
    tcga_exprs_filtered = patient_gene[common_genes].values.astype(np.float32)
    exprs = np.concatenate((gdsc_exprs_filtered, tcga_exprs_filtered), axis=0)

    # Load and merge null masks
    gdsc_null_mask = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/GDSC/null_mask.csv", index_col=0).values.astype(np.float32)
    tcga_null_mask = pd.read_csv("/media/external_16TB_1/ali_kianfar/Data/TCGA/null_mask.csv", index_col=0).values.astype(np.float32)
    null_mask = np.concatenate((gdsc_null_mask, tcga_null_mask), axis=0)

    return res, drug_finger, exprs, null_mask, train_row, args