import os import time from typing import Tuple, List, Union, Optional import torch import numpy as np import pandas as pd import torch.nn as nn import seaborn as sns import pubchempy as pcp import scipy.sparse as sp from sklearn.metrics import roc_auc_score, average_precision_score import itertools as it import torch.nn.functional as F # ---------------------------------------------------------------------------- # Model Evaluation Functions # ---------------------------------------------------------------------------- def roc_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> float: """Calculate ROC-AUC score for binary classification.""" assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" return roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) def ap_score(true_data: torch.Tensor, predict_data: torch.Tensor) -> float: """Calculate Average Precision (area under Precision-Recall curve).""" assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" return average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) def f1_score_binary(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]: """Calculate F1 score and the optimal threshold for binary classification.""" thresholds = torch.unique(predict_data) n_samples = true_data.size(0) ones = torch.ones((thresholds.size(0), n_samples), device=true_data.device) zeros = torch.zeros((thresholds.size(0), n_samples), device=true_data.device) predict_value = torch.where(predict_data.view(1, -1) >= thresholds.view(-1, 1), ones, zeros) tpn = torch.sum(torch.where(predict_value == true_data.view(1, -1), ones, zeros), dim=1) tp = torch.sum(predict_value * true_data.view(1, -1), dim=1) scores = (2 * tp) / (n_samples + 2 * tp - tpn) max_f1_score = torch.max(scores) threshold = thresholds[torch.argmax(scores)] return max_f1_score.item(), threshold.item() def accuracy_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: """Calculate accuracy using the specified threshold.""" predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) correct = torch.sum(predict_value == true_data).float() return (correct / true_data.size(0)).item() def precision_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: """Calculate precision using the specified threshold.""" predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) tp = torch.sum(true_data * predict_value) fp = torch.sum((1 - true_data) * predict_value) return (tp / (tp + fp + 1e-8)).item() def recall_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: """Calculate recall using the specified threshold.""" predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) tp = torch.sum(true_data * predict_value) fn = torch.sum(true_data * (1 - predict_value)) return (tp / (tp + fn + 1e-8)).item() def mcc_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: """Calculate Matthews Correlation Coefficient (MCC) using the specified threshold.""" predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) true_neg = 1 - true_data predict_neg = 1 - predict_value tp = torch.sum(true_data * predict_value) tn = torch.sum(true_neg * predict_neg) fp = torch.sum(true_neg * predict_value) fn = torch.sum(true_data * predict_neg) denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + 1e-8) return ((tp * tn - fp * fn) / denominator).item() def evaluate_all(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, ...]: """Evaluate multiple metrics: ROC-AUC, AP, accuracy, F1, and MCC.""" assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" auc = roc_auc(true_data, predict_data) ap = ap_score(true_data, predict_data) f1, threshold = f1_score_binary(true_data, predict_data) acc = accuracy_binary(true_data, predict_data, threshold) precision = precision_binary(true_data, predict_data, threshold) recall = recall_binary(true_data, predict_data, threshold) mcc = mcc_binary(true_data, predict_data, threshold) return auc, ap, acc, f1, mcc, precision, recall def evaluate_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]: """Calculate ROC-AUC and Average Precision.""" assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" auc = roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) ap = average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) return auc, ap # ---------------------------------------------------------------------------- # Loss Functions # ---------------------------------------------------------------------------- def cross_entropy_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor: """Calculate masked binary cross-entropy loss.""" masked = masked.to(torch.bool) true_data = torch.masked_select(true_data, masked) pred_data = torch.masked_select(predict_data, masked) return nn.BCELoss()(pred_data, true_data) def mse_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor: """Calculate masked mean squared error loss.""" true_data = true_data * masked predict_data = predict_data * masked return nn.MSELoss()(predict_data, true_data) def prototypical_loss( cell_emb: torch.Tensor, drug_emb: torch.Tensor, adj_matrix: Union[torch.Tensor, np.ndarray], margin: float = 2.0 ) -> torch.Tensor: """Calculate prototypical loss for positive and negative pairs.""" if isinstance(adj_matrix, torch.Tensor): adj_matrix = sp.coo_matrix(adj_matrix.detach().cpu().numpy()) pos_pairs = torch.sum(cell_emb[adj_matrix.row] * drug_emb[adj_matrix.col], dim=1) n_pos = len(adj_matrix.row) cell_neg = torch.randint(0, cell_emb.size(0), (n_pos,), device=cell_emb.device) drug_neg = torch.randint(0, drug_emb.size(0), (n_pos,), device=drug_emb.device) neg_pairs = torch.sum(cell_emb[cell_neg] * drug_emb[drug_neg], dim=1) labels = torch.ones_like(pos_pairs, device=cell_emb.device) return F.margin_ranking_loss(pos_pairs, neg_pairs, labels, margin=margin) # ---------------------------------------------------------------------------- # Correlation and Normalization Functions # ---------------------------------------------------------------------------- def torch_z_normalized(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: """Apply z-normalization along the specified dimension.""" mean = tensor.mean(dim=1 - dim, keepdim=True) std = tensor.std(dim=1 - dim, keepdim=True) + 1e-8 return (tensor - mean) / std def torch_corr_x_y(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Compute correlation matrix between row vectors of two matrices.""" x_center = x - x.mean(dim=1, keepdim=True) y_center = y - y.mean(dim=1, keepdim=True) x_std = x.std(dim=1, keepdim=True) + 1e-8 y_std = y.std(dim=1, keepdim=True) + 1e-8 x_norm = x_center / x_std y_norm = y_center / y_std corr_matrix = x_norm @ y_norm.t() / (x.size(1) - 1) return corr_matrix # ---------------------------------------------------------------------------- # Distance and Similarity Functions # ---------------------------------------------------------------------------- def torch_euclidean_dist(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: """Calculate Euclidean distance between rows or columns of a tensor.""" tensor_mul = torch.mm(tensor.t(), tensor) if dim else torch.mm(tensor, tensor.t()) diag = torch.diag(tensor_mul) n_diag = diag.size(0) tensor_diag = diag.repeat(n_diag, 1) diag = diag.view(n_diag, -1) dist = torch.sqrt(tensor_diag + diag - 2 * tensor_mul) return dist def exp_similarity(tensor: torch.Tensor, sigma: torch.Tensor, normalize: bool = True) -> torch.Tensor: """Calculate exponential similarity based on Euclidean distance.""" if normalize: tensor = torch_z_normalized(tensor, dim=1) tensor_dist = torch_euclidean_dist(tensor, dim=0) return torch.exp(-tensor_dist / (2 * sigma.pow(2))) def full_kernel(exp_dist: torch.Tensor) -> torch.Tensor: """Calculate full kernel matrix from exponential similarity.""" n = exp_dist.shape[0] ones = torch.ones(n, n, device=exp_dist.device) diag = torch.diag(ones) mask_diag = (ones - diag) * exp_dist mask_diag_sum = mask_diag.sum(dim=1, keepdim=True) mask_diag = mask_diag / (2 * mask_diag_sum) + 0.5 * diag return mask_diag def sparse_kernel(exp_dist: torch.Tensor, k: int) -> torch.Tensor: """Calculate sparse kernel using k-nearest neighbors.""" n = exp_dist.shape[0] maxk = torch.topk(exp_dist, k, dim=1) mink_indices = torch.topk(exp_dist, n - k, dim=1, largest=False).indices exp_dist[torch.arange(n, device=exp_dist.device).view(n, -1), mink_indices] = 0 knn_sum = maxk.values.sum(dim=1, keepdim=True) return exp_dist / knn_sum def scale_sigmoid(tensor: torch.Tensor, alpha: float) -> torch.Tensor: """Apply scaled sigmoid transformation.""" alpha = torch.tensor(alpha, dtype=torch.float32, device=tensor.device) return torch.sigmoid(alpha * tensor) # ---------------------------------------------------------------------------- # Data Processing and Helper Functions # ---------------------------------------------------------------------------- def init_seeds(seed: int = 0) -> None: """Initialize random seeds for reproducibility.""" np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def distribute_compute( lr_list: List[float], wd_list: List[float], scale_list: List[float], layer_size: List[int], sigma_list: List[float], beta_list: List[float], workers: int, id: int ) -> np.ndarray: """Distribute hyperparameter combinations across workers.""" all_combinations = [ [lr, wd, sc, la, sg, bt] for lr, wd, sc, la, sg, bt in it.product(lr_list, wd_list, scale_list, layer_size, sigma_list, beta_list) ] return np.array_split(all_combinations, workers)[id] def get_fingerprint(cid: int) -> np.ndarray: """Retrieve PubChem fingerprint for a given compound CID.""" compound = pcp.Compound.from_cid(cid) fingerprint = "".join(f"{int(bit, 16):04b}" for bit in compound.fingerprint) return np.array([int(b) for b in fingerprint], dtype=np.int32) def save_fingerprint(cid_list: List[int], last_cid: int, fpath: str) -> None: """Save fingerprints for a list of compound CIDs to disk.""" start_idx = np.where(np.array(cid_list) == last_cid)[0][0] + 1 if last_cid > 0 else 0 for cid in cid_list[start_idx:]: fingerprint = get_fingerprint(cid) np.save(os.path.join(fpath, str(cid)), fingerprint) print(f"CID {cid} processed successfully.") time.sleep(1) if start_idx >= len(cid_list): print("All compounds have been processed!") def read_fingerprint_cid(path: str) -> Tuple[np.ndarray, List[int]]: """Read fingerprints from .npy files in the specified directory.""" fingerprint = [] cids = [] for file_name in sorted(os.listdir(path)): if file_name.endswith(".npy"): cid = int(file_name.split(".")[0]) fing = np.load(os.path.join(path, file_name)) fingerprint.append(fing) cids.append(cid) fingerprint = np.array(fingerprint).reshape(-1, 920) return fingerprint, cids def common_data_index(data_for_index: np.ndarray, data_for_cmp: np.ndarray) -> np.ndarray: """Find indices of elements in data_for_index that exist in data_for_cmp.""" return np.where(np.isin(data_for_index, data_for_cmp))[0] def to_coo_matrix(adj_mat: Union[np.ndarray, sp.coo_matrix]) -> sp.coo_matrix: """Convert input matrix to scipy.sparse.coo_matrix format.""" if not sp.isspmatrix_coo(adj_mat): adj_mat = sp.coo_matrix(adj_mat) return adj_mat def mask(positive: sp.coo_matrix, negative: sp.coo_matrix, dtype: type = int) -> torch.Tensor: """Create a mask combining positive and negative edges.""" row = np.hstack((positive.row, negative.row)) col = np.hstack((positive.col, negative.col)) data = np.ones_like(row) masked = sp.coo_matrix((data, (row, col)), shape=positive.shape).toarray().astype(dtype) return torch.from_numpy(masked) def to_tensor(positive: sp.coo_matrix, identity: bool = False) -> torch.Tensor: """Convert sparse matrix to torch.Tensor, optionally adding identity matrix.""" data = positive + sp.identity(positive.shape[0]) if identity else positive return torch.from_numpy(data.toarray()).float() def np_delete_value(arr: np.ndarray, obj: np.ndarray) -> np.ndarray: """Remove specified values from a NumPy array.""" indices = [np.where(arr == x)[0][0] for x in obj if x in arr] return np.delete(arr, indices) def translate_result(tensor: Union[torch.Tensor, np.ndarray]) -> pd.DataFrame: """Convert tensor or array to a pandas DataFrame.""" if isinstance(tensor, torch.Tensor): tensor = tensor.detach().cpu().numpy() return pd.DataFrame(tensor.reshape(1, -1)) def calculate_train_test_index( response: np.ndarray, pos_train_index: np.ndarray, pos_test_index: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """Calculate train and test indices combining positive and negative samples.""" neg_response_index = np.where(response == 0)[0] neg_test_index = np.random.choice(neg_response_index, pos_test_index.shape[0], replace=False) neg_train_index = np_delete_value(neg_response_index, neg_test_index) test_index = np.hstack((pos_test_index, neg_test_index)) train_index = np.hstack((pos_train_index, neg_train_index)) return train_index, test_index def dir_path(k: int = 1) -> str: """Get directory path by traversing k levels up from current file.""" fpath = os.path.realpath(__file__) dir_name = os.path.dirname(fpath).replace("\\", "/") for _ in range(k): dir_name = os.path.dirname(dir_name) return dir_name def extract_row_data(data: pd.DataFrame, row: int) -> np.ndarray: """Extract non-NaN data from a specific row of a DataFrame.""" target = np.array(data.iloc[row], dtype=np.float32) return target[~np.isnan(target)] def transfer_data(data: pd.DataFrame, label: str) -> pd.DataFrame: """Add a label column to a DataFrame.""" data = data.copy() data["label"] = label return data def link_data_frame(*data: pd.DataFrame) -> pd.DataFrame: """Concatenate multiple DataFrames vertically.""" return pd.concat(data, ignore_index=True) def calculate_limit(*data: pd.DataFrame, key: Union[str, int]) -> Tuple[float, float]: """Calculate min and max values of a key across multiple DataFrames.""" temp = pd.concat(data, ignore_index=True) return temp[key].min() - 0.1, temp[key].max() + 0.1 def delete_all_sub_str(string: str, sub: str, join_str: str = "") -> str: """Remove all occurrences of a substring and join with specified string.""" parts = string.split(sub) parts = [p for p in parts if p] return join_str.join(parts) def get_best_index(fname: str) -> int: """Find the index of the AUC closest to the average AUC from a results file.""" with open(fname, "r") as file: content = file.read().replace("\n", "") auc_str = content.split("accs")[0].split(":")[1] auc_str = delete_all_sub_str(auc_str, " ", ",").replace(",]", "]") aucs = np.array(eval(auc_str)) avg_auc = float(content.split("avg_aucs")[1].split(":")[1].split()[0]) return np.argmin(np.abs(aucs - avg_auc)) def gather_color_code(*string: str) -> List[Tuple[float, float, float]]: """Map color names to seaborn color palette codes.""" color_str = ["bluea0"] = ["blue", "orange", "green", "red", "purple", "brown", "pink", "grey", "yellow", "cyan"] palette = sns.color_palette() color_map = dict(zip(color_str, palette)) return [color_map[color] for color in string]