123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395 |
- import os
- import time
- from typing import Tuple, List, Union, Optional
- import torch
- import numpy as np
- import pandas as pd
- import torch.nn as nn
- import seaborn as sns
- import pubchempy as pcp
- import scipy.sparse as sp
- from sklearn.metrics import roc_auc_score, average_precision_score
- import itertools as it
- import torch.nn.functional as F
-
-
- # ----------------------------------------------------------------------------
- # Model Evaluation Functions
- # ----------------------------------------------------------------------------
-
- def roc_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> float:
- """Calculate ROC-AUC score for binary classification."""
- assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
- return roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
-
-
- def ap_score(true_data: torch.Tensor, predict_data: torch.Tensor) -> float:
- """Calculate Average Precision (area under Precision-Recall curve)."""
- assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
- return average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
-
-
- def f1_score_binary(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]:
- """Calculate F1 score and the optimal threshold for binary classification."""
- thresholds = torch.unique(predict_data)
- n_samples = true_data.size(0)
- ones = torch.ones((thresholds.size(0), n_samples), device=true_data.device)
- zeros = torch.zeros((thresholds.size(0), n_samples), device=true_data.device)
- predict_value = torch.where(predict_data.view(1, -1) >= thresholds.view(-1, 1), ones, zeros)
-
- tpn = torch.sum(torch.where(predict_value == true_data.view(1, -1), ones, zeros), dim=1)
- tp = torch.sum(predict_value * true_data.view(1, -1), dim=1)
- scores = (2 * tp) / (n_samples + 2 * tp - tpn)
-
- max_f1_score = torch.max(scores)
- threshold = thresholds[torch.argmax(scores)]
- return max_f1_score.item(), threshold.item()
-
-
- def accuracy_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
- """Calculate accuracy using the specified threshold."""
- predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
- correct = torch.sum(predict_value == true_data).float()
- return (correct / true_data.size(0)).item()
-
-
- def precision_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
- """Calculate precision using the specified threshold."""
- predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
- tp = torch.sum(true_data * predict_value)
- fp = torch.sum((1 - true_data) * predict_value)
- return (tp / (tp + fp + 1e-8)).item()
-
-
- def recall_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
- """Calculate recall using the specified threshold."""
- predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
- tp = torch.sum(true_data * predict_value)
- fn = torch.sum(true_data * (1 - predict_value))
- return (tp / (tp + fn + 1e-8)).item()
-
-
- def mcc_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
- """Calculate Matthews Correlation Coefficient (MCC) using the specified threshold."""
- predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
- true_neg = 1 - true_data
- predict_neg = 1 - predict_value
-
- tp = torch.sum(true_data * predict_value)
- tn = torch.sum(true_neg * predict_neg)
- fp = torch.sum(true_neg * predict_value)
- fn = torch.sum(true_data * predict_neg)
-
- denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + 1e-8)
- return ((tp * tn - fp * fn) / denominator).item()
-
-
- def evaluate_all(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, ...]:
- """Evaluate multiple metrics: ROC-AUC, AP, accuracy, F1, and MCC."""
- assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
- auc = roc_auc(true_data, predict_data)
- ap = ap_score(true_data, predict_data)
- f1, threshold = f1_score_binary(true_data, predict_data)
- acc = accuracy_binary(true_data, predict_data, threshold)
- precision = precision_binary(true_data, predict_data, threshold)
- recall = recall_binary(true_data, predict_data, threshold)
- mcc = mcc_binary(true_data, predict_data, threshold)
- return auc, ap, acc, f1, mcc, precision, recall
-
-
- def evaluate_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]:
- """Calculate ROC-AUC and Average Precision."""
- assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
- auc = roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
- ap = average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
- return auc, ap
-
-
- # ----------------------------------------------------------------------------
- # Loss Functions
- # ----------------------------------------------------------------------------
-
- def cross_entropy_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor:
- """Calculate masked binary cross-entropy loss."""
- masked = masked.to(torch.bool)
- true_data = torch.masked_select(true_data, masked)
- pred_data = torch.masked_select(predict_data, masked)
- return nn.BCELoss()(pred_data, true_data)
-
-
- def mse_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor:
- """Calculate masked mean squared error loss."""
- true_data = true_data * masked
- predict_data = predict_data * masked
- return nn.MSELoss()(predict_data, true_data)
-
-
- def prototypical_loss(
- cell_emb: torch.Tensor,
- drug_emb: torch.Tensor,
- adj_matrix: Union[torch.Tensor, np.ndarray],
- margin: float = 2.0
- ) -> torch.Tensor:
- """Calculate prototypical loss for positive and negative pairs."""
- if isinstance(adj_matrix, torch.Tensor):
- adj_matrix = sp.coo_matrix(adj_matrix.detach().cpu().numpy())
-
- pos_pairs = torch.sum(cell_emb[adj_matrix.row] * drug_emb[adj_matrix.col], dim=1)
- n_pos = len(adj_matrix.row)
-
- cell_neg = torch.randint(0, cell_emb.size(0), (n_pos,), device=cell_emb.device)
- drug_neg = torch.randint(0, drug_emb.size(0), (n_pos,), device=drug_emb.device)
- neg_pairs = torch.sum(cell_emb[cell_neg] * drug_emb[drug_neg], dim=1)
-
- labels = torch.ones_like(pos_pairs, device=cell_emb.device)
- return F.margin_ranking_loss(pos_pairs, neg_pairs, labels, margin=margin)
-
-
- # ----------------------------------------------------------------------------
- # Correlation and Normalization Functions
- # ----------------------------------------------------------------------------
-
- def torch_z_normalized(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor:
- """Apply z-normalization along the specified dimension."""
- mean = tensor.mean(dim=1 - dim, keepdim=True)
- std = tensor.std(dim=1 - dim, keepdim=True) + 1e-8
- return (tensor - mean) / std
-
-
- def torch_corr_x_y(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- """Compute correlation matrix between row vectors of two matrices."""
- x_center = x - x.mean(dim=1, keepdim=True)
- y_center = y - y.mean(dim=1, keepdim=True)
-
- x_std = x.std(dim=1, keepdim=True) + 1e-8
- y_std = y.std(dim=1, keepdim=True) + 1e-8
-
- x_norm = x_center / x_std
- y_norm = y_center / y_std
-
- corr_matrix = x_norm @ y_norm.t() / (x.size(1) - 1)
- return corr_matrix
-
-
- # ----------------------------------------------------------------------------
- # Distance and Similarity Functions
- # ----------------------------------------------------------------------------
-
- def torch_euclidean_dist(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor:
- """Calculate Euclidean distance between rows or columns of a tensor."""
- tensor_mul = torch.mm(tensor.t(), tensor) if dim else torch.mm(tensor, tensor.t())
- diag = torch.diag(tensor_mul)
- n_diag = diag.size(0)
- tensor_diag = diag.repeat(n_diag, 1)
- diag = diag.view(n_diag, -1)
- dist = torch.sqrt(tensor_diag + diag - 2 * tensor_mul)
- return dist
-
-
- def exp_similarity(tensor: torch.Tensor, sigma: torch.Tensor, normalize: bool = True) -> torch.Tensor:
- """Calculate exponential similarity based on Euclidean distance."""
- if normalize:
- tensor = torch_z_normalized(tensor, dim=1)
- tensor_dist = torch_euclidean_dist(tensor, dim=0)
- return torch.exp(-tensor_dist / (2 * sigma.pow(2)))
-
-
- def full_kernel(exp_dist: torch.Tensor) -> torch.Tensor:
- """Calculate full kernel matrix from exponential similarity."""
- n = exp_dist.shape[0]
- ones = torch.ones(n, n, device=exp_dist.device)
- diag = torch.diag(ones)
- mask_diag = (ones - diag) * exp_dist
- mask_diag_sum = mask_diag.sum(dim=1, keepdim=True)
- mask_diag = mask_diag / (2 * mask_diag_sum) + 0.5 * diag
- return mask_diag
-
-
- def sparse_kernel(exp_dist: torch.Tensor, k: int) -> torch.Tensor:
- """Calculate sparse kernel using k-nearest neighbors."""
- n = exp_dist.shape[0]
- maxk = torch.topk(exp_dist, k, dim=1)
- mink_indices = torch.topk(exp_dist, n - k, dim=1, largest=False).indices
- exp_dist[torch.arange(n, device=exp_dist.device).view(n, -1), mink_indices] = 0
- knn_sum = maxk.values.sum(dim=1, keepdim=True)
- return exp_dist / knn_sum
-
-
- def scale_sigmoid(tensor: torch.Tensor, alpha: float) -> torch.Tensor:
- """Apply scaled sigmoid transformation."""
- alpha = torch.tensor(alpha, dtype=torch.float32, device=tensor.device)
- return torch.sigmoid(alpha * tensor)
-
-
- # ----------------------------------------------------------------------------
- # Data Processing and Helper Functions
- # ----------------------------------------------------------------------------
-
- def init_seeds(seed: int = 0) -> None:
- """Initialize random seeds for reproducibility."""
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
-
- def distribute_compute(
- lr_list: List[float],
- wd_list: List[float],
- scale_list: List[float],
- layer_size: List[int],
- sigma_list: List[float],
- beta_list: List[float],
- workers: int,
- id: int
- ) -> np.ndarray:
- """Distribute hyperparameter combinations across workers."""
- all_combinations = [
- [lr, wd, sc, la, sg, bt]
- for lr, wd, sc, la, sg, bt in it.product(lr_list, wd_list, scale_list, layer_size, sigma_list, beta_list)
- ]
- return np.array_split(all_combinations, workers)[id]
-
-
- def get_fingerprint(cid: int) -> np.ndarray:
- """Retrieve PubChem fingerprint for a given compound CID."""
- compound = pcp.Compound.from_cid(cid)
- fingerprint = "".join(f"{int(bit, 16):04b}" for bit in compound.fingerprint)
- return np.array([int(b) for b in fingerprint], dtype=np.int32)
-
-
- def save_fingerprint(cid_list: List[int], last_cid: int, fpath: str) -> None:
- """Save fingerprints for a list of compound CIDs to disk."""
- start_idx = np.where(np.array(cid_list) == last_cid)[0][0] + 1 if last_cid > 0 else 0
- for cid in cid_list[start_idx:]:
- fingerprint = get_fingerprint(cid)
- np.save(os.path.join(fpath, str(cid)), fingerprint)
- print(f"CID {cid} processed successfully.")
- time.sleep(1)
- if start_idx >= len(cid_list):
- print("All compounds have been processed!")
-
-
- def read_fingerprint_cid(path: str) -> Tuple[np.ndarray, List[int]]:
- """Read fingerprints from .npy files in the specified directory."""
- fingerprint = []
- cids = []
- for file_name in sorted(os.listdir(path)):
- if file_name.endswith(".npy"):
- cid = int(file_name.split(".")[0])
- fing = np.load(os.path.join(path, file_name))
- fingerprint.append(fing)
- cids.append(cid)
- fingerprint = np.array(fingerprint).reshape(-1, 920)
- return fingerprint, cids
-
-
- def common_data_index(data_for_index: np.ndarray, data_for_cmp: np.ndarray) -> np.ndarray:
- """Find indices of elements in data_for_index that exist in data_for_cmp."""
- return np.where(np.isin(data_for_index, data_for_cmp))[0]
-
-
- def to_coo_matrix(adj_mat: Union[np.ndarray, sp.coo_matrix]) -> sp.coo_matrix:
- """Convert input matrix to scipy.sparse.coo_matrix format."""
- if not sp.isspmatrix_coo(adj_mat):
- adj_mat = sp.coo_matrix(adj_mat)
- return adj_mat
-
-
- def mask(positive: sp.coo_matrix, negative: sp.coo_matrix, dtype: type = int) -> torch.Tensor:
- """Create a mask combining positive and negative edges."""
- row = np.hstack((positive.row, negative.row))
- col = np.hstack((positive.col, negative.col))
- data = np.ones_like(row)
- masked = sp.coo_matrix((data, (row, col)), shape=positive.shape).toarray().astype(dtype)
- return torch.from_numpy(masked)
-
-
- def to_tensor(positive: sp.coo_matrix, identity: bool = False) -> torch.Tensor:
- """Convert sparse matrix to torch.Tensor, optionally adding identity matrix."""
- data = positive + sp.identity(positive.shape[0]) if identity else positive
- return torch.from_numpy(data.toarray()).float()
-
-
- def np_delete_value(arr: np.ndarray, obj: np.ndarray) -> np.ndarray:
- """Remove specified values from a NumPy array."""
- indices = [np.where(arr == x)[0][0] for x in obj if x in arr]
- return np.delete(arr, indices)
-
-
- def translate_result(tensor: Union[torch.Tensor, np.ndarray]) -> pd.DataFrame:
- """Convert tensor or array to a pandas DataFrame."""
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
- return pd.DataFrame(tensor.reshape(1, -1))
-
-
- def calculate_train_test_index(
- response: np.ndarray,
- pos_train_index: np.ndarray,
- pos_test_index: np.ndarray
- ) -> Tuple[np.ndarray, np.ndarray]:
- """Calculate train and test indices combining positive and negative samples."""
- neg_response_index = np.where(response == 0)[0]
- neg_test_index = np.random.choice(neg_response_index, pos_test_index.shape[0], replace=False)
- neg_train_index = np_delete_value(neg_response_index, neg_test_index)
- test_index = np.hstack((pos_test_index, neg_test_index))
- train_index = np.hstack((pos_train_index, neg_train_index))
- return train_index, test_index
-
-
- def dir_path(k: int = 1) -> str:
- """Get directory path by traversing k levels up from current file."""
- fpath = os.path.realpath(__file__)
- dir_name = os.path.dirname(fpath).replace("\\", "/")
- for _ in range(k):
- dir_name = os.path.dirname(dir_name)
- return dir_name
-
-
- def extract_row_data(data: pd.DataFrame, row: int) -> np.ndarray:
- """Extract non-NaN data from a specific row of a DataFrame."""
- target = np.array(data.iloc[row], dtype=np.float32)
- return target[~np.isnan(target)]
-
-
- def transfer_data(data: pd.DataFrame, label: str) -> pd.DataFrame:
- """Add a label column to a DataFrame."""
- data = data.copy()
- data["label"] = label
- return data
-
-
- def link_data_frame(*data: pd.DataFrame) -> pd.DataFrame:
- """Concatenate multiple DataFrames vertically."""
- return pd.concat(data, ignore_index=True)
-
-
- def calculate_limit(*data: pd.DataFrame, key: Union[str, int]) -> Tuple[float, float]:
- """Calculate min and max values of a key across multiple DataFrames."""
- temp = pd.concat(data, ignore_index=True)
- return temp[key].min() - 0.1, temp[key].max() + 0.1
-
-
- def delete_all_sub_str(string: str, sub: str, join_str: str = "") -> str:
- """Remove all occurrences of a substring and join with specified string."""
- parts = string.split(sub)
- parts = [p for p in parts if p]
- return join_str.join(parts)
-
-
- def get_best_index(fname: str) -> int:
- """Find the index of the AUC closest to the average AUC from a results file."""
- with open(fname, "r") as file:
- content = file.read().replace("\n", "")
-
- auc_str = content.split("accs")[0].split(":")[1]
- auc_str = delete_all_sub_str(auc_str, " ", ",").replace(",]", "]")
- aucs = np.array(eval(auc_str))
-
- avg_auc = float(content.split("avg_aucs")[1].split(":")[1].split()[0])
- return np.argmin(np.abs(aucs - avg_auc))
-
|