| @@ -1,401 +1,395 @@ | |||
| import os | |||
| import time | |||
| from typing import Tuple, List, Union, Optional | |||
| import torch | |||
| import numpy as np | |||
| import pandas as pd | |||
| import torch.nn as nn | |||
| import seaborn as sns | |||
| import pubchempy as pcp | |||
| import scipy.sparse as sp | |||
| from sklearn.metrics import roc_auc_score, average_precision_score | |||
| import itertools as it | |||
| import torch.nn.functional as F | |||
| # ---------------------------------------------------------------------------- | |||
| # Model Evaluation Functions | |||
| # ---------------------------------------------------------------------------- | |||
| def roc_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> float: | |||
| """Calculate ROC-AUC score for binary classification.""" | |||
| assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" | |||
| return roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) | |||
| def ap_score(true_data: torch.Tensor, predict_data: torch.Tensor) -> float: | |||
| """Calculate Average Precision (area under Precision-Recall curve).""" | |||
| assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" | |||
| return average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) | |||
| def f1_score_binary(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]: | |||
| """Calculate F1 score and the optimal threshold for binary classification.""" | |||
| thresholds = torch.unique(predict_data) | |||
| n_samples = true_data.size(0) | |||
| ones = torch.ones((thresholds.size(0), n_samples), device=true_data.device) | |||
| zeros = torch.zeros((thresholds.size(0), n_samples), device=true_data.device) | |||
| predict_value = torch.where(predict_data.view(1, -1) >= thresholds.view(-1, 1), ones, zeros) | |||
| tpn = torch.sum(torch.where(predict_value == true_data.view(1, -1), ones, zeros), dim=1) | |||
| tp = torch.sum(predict_value * true_data.view(1, -1), dim=1) | |||
| scores = (2 * tp) / (n_samples + 2 * tp - tpn) | |||
| max_f1_score = torch.max(scores) | |||
| threshold = thresholds[torch.argmax(scores)] | |||
| return max_f1_score.item(), threshold.item() | |||
| def accuracy_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: | |||
| """Calculate accuracy using the specified threshold.""" | |||
| predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) | |||
| correct = torch.sum(predict_value == true_data).float() | |||
| return (correct / true_data.size(0)).item() | |||
| def precision_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: | |||
| """Calculate precision using the specified threshold.""" | |||
| predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) | |||
| tp = torch.sum(true_data * predict_value) | |||
| fp = torch.sum((1 - true_data) * predict_value) | |||
| return (tp / (tp + fp + 1e-8)).item() | |||
| def recall_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: | |||
| """Calculate recall using the specified threshold.""" | |||
| predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) | |||
| tp = torch.sum(true_data * predict_value) | |||
| fn = torch.sum(true_data * (1 - predict_value)) | |||
| return (tp / (tp + fn + 1e-8)).item() | |||
| def mcc_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: | |||
| """Calculate Matthews Correlation Coefficient (MCC) using the specified threshold.""" | |||
| predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) | |||
| true_neg = 1 - true_data | |||
| predict_neg = 1 - predict_value | |||
| tp = torch.sum(true_data * predict_value) | |||
| tn = torch.sum(true_neg * predict_neg) | |||
| fp = torch.sum(true_neg * predict_value) | |||
| fn = torch.sum(true_data * predict_neg) | |||
| denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + 1e-8) | |||
| return ((tp * tn - fp * fn) / denominator).item() | |||
| def evaluate_all(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, ...]: | |||
| """Evaluate multiple metrics: ROC-AUC, AP, accuracy, F1, and MCC.""" | |||
| assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" | |||
| auc = roc_auc(true_data, predict_data) | |||
| ap = ap_score(true_data, predict_data) | |||
| f1, threshold = f1_score_binary(true_data, predict_data) | |||
| acc = accuracy_binary(true_data, predict_data, threshold) | |||
| precision = precision_binary(true_data, predict_data, threshold) | |||
| recall = recall_binary(true_data, predict_data, threshold) | |||
| mcc = mcc_binary(true_data, predict_data, threshold) | |||
| return auc, ap, acc, f1, mcc, precision, recall | |||
| def evaluate_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]: | |||
| """Calculate ROC-AUC and Average Precision.""" | |||
| assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" | |||
| auc = roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) | |||
| ap = average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) | |||
| return auc, ap | |||
| # ---------------------------------------------------------------------------- | |||
| # Loss Functions | |||
| # ---------------------------------------------------------------------------- | |||
| def cross_entropy_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor: | |||
| """Calculate masked binary cross-entropy loss.""" | |||
| masked = masked.to(torch.bool) | |||
| true_data = torch.masked_select(true_data, masked) | |||
| pred_data = torch.masked_select(predict_data, masked) | |||
| return nn.BCELoss()(pred_data, true_data) | |||
| def mse_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor: | |||
| """Calculate masked mean squared error loss.""" | |||
| true_data = true_data * masked | |||
| predict_data = predict_data * masked | |||
| return nn.MSELoss()(predict_data, true_data) | |||
| def prototypical_loss( | |||
| cell_emb: torch.Tensor, | |||
| drug_emb: torch.Tensor, | |||
| adj_matrix: Union[torch.Tensor, np.ndarray], | |||
| margin: float = 2.0 | |||
| ) -> torch.Tensor: | |||
| """Calculate prototypical loss for positive and negative pairs.""" | |||
| if isinstance(adj_matrix, torch.Tensor): | |||
| adj_matrix = sp.coo_matrix(adj_matrix.detach().cpu().numpy()) | |||
| pos_pairs = torch.sum(cell_emb[adj_matrix.row] * drug_emb[adj_matrix.col], dim=1) | |||
| n_pos = len(adj_matrix.row) | |||
| cell_neg = torch.randint(0, cell_emb.size(0), (n_pos,), device=cell_emb.device) | |||
| drug_neg = torch.randint(0, drug_emb.size(0), (n_pos,), device=drug_emb.device) | |||
| neg_pairs = torch.sum(cell_emb[cell_neg] * drug_emb[drug_neg], dim=1) | |||
| labels = torch.ones_like(pos_pairs, device=cell_emb.device) | |||
| return F.margin_ranking_loss(pos_pairs, neg_pairs, labels, margin=margin) | |||
| # ---------------------------------------------------------------------------- | |||
| # Correlation and Normalization Functions | |||
| # ---------------------------------------------------------------------------- | |||
| def torch_z_normalized(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: | |||
| """Apply z-normalization along the specified dimension.""" | |||
| mean = tensor.mean(dim=1 - dim, keepdim=True) | |||
| std = tensor.std(dim=1 - dim, keepdim=True) + 1e-8 | |||
| return (tensor - mean) / std | |||
| def torch_corr_x_y(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |||
| """Compute correlation matrix between row vectors of two matrices.""" | |||
| x_center = x - x.mean(dim=1, keepdim=True) | |||
| y_center = y - y.mean(dim=1, keepdim=True) | |||
| x_std = x.std(dim=1, keepdim=True) + 1e-8 | |||
| y_std = y.std(dim=1, keepdim=True) + 1e-8 | |||
| x_norm = x_center / x_std | |||
| y_norm = y_center / y_std | |||
| corr_matrix = x_norm @ y_norm.t() / (x.size(1) - 1) | |||
| return corr_matrix | |||
| # ---------------------------------------------------------------------------- | |||
| # Distance and Similarity Functions | |||
| # ---------------------------------------------------------------------------- | |||
| def torch_euclidean_dist(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: | |||
| """Calculate Euclidean distance between rows or columns of a tensor.""" | |||
| tensor_mul = torch.mm(tensor.t(), tensor) if dim else torch.mm(tensor, tensor.t()) | |||
| diag = torch.diag(tensor_mul) | |||
| n_diag = diag.size(0) | |||
| tensor_diag = diag.repeat(n_diag, 1) | |||
| diag = diag.view(n_diag, -1) | |||
| dist = torch.sqrt(tensor_diag + diag - 2 * tensor_mul) | |||
| return dist | |||
| def exp_similarity(tensor: torch.Tensor, sigma: torch.Tensor, normalize: bool = True) -> torch.Tensor: | |||
| """Calculate exponential similarity based on Euclidean distance.""" | |||
| if normalize: | |||
| tensor = torch_z_normalized(tensor, dim=1) | |||
| tensor_dist = torch_euclidean_dist(tensor, dim=0) | |||
| return torch.exp(-tensor_dist / (2 * sigma.pow(2))) | |||
| def full_kernel(exp_dist: torch.Tensor) -> torch.Tensor: | |||
| """Calculate full kernel matrix from exponential similarity.""" | |||
| n = exp_dist.shape[0] | |||
| ones = torch.ones(n, n, device=exp_dist.device) | |||
| diag = torch.diag(ones) | |||
| mask_diag = (ones - diag) * exp_dist | |||
| mask_diag_sum = mask_diag.sum(dim=1, keepdim=True) | |||
| mask_diag = mask_diag / (2 * mask_diag_sum) + 0.5 * diag | |||
| return mask_diag | |||
| def sparse_kernel(exp_dist: torch.Tensor, k: int) -> torch.Tensor: | |||
| """Calculate sparse kernel using k-nearest neighbors.""" | |||
| n = exp_dist.shape[0] | |||
| maxk = torch.topk(exp_dist, k, dim=1) | |||
| mink_indices = torch.topk(exp_dist, n - k, dim=1, largest=False).indices | |||
| exp_dist[torch.arange(n, device=exp_dist.device).view(n, -1), mink_indices] = 0 | |||
| knn_sum = maxk.values.sum(dim=1, keepdim=True) | |||
| return exp_dist / knn_sum | |||
| def scale_sigmoid(tensor: torch.Tensor, alpha: float) -> torch.Tensor: | |||
| """Apply scaled sigmoid transformation.""" | |||
| alpha = torch.tensor(alpha, dtype=torch.float32, device=tensor.device) | |||
| return torch.sigmoid(alpha * tensor) | |||
| # ---------------------------------------------------------------------------- | |||
| # Data Processing and Helper Functions | |||
| # ---------------------------------------------------------------------------- | |||
| def init_seeds(seed: int = 0) -> None: | |||
| """Initialize random seeds for reproducibility.""" | |||
| np.random.seed(seed) | |||
| torch.manual_seed(seed) | |||
| torch.cuda.manual_seed(seed) | |||
| torch.cuda.manual_seed_all(seed) | |||
| torch.backends.cudnn.deterministic = True | |||
| torch.backends.cudnn.benchmark = False | |||
| def distribute_compute( | |||
| lr_list: List[float], | |||
| wd_list: List[float], | |||
| scale_list: List[float], | |||
| layer_size: List[int], | |||
| sigma_list: List[float], | |||
| beta_list: List[float], | |||
| workers: int, | |||
| id: int | |||
| ) -> np.ndarray: | |||
| """Distribute hyperparameter combinations across workers.""" | |||
| all_combinations = [ | |||
| [lr, wd, sc, la, sg, bt] | |||
| for lr, wd, sc, la, sg, bt in it.product(lr_list, wd_list, scale_list, layer_size, sigma_list, beta_list) | |||
| ] | |||
| return np.array_split(all_combinations, workers)[id] | |||
| def get_fingerprint(cid: int) -> np.ndarray: | |||
| """Retrieve PubChem fingerprint for a given compound CID.""" | |||
| compound = pcp.Compound.from_cid(cid) | |||
| fingerprint = "".join(f"{int(bit, 16):04b}" for bit in compound.fingerprint) | |||
| return np.array([int(b) for b in fingerprint], dtype=np.int32) | |||
| def save_fingerprint(cid_list: List[int], last_cid: int, fpath: str) -> None: | |||
| """Save fingerprints for a list of compound CIDs to disk.""" | |||
| start_idx = np.where(np.array(cid_list) == last_cid)[0][0] + 1 if last_cid > 0 else 0 | |||
| for cid in cid_list[start_idx:]: | |||
| fingerprint = get_fingerprint(cid) | |||
| np.save(os.path.join(fpath, str(cid)), fingerprint) | |||
| print(f"CID {cid} processed successfully.") | |||
| time.sleep(1) | |||
| if start_idx >= len(cid_list): | |||
| print("All compounds have been processed!") | |||
| def read_fingerprint_cid(path: str) -> Tuple[np.ndarray, List[int]]: | |||
| """Read fingerprints from .npy files in the specified directory.""" | |||
| fingerprint = [] | |||
| cids = [] | |||
| for file_name in sorted(os.listdir(path)): | |||
| if file_name.endswith(".npy"): | |||
| cid = int(file_name.split(".")[0]) | |||
| fing = np.load(os.path.join(path, file_name)) | |||
| fingerprint.append(fing) | |||
| cids.append(cid) | |||
| fingerprint = np.array(fingerprint).reshape(-1, 920) | |||
| return fingerprint, cids | |||
| def common_data_index(data_for_index: np.ndarray, data_for_cmp: np.ndarray) -> np.ndarray: | |||
| """Find indices of elements in data_for_index that exist in data_for_cmp.""" | |||
| return np.where(np.isin(data_for_index, data_for_cmp))[0] | |||
| def to_coo_matrix(adj_mat: Union[np.ndarray, sp.coo_matrix]) -> sp.coo_matrix: | |||
| """Convert input matrix to scipy.sparse.coo_matrix format.""" | |||
| if not sp.isspmatrix_coo(adj_mat): | |||
| adj_mat = sp.coo_matrix(adj_mat) | |||
| return adj_mat | |||
| def mask(positive: sp.coo_matrix, negative: sp.coo_matrix, dtype: type = int) -> torch.Tensor: | |||
| """Create a mask combining positive and negative edges.""" | |||
| row = np.hstack((positive.row, negative.row)) | |||
| col = np.hstack((positive.col, negative.col)) | |||
| data = np.ones_like(row) | |||
| masked = sp.coo_matrix((data, (row, col)), shape=positive.shape).toarray().astype(dtype) | |||
| return torch.from_numpy(masked) | |||
| def to_tensor(positive: sp.coo_matrix, identity: bool = False) -> torch.Tensor: | |||
| """Convert sparse matrix to torch.Tensor, optionally adding identity matrix.""" | |||
| data = positive + sp.identity(positive.shape[0]) if identity else positive | |||
| return torch.from_numpy(data.toarray()).float() | |||
| def np_delete_value(arr: np.ndarray, obj: np.ndarray) -> np.ndarray: | |||
| """Remove specified values from a NumPy array.""" | |||
| indices = [np.where(arr == x)[0][0] for x in obj if x in arr] | |||
| return np.delete(arr, indices) | |||
| def translate_result(tensor: Union[torch.Tensor, np.ndarray]) -> pd.DataFrame: | |||
| """Convert tensor or array to a pandas DataFrame.""" | |||
| if isinstance(tensor, torch.Tensor): | |||
| tensor = tensor.detach().cpu().numpy() | |||
| return pd.DataFrame(tensor.reshape(1, -1)) | |||
| def calculate_train_test_index( | |||
| response: np.ndarray, | |||
| pos_train_index: np.ndarray, | |||
| pos_test_index: np.ndarray | |||
| ) -> Tuple[np.ndarray, np.ndarray]: | |||
| """Calculate train and test indices combining positive and negative samples.""" | |||
| neg_response_index = np.where(response == 0)[0] | |||
| neg_test_index = np.random.choice(neg_response_index, pos_test_index.shape[0], replace=False) | |||
| neg_train_index = np_delete_value(neg_response_index, neg_test_index) | |||
| test_index = np.hstack((pos_test_index, neg_test_index)) | |||
| train_index = np.hstack((pos_train_index, neg_train_index)) | |||
| return train_index, test_index | |||
| def dir_path(k: int = 1) -> str: | |||
| """Get directory path by traversing k levels up from current file.""" | |||
| fpath = os.path.realpath(__file__) | |||
| dir_name = os.path.dirname(fpath).replace("\\", "/") | |||
| for _ in range(k): | |||
| dir_name = os.path.dirname(dir_name) | |||
| return dir_name | |||
| def extract_row_data(data: pd.DataFrame, row: int) -> np.ndarray: | |||
| """Extract non-NaN data from a specific row of a DataFrame.""" | |||
| target = np.array(data.iloc[row], dtype=np.float32) | |||
| return target[~np.isnan(target)] | |||
| def transfer_data(data: pd.DataFrame, label: str) -> pd.DataFrame: | |||
| """Add a label column to a DataFrame.""" | |||
| data = data.copy() | |||
| data["label"] = label | |||
| return data | |||
| def link_data_frame(*data: pd.DataFrame) -> pd.DataFrame: | |||
| """Concatenate multiple DataFrames vertically.""" | |||
| return pd.concat(data, ignore_index=True) | |||
| def calculate_limit(*data: pd.DataFrame, key: Union[str, int]) -> Tuple[float, float]: | |||
| """Calculate min and max values of a key across multiple DataFrames.""" | |||
| temp = pd.concat(data, ignore_index=True) | |||
| return temp[key].min() - 0.1, temp[key].max() + 0.1 | |||
| def delete_all_sub_str(string: str, sub: str, join_str: str = "") -> str: | |||
| """Remove all occurrences of a substring and join with specified string.""" | |||
| parts = string.split(sub) | |||
| parts = [p for p in parts if p] | |||
| return join_str.join(parts) | |||
| def get_best_index(fname: str) -> int: | |||
| """Find the index of the AUC closest to the average AUC from a results file.""" | |||
| with open(fname, "r") as file: | |||
| content = file.read().replace("\n", "") | |||
| auc_str = content.split("accs")[0].split(":")[1] | |||
| auc_str = delete_all_sub_str(auc_str, " ", ",").replace(",]", "]") | |||
| aucs = np.array(eval(auc_str)) | |||
| avg_auc = float(content.split("avg_aucs")[1].split(":")[1].split()[0]) | |||
| return np.argmin(np.abs(aucs - avg_auc)) | |||
| def gather_color_code(*string: str) -> List[Tuple[float, float, float]]: | |||
| """Map color names to seaborn color palette codes.""" | |||
| color_str = ["bluea0"] = ["blue", "orange", "green", "red", "purple", "brown", "pink", "grey", "yellow", "cyan"] | |||
| palette = sns.color_palette() | |||
| color_map = dict(zip(color_str, palette)) | |||
| return [color_map[color] for color in string] | |||
| import os | |||
| import time | |||
| from typing import Tuple, List, Union, Optional | |||
| import torch | |||
| import numpy as np | |||
| import pandas as pd | |||
| import torch.nn as nn | |||
| import seaborn as sns | |||
| import pubchempy as pcp | |||
| import scipy.sparse as sp | |||
| from sklearn.metrics import roc_auc_score, average_precision_score | |||
| import itertools as it | |||
| import torch.nn.functional as F | |||
| # ---------------------------------------------------------------------------- | |||
| # Model Evaluation Functions | |||
| # ---------------------------------------------------------------------------- | |||
| def roc_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> float: | |||
| """Calculate ROC-AUC score for binary classification.""" | |||
| assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" | |||
| return roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) | |||
| def ap_score(true_data: torch.Tensor, predict_data: torch.Tensor) -> float: | |||
| """Calculate Average Precision (area under Precision-Recall curve).""" | |||
| assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" | |||
| return average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) | |||
| def f1_score_binary(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]: | |||
| """Calculate F1 score and the optimal threshold for binary classification.""" | |||
| thresholds = torch.unique(predict_data) | |||
| n_samples = true_data.size(0) | |||
| ones = torch.ones((thresholds.size(0), n_samples), device=true_data.device) | |||
| zeros = torch.zeros((thresholds.size(0), n_samples), device=true_data.device) | |||
| predict_value = torch.where(predict_data.view(1, -1) >= thresholds.view(-1, 1), ones, zeros) | |||
| tpn = torch.sum(torch.where(predict_value == true_data.view(1, -1), ones, zeros), dim=1) | |||
| tp = torch.sum(predict_value * true_data.view(1, -1), dim=1) | |||
| scores = (2 * tp) / (n_samples + 2 * tp - tpn) | |||
| max_f1_score = torch.max(scores) | |||
| threshold = thresholds[torch.argmax(scores)] | |||
| return max_f1_score.item(), threshold.item() | |||
| def accuracy_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: | |||
| """Calculate accuracy using the specified threshold.""" | |||
| predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) | |||
| correct = torch.sum(predict_value == true_data).float() | |||
| return (correct / true_data.size(0)).item() | |||
| def precision_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: | |||
| """Calculate precision using the specified threshold.""" | |||
| predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) | |||
| tp = torch.sum(true_data * predict_value) | |||
| fp = torch.sum((1 - true_data) * predict_value) | |||
| return (tp / (tp + fp + 1e-8)).item() | |||
| def recall_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: | |||
| """Calculate recall using the specified threshold.""" | |||
| predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) | |||
| tp = torch.sum(true_data * predict_value) | |||
| fn = torch.sum(true_data * (1 - predict_value)) | |||
| return (tp / (tp + fn + 1e-8)).item() | |||
| def mcc_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: | |||
| """Calculate Matthews Correlation Coefficient (MCC) using the specified threshold.""" | |||
| predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) | |||
| true_neg = 1 - true_data | |||
| predict_neg = 1 - predict_value | |||
| tp = torch.sum(true_data * predict_value) | |||
| tn = torch.sum(true_neg * predict_neg) | |||
| fp = torch.sum(true_neg * predict_value) | |||
| fn = torch.sum(true_data * predict_neg) | |||
| denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + 1e-8) | |||
| return ((tp * tn - fp * fn) / denominator).item() | |||
| def evaluate_all(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, ...]: | |||
| """Evaluate multiple metrics: ROC-AUC, AP, accuracy, F1, and MCC.""" | |||
| assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" | |||
| auc = roc_auc(true_data, predict_data) | |||
| ap = ap_score(true_data, predict_data) | |||
| f1, threshold = f1_score_binary(true_data, predict_data) | |||
| acc = accuracy_binary(true_data, predict_data, threshold) | |||
| precision = precision_binary(true_data, predict_data, threshold) | |||
| recall = recall_binary(true_data, predict_data, threshold) | |||
| mcc = mcc_binary(true_data, predict_data, threshold) | |||
| return auc, ap, acc, f1, mcc, precision, recall | |||
| def evaluate_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]: | |||
| """Calculate ROC-AUC and Average Precision.""" | |||
| assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" | |||
| auc = roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) | |||
| ap = average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) | |||
| return auc, ap | |||
| # ---------------------------------------------------------------------------- | |||
| # Loss Functions | |||
| # ---------------------------------------------------------------------------- | |||
| def cross_entropy_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor: | |||
| """Calculate masked binary cross-entropy loss.""" | |||
| masked = masked.to(torch.bool) | |||
| true_data = torch.masked_select(true_data, masked) | |||
| pred_data = torch.masked_select(predict_data, masked) | |||
| return nn.BCELoss()(pred_data, true_data) | |||
| def mse_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor: | |||
| """Calculate masked mean squared error loss.""" | |||
| true_data = true_data * masked | |||
| predict_data = predict_data * masked | |||
| return nn.MSELoss()(predict_data, true_data) | |||
| def prototypical_loss( | |||
| cell_emb: torch.Tensor, | |||
| drug_emb: torch.Tensor, | |||
| adj_matrix: Union[torch.Tensor, np.ndarray], | |||
| margin: float = 2.0 | |||
| ) -> torch.Tensor: | |||
| """Calculate prototypical loss for positive and negative pairs.""" | |||
| if isinstance(adj_matrix, torch.Tensor): | |||
| adj_matrix = sp.coo_matrix(adj_matrix.detach().cpu().numpy()) | |||
| pos_pairs = torch.sum(cell_emb[adj_matrix.row] * drug_emb[adj_matrix.col], dim=1) | |||
| n_pos = len(adj_matrix.row) | |||
| cell_neg = torch.randint(0, cell_emb.size(0), (n_pos,), device=cell_emb.device) | |||
| drug_neg = torch.randint(0, drug_emb.size(0), (n_pos,), device=drug_emb.device) | |||
| neg_pairs = torch.sum(cell_emb[cell_neg] * drug_emb[drug_neg], dim=1) | |||
| labels = torch.ones_like(pos_pairs, device=cell_emb.device) | |||
| return F.margin_ranking_loss(pos_pairs, neg_pairs, labels, margin=margin) | |||
| # ---------------------------------------------------------------------------- | |||
| # Correlation and Normalization Functions | |||
| # ---------------------------------------------------------------------------- | |||
| def torch_z_normalized(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: | |||
| """Apply z-normalization along the specified dimension.""" | |||
| mean = tensor.mean(dim=1 - dim, keepdim=True) | |||
| std = tensor.std(dim=1 - dim, keepdim=True) + 1e-8 | |||
| return (tensor - mean) / std | |||
| def torch_corr_x_y(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |||
| """Compute correlation matrix between row vectors of two matrices.""" | |||
| x_center = x - x.mean(dim=1, keepdim=True) | |||
| y_center = y - y.mean(dim=1, keepdim=True) | |||
| x_std = x.std(dim=1, keepdim=True) + 1e-8 | |||
| y_std = y.std(dim=1, keepdim=True) + 1e-8 | |||
| x_norm = x_center / x_std | |||
| y_norm = y_center / y_std | |||
| corr_matrix = x_norm @ y_norm.t() / (x.size(1) - 1) | |||
| return corr_matrix | |||
| # ---------------------------------------------------------------------------- | |||
| # Distance and Similarity Functions | |||
| # ---------------------------------------------------------------------------- | |||
| def torch_euclidean_dist(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: | |||
| """Calculate Euclidean distance between rows or columns of a tensor.""" | |||
| tensor_mul = torch.mm(tensor.t(), tensor) if dim else torch.mm(tensor, tensor.t()) | |||
| diag = torch.diag(tensor_mul) | |||
| n_diag = diag.size(0) | |||
| tensor_diag = diag.repeat(n_diag, 1) | |||
| diag = diag.view(n_diag, -1) | |||
| dist = torch.sqrt(tensor_diag + diag - 2 * tensor_mul) | |||
| return dist | |||
| def exp_similarity(tensor: torch.Tensor, sigma: torch.Tensor, normalize: bool = True) -> torch.Tensor: | |||
| """Calculate exponential similarity based on Euclidean distance.""" | |||
| if normalize: | |||
| tensor = torch_z_normalized(tensor, dim=1) | |||
| tensor_dist = torch_euclidean_dist(tensor, dim=0) | |||
| return torch.exp(-tensor_dist / (2 * sigma.pow(2))) | |||
| def full_kernel(exp_dist: torch.Tensor) -> torch.Tensor: | |||
| """Calculate full kernel matrix from exponential similarity.""" | |||
| n = exp_dist.shape[0] | |||
| ones = torch.ones(n, n, device=exp_dist.device) | |||
| diag = torch.diag(ones) | |||
| mask_diag = (ones - diag) * exp_dist | |||
| mask_diag_sum = mask_diag.sum(dim=1, keepdim=True) | |||
| mask_diag = mask_diag / (2 * mask_diag_sum) + 0.5 * diag | |||
| return mask_diag | |||
| def sparse_kernel(exp_dist: torch.Tensor, k: int) -> torch.Tensor: | |||
| """Calculate sparse kernel using k-nearest neighbors.""" | |||
| n = exp_dist.shape[0] | |||
| maxk = torch.topk(exp_dist, k, dim=1) | |||
| mink_indices = torch.topk(exp_dist, n - k, dim=1, largest=False).indices | |||
| exp_dist[torch.arange(n, device=exp_dist.device).view(n, -1), mink_indices] = 0 | |||
| knn_sum = maxk.values.sum(dim=1, keepdim=True) | |||
| return exp_dist / knn_sum | |||
| def scale_sigmoid(tensor: torch.Tensor, alpha: float) -> torch.Tensor: | |||
| """Apply scaled sigmoid transformation.""" | |||
| alpha = torch.tensor(alpha, dtype=torch.float32, device=tensor.device) | |||
| return torch.sigmoid(alpha * tensor) | |||
| # ---------------------------------------------------------------------------- | |||
| # Data Processing and Helper Functions | |||
| # ---------------------------------------------------------------------------- | |||
| def init_seeds(seed: int = 0) -> None: | |||
| """Initialize random seeds for reproducibility.""" | |||
| np.random.seed(seed) | |||
| torch.manual_seed(seed) | |||
| torch.cuda.manual_seed(seed) | |||
| torch.cuda.manual_seed_all(seed) | |||
| torch.backends.cudnn.deterministic = True | |||
| torch.backends.cudnn.benchmark = False | |||
| def distribute_compute( | |||
| lr_list: List[float], | |||
| wd_list: List[float], | |||
| scale_list: List[float], | |||
| layer_size: List[int], | |||
| sigma_list: List[float], | |||
| beta_list: List[float], | |||
| workers: int, | |||
| id: int | |||
| ) -> np.ndarray: | |||
| """Distribute hyperparameter combinations across workers.""" | |||
| all_combinations = [ | |||
| [lr, wd, sc, la, sg, bt] | |||
| for lr, wd, sc, la, sg, bt in it.product(lr_list, wd_list, scale_list, layer_size, sigma_list, beta_list) | |||
| ] | |||
| return np.array_split(all_combinations, workers)[id] | |||
| def get_fingerprint(cid: int) -> np.ndarray: | |||
| """Retrieve PubChem fingerprint for a given compound CID.""" | |||
| compound = pcp.Compound.from_cid(cid) | |||
| fingerprint = "".join(f"{int(bit, 16):04b}" for bit in compound.fingerprint) | |||
| return np.array([int(b) for b in fingerprint], dtype=np.int32) | |||
| def save_fingerprint(cid_list: List[int], last_cid: int, fpath: str) -> None: | |||
| """Save fingerprints for a list of compound CIDs to disk.""" | |||
| start_idx = np.where(np.array(cid_list) == last_cid)[0][0] + 1 if last_cid > 0 else 0 | |||
| for cid in cid_list[start_idx:]: | |||
| fingerprint = get_fingerprint(cid) | |||
| np.save(os.path.join(fpath, str(cid)), fingerprint) | |||
| print(f"CID {cid} processed successfully.") | |||
| time.sleep(1) | |||
| if start_idx >= len(cid_list): | |||
| print("All compounds have been processed!") | |||
| def read_fingerprint_cid(path: str) -> Tuple[np.ndarray, List[int]]: | |||
| """Read fingerprints from .npy files in the specified directory.""" | |||
| fingerprint = [] | |||
| cids = [] | |||
| for file_name in sorted(os.listdir(path)): | |||
| if file_name.endswith(".npy"): | |||
| cid = int(file_name.split(".")[0]) | |||
| fing = np.load(os.path.join(path, file_name)) | |||
| fingerprint.append(fing) | |||
| cids.append(cid) | |||
| fingerprint = np.array(fingerprint).reshape(-1, 920) | |||
| return fingerprint, cids | |||
| def common_data_index(data_for_index: np.ndarray, data_for_cmp: np.ndarray) -> np.ndarray: | |||
| """Find indices of elements in data_for_index that exist in data_for_cmp.""" | |||
| return np.where(np.isin(data_for_index, data_for_cmp))[0] | |||
| def to_coo_matrix(adj_mat: Union[np.ndarray, sp.coo_matrix]) -> sp.coo_matrix: | |||
| """Convert input matrix to scipy.sparse.coo_matrix format.""" | |||
| if not sp.isspmatrix_coo(adj_mat): | |||
| adj_mat = sp.coo_matrix(adj_mat) | |||
| return adj_mat | |||
| def mask(positive: sp.coo_matrix, negative: sp.coo_matrix, dtype: type = int) -> torch.Tensor: | |||
| """Create a mask combining positive and negative edges.""" | |||
| row = np.hstack((positive.row, negative.row)) | |||
| col = np.hstack((positive.col, negative.col)) | |||
| data = np.ones_like(row) | |||
| masked = sp.coo_matrix((data, (row, col)), shape=positive.shape).toarray().astype(dtype) | |||
| return torch.from_numpy(masked) | |||
| def to_tensor(positive: sp.coo_matrix, identity: bool = False) -> torch.Tensor: | |||
| """Convert sparse matrix to torch.Tensor, optionally adding identity matrix.""" | |||
| data = positive + sp.identity(positive.shape[0]) if identity else positive | |||
| return torch.from_numpy(data.toarray()).float() | |||
| def np_delete_value(arr: np.ndarray, obj: np.ndarray) -> np.ndarray: | |||
| """Remove specified values from a NumPy array.""" | |||
| indices = [np.where(arr == x)[0][0] for x in obj if x in arr] | |||
| return np.delete(arr, indices) | |||
| def translate_result(tensor: Union[torch.Tensor, np.ndarray]) -> pd.DataFrame: | |||
| """Convert tensor or array to a pandas DataFrame.""" | |||
| if isinstance(tensor, torch.Tensor): | |||
| tensor = tensor.detach().cpu().numpy() | |||
| return pd.DataFrame(tensor.reshape(1, -1)) | |||
| def calculate_train_test_index( | |||
| response: np.ndarray, | |||
| pos_train_index: np.ndarray, | |||
| pos_test_index: np.ndarray | |||
| ) -> Tuple[np.ndarray, np.ndarray]: | |||
| """Calculate train and test indices combining positive and negative samples.""" | |||
| neg_response_index = np.where(response == 0)[0] | |||
| neg_test_index = np.random.choice(neg_response_index, pos_test_index.shape[0], replace=False) | |||
| neg_train_index = np_delete_value(neg_response_index, neg_test_index) | |||
| test_index = np.hstack((pos_test_index, neg_test_index)) | |||
| train_index = np.hstack((pos_train_index, neg_train_index)) | |||
| return train_index, test_index | |||
| def dir_path(k: int = 1) -> str: | |||
| """Get directory path by traversing k levels up from current file.""" | |||
| fpath = os.path.realpath(__file__) | |||
| dir_name = os.path.dirname(fpath).replace("\\", "/") | |||
| for _ in range(k): | |||
| dir_name = os.path.dirname(dir_name) | |||
| return dir_name | |||
| def extract_row_data(data: pd.DataFrame, row: int) -> np.ndarray: | |||
| """Extract non-NaN data from a specific row of a DataFrame.""" | |||
| target = np.array(data.iloc[row], dtype=np.float32) | |||
| return target[~np.isnan(target)] | |||
| def transfer_data(data: pd.DataFrame, label: str) -> pd.DataFrame: | |||
| """Add a label column to a DataFrame.""" | |||
| data = data.copy() | |||
| data["label"] = label | |||
| return data | |||
| def link_data_frame(*data: pd.DataFrame) -> pd.DataFrame: | |||
| """Concatenate multiple DataFrames vertically.""" | |||
| return pd.concat(data, ignore_index=True) | |||
| def calculate_limit(*data: pd.DataFrame, key: Union[str, int]) -> Tuple[float, float]: | |||
| """Calculate min and max values of a key across multiple DataFrames.""" | |||
| temp = pd.concat(data, ignore_index=True) | |||
| return temp[key].min() - 0.1, temp[key].max() + 0.1 | |||
| def delete_all_sub_str(string: str, sub: str, join_str: str = "") -> str: | |||
| """Remove all occurrences of a substring and join with specified string.""" | |||
| parts = string.split(sub) | |||
| parts = [p for p in parts if p] | |||
| return join_str.join(parts) | |||
| def get_best_index(fname: str) -> int: | |||
| """Find the index of the AUC closest to the average AUC from a results file.""" | |||
| with open(fname, "r") as file: | |||
| content = file.read().replace("\n", "") | |||
| auc_str = content.split("accs")[0].split(":")[1] | |||
| auc_str = delete_all_sub_str(auc_str, " ", ",").replace(",]", "]") | |||
| aucs = np.array(eval(auc_str)) | |||
| avg_auc = float(content.split("avg_aucs")[1].split(":")[1].split()[0]) | |||
| return np.argmin(np.abs(aucs - avg_auc)) | |||