|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
import time
|
|
|
|
|
|
from typing import Tuple, List, Union, Optional
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
import seaborn as sns
|
|
|
|
|
|
import pubchempy as pcp
|
|
|
|
|
|
import scipy.sparse as sp
|
|
|
|
|
|
from sklearn.metrics import roc_auc_score, average_precision_score
|
|
|
|
|
|
import itertools as it
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------------------------------------------------------------------
|
|
|
|
|
|
# Model Evaluation Functions
|
|
|
|
|
|
# ----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
def roc_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> float:
|
|
|
|
|
|
"""Calculate ROC-AUC score for binary classification."""
|
|
|
|
|
|
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
|
|
|
|
|
|
return roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ap_score(true_data: torch.Tensor, predict_data: torch.Tensor) -> float:
|
|
|
|
|
|
"""Calculate Average Precision (area under Precision-Recall curve)."""
|
|
|
|
|
|
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
|
|
|
|
|
|
return average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def f1_score_binary(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]:
|
|
|
|
|
|
"""Calculate F1 score and the optimal threshold for binary classification."""
|
|
|
|
|
|
thresholds = torch.unique(predict_data)
|
|
|
|
|
|
n_samples = true_data.size(0)
|
|
|
|
|
|
ones = torch.ones((thresholds.size(0), n_samples), device=true_data.device)
|
|
|
|
|
|
zeros = torch.zeros((thresholds.size(0), n_samples), device=true_data.device)
|
|
|
|
|
|
predict_value = torch.where(predict_data.view(1, -1) >= thresholds.view(-1, 1), ones, zeros)
|
|
|
|
|
|
|
|
|
|
|
|
tpn = torch.sum(torch.where(predict_value == true_data.view(1, -1), ones, zeros), dim=1)
|
|
|
|
|
|
tp = torch.sum(predict_value * true_data.view(1, -1), dim=1)
|
|
|
|
|
|
scores = (2 * tp) / (n_samples + 2 * tp - tpn)
|
|
|
|
|
|
|
|
|
|
|
|
max_f1_score = torch.max(scores)
|
|
|
|
|
|
threshold = thresholds[torch.argmax(scores)]
|
|
|
|
|
|
return max_f1_score.item(), threshold.item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def accuracy_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
|
|
|
|
|
|
"""Calculate accuracy using the specified threshold."""
|
|
|
|
|
|
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
|
|
|
|
|
|
correct = torch.sum(predict_value == true_data).float()
|
|
|
|
|
|
return (correct / true_data.size(0)).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def precision_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
|
|
|
|
|
|
"""Calculate precision using the specified threshold."""
|
|
|
|
|
|
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
|
|
|
|
|
|
tp = torch.sum(true_data * predict_value)
|
|
|
|
|
|
fp = torch.sum((1 - true_data) * predict_value)
|
|
|
|
|
|
return (tp / (tp + fp + 1e-8)).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def recall_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
|
|
|
|
|
|
"""Calculate recall using the specified threshold."""
|
|
|
|
|
|
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
|
|
|
|
|
|
tp = torch.sum(true_data * predict_value)
|
|
|
|
|
|
fn = torch.sum(true_data * (1 - predict_value))
|
|
|
|
|
|
return (tp / (tp + fn + 1e-8)).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mcc_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
|
|
|
|
|
|
"""Calculate Matthews Correlation Coefficient (MCC) using the specified threshold."""
|
|
|
|
|
|
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
|
|
|
|
|
|
true_neg = 1 - true_data
|
|
|
|
|
|
predict_neg = 1 - predict_value
|
|
|
|
|
|
|
|
|
|
|
|
tp = torch.sum(true_data * predict_value)
|
|
|
|
|
|
tn = torch.sum(true_neg * predict_neg)
|
|
|
|
|
|
fp = torch.sum(true_neg * predict_value)
|
|
|
|
|
|
fn = torch.sum(true_data * predict_neg)
|
|
|
|
|
|
|
|
|
|
|
|
denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + 1e-8)
|
|
|
|
|
|
return ((tp * tn - fp * fn) / denominator).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_all(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, ...]:
|
|
|
|
|
|
"""Evaluate multiple metrics: ROC-AUC, AP, accuracy, F1, and MCC."""
|
|
|
|
|
|
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
|
|
|
|
|
|
auc = roc_auc(true_data, predict_data)
|
|
|
|
|
|
ap = ap_score(true_data, predict_data)
|
|
|
|
|
|
f1, threshold = f1_score_binary(true_data, predict_data)
|
|
|
|
|
|
acc = accuracy_binary(true_data, predict_data, threshold)
|
|
|
|
|
|
precision = precision_binary(true_data, predict_data, threshold)
|
|
|
|
|
|
recall = recall_binary(true_data, predict_data, threshold)
|
|
|
|
|
|
mcc = mcc_binary(true_data, predict_data, threshold)
|
|
|
|
|
|
return auc, ap, acc, f1, mcc, precision, recall
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]:
|
|
|
|
|
|
"""Calculate ROC-AUC and Average Precision."""
|
|
|
|
|
|
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
|
|
|
|
|
|
auc = roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
|
|
|
|
|
|
ap = average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
|
|
|
|
|
|
return auc, ap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------------------------------------------------------------------
|
|
|
|
|
|
# Loss Functions
|
|
|
|
|
|
# ----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
def cross_entropy_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
"""Calculate masked binary cross-entropy loss."""
|
|
|
|
|
|
masked = masked.to(torch.bool)
|
|
|
|
|
|
true_data = torch.masked_select(true_data, masked)
|
|
|
|
|
|
pred_data = torch.masked_select(predict_data, masked)
|
|
|
|
|
|
return nn.BCELoss()(pred_data, true_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mse_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
"""Calculate masked mean squared error loss."""
|
|
|
|
|
|
true_data = true_data * masked
|
|
|
|
|
|
predict_data = predict_data * masked
|
|
|
|
|
|
return nn.MSELoss()(predict_data, true_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prototypical_loss(
|
|
|
|
|
|
cell_emb: torch.Tensor,
|
|
|
|
|
|
drug_emb: torch.Tensor,
|
|
|
|
|
|
adj_matrix: Union[torch.Tensor, np.ndarray],
|
|
|
|
|
|
margin: float = 2.0
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
"""Calculate prototypical loss for positive and negative pairs."""
|
|
|
|
|
|
if isinstance(adj_matrix, torch.Tensor):
|
|
|
|
|
|
adj_matrix = sp.coo_matrix(adj_matrix.detach().cpu().numpy())
|
|
|
|
|
|
|
|
|
|
|
|
pos_pairs = torch.sum(cell_emb[adj_matrix.row] * drug_emb[adj_matrix.col], dim=1)
|
|
|
|
|
|
n_pos = len(adj_matrix.row)
|
|
|
|
|
|
|
|
|
|
|
|
cell_neg = torch.randint(0, cell_emb.size(0), (n_pos,), device=cell_emb.device)
|
|
|
|
|
|
drug_neg = torch.randint(0, drug_emb.size(0), (n_pos,), device=drug_emb.device)
|
|
|
|
|
|
neg_pairs = torch.sum(cell_emb[cell_neg] * drug_emb[drug_neg], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
labels = torch.ones_like(pos_pairs, device=cell_emb.device)
|
|
|
|
|
|
return F.margin_ranking_loss(pos_pairs, neg_pairs, labels, margin=margin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------------------------------------------------------------------
|
|
|
|
|
|
# Correlation and Normalization Functions
|
|
|
|
|
|
# ----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
def torch_z_normalized(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
|
|
|
|
|
"""Apply z-normalization along the specified dimension."""
|
|
|
|
|
|
mean = tensor.mean(dim=1 - dim, keepdim=True)
|
|
|
|
|
|
std = tensor.std(dim=1 - dim, keepdim=True) + 1e-8
|
|
|
|
|
|
return (tensor - mean) / std
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_corr_x_y(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
"""Compute correlation matrix between row vectors of two matrices."""
|
|
|
|
|
|
x_center = x - x.mean(dim=1, keepdim=True)
|
|
|
|
|
|
y_center = y - y.mean(dim=1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
|
|
x_std = x.std(dim=1, keepdim=True) + 1e-8
|
|
|
|
|
|
y_std = y.std(dim=1, keepdim=True) + 1e-8
|
|
|
|
|
|
|
|
|
|
|
|
x_norm = x_center / x_std
|
|
|
|
|
|
y_norm = y_center / y_std
|
|
|
|
|
|
|
|
|
|
|
|
corr_matrix = x_norm @ y_norm.t() / (x.size(1) - 1)
|
|
|
|
|
|
return corr_matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------------------------------------------------------------------
|
|
|
|
|
|
# Distance and Similarity Functions
|
|
|
|
|
|
# ----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
def torch_euclidean_dist(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
|
|
|
|
|
"""Calculate Euclidean distance between rows or columns of a tensor."""
|
|
|
|
|
|
tensor_mul = torch.mm(tensor.t(), tensor) if dim else torch.mm(tensor, tensor.t())
|
|
|
|
|
|
diag = torch.diag(tensor_mul)
|
|
|
|
|
|
n_diag = diag.size(0)
|
|
|
|
|
|
tensor_diag = diag.repeat(n_diag, 1)
|
|
|
|
|
|
diag = diag.view(n_diag, -1)
|
|
|
|
|
|
dist = torch.sqrt(tensor_diag + diag - 2 * tensor_mul)
|
|
|
|
|
|
return dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exp_similarity(tensor: torch.Tensor, sigma: torch.Tensor, normalize: bool = True) -> torch.Tensor:
|
|
|
|
|
|
"""Calculate exponential similarity based on Euclidean distance."""
|
|
|
|
|
|
if normalize:
|
|
|
|
|
|
tensor = torch_z_normalized(tensor, dim=1)
|
|
|
|
|
|
tensor_dist = torch_euclidean_dist(tensor, dim=0)
|
|
|
|
|
|
return torch.exp(-tensor_dist / (2 * sigma.pow(2)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def full_kernel(exp_dist: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
"""Calculate full kernel matrix from exponential similarity."""
|
|
|
|
|
|
n = exp_dist.shape[0]
|
|
|
|
|
|
ones = torch.ones(n, n, device=exp_dist.device)
|
|
|
|
|
|
diag = torch.diag(ones)
|
|
|
|
|
|
mask_diag = (ones - diag) * exp_dist
|
|
|
|
|
|
mask_diag_sum = mask_diag.sum(dim=1, keepdim=True)
|
|
|
|
|
|
mask_diag = mask_diag / (2 * mask_diag_sum) + 0.5 * diag
|
|
|
|
|
|
return mask_diag
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sparse_kernel(exp_dist: torch.Tensor, k: int) -> torch.Tensor:
|
|
|
|
|
|
"""Calculate sparse kernel using k-nearest neighbors."""
|
|
|
|
|
|
n = exp_dist.shape[0]
|
|
|
|
|
|
maxk = torch.topk(exp_dist, k, dim=1)
|
|
|
|
|
|
mink_indices = torch.topk(exp_dist, n - k, dim=1, largest=False).indices
|
|
|
|
|
|
exp_dist[torch.arange(n, device=exp_dist.device).view(n, -1), mink_indices] = 0
|
|
|
|
|
|
knn_sum = maxk.values.sum(dim=1, keepdim=True)
|
|
|
|
|
|
return exp_dist / knn_sum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scale_sigmoid(tensor: torch.Tensor, alpha: float) -> torch.Tensor:
|
|
|
|
|
|
"""Apply scaled sigmoid transformation."""
|
|
|
|
|
|
alpha = torch.tensor(alpha, dtype=torch.float32, device=tensor.device)
|
|
|
|
|
|
return torch.sigmoid(alpha * tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------------------------------------------------------------------
|
|
|
|
|
|
# Data Processing and Helper Functions
|
|
|
|
|
|
# ----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
def init_seeds(seed: int = 0) -> None:
|
|
|
|
|
|
"""Initialize random seeds for reproducibility."""
|
|
|
|
|
|
np.random.seed(seed)
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def distribute_compute(
|
|
|
|
|
|
lr_list: List[float],
|
|
|
|
|
|
wd_list: List[float],
|
|
|
|
|
|
scale_list: List[float],
|
|
|
|
|
|
layer_size: List[int],
|
|
|
|
|
|
sigma_list: List[float],
|
|
|
|
|
|
beta_list: List[float],
|
|
|
|
|
|
workers: int,
|
|
|
|
|
|
id: int
|
|
|
|
|
|
) -> np.ndarray:
|
|
|
|
|
|
"""Distribute hyperparameter combinations across workers."""
|
|
|
|
|
|
all_combinations = [
|
|
|
|
|
|
[lr, wd, sc, la, sg, bt]
|
|
|
|
|
|
for lr, wd, sc, la, sg, bt in it.product(lr_list, wd_list, scale_list, layer_size, sigma_list, beta_list)
|
|
|
|
|
|
]
|
|
|
|
|
|
return np.array_split(all_combinations, workers)[id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_fingerprint(cid: int) -> np.ndarray:
|
|
|
|
|
|
"""Retrieve PubChem fingerprint for a given compound CID."""
|
|
|
|
|
|
compound = pcp.Compound.from_cid(cid)
|
|
|
|
|
|
fingerprint = "".join(f"{int(bit, 16):04b}" for bit in compound.fingerprint)
|
|
|
|
|
|
return np.array([int(b) for b in fingerprint], dtype=np.int32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_fingerprint(cid_list: List[int], last_cid: int, fpath: str) -> None:
|
|
|
|
|
|
"""Save fingerprints for a list of compound CIDs to disk."""
|
|
|
|
|
|
start_idx = np.where(np.array(cid_list) == last_cid)[0][0] + 1 if last_cid > 0 else 0
|
|
|
|
|
|
for cid in cid_list[start_idx:]:
|
|
|
|
|
|
fingerprint = get_fingerprint(cid)
|
|
|
|
|
|
np.save(os.path.join(fpath, str(cid)), fingerprint)
|
|
|
|
|
|
print(f"CID {cid} processed successfully.")
|
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
if start_idx >= len(cid_list):
|
|
|
|
|
|
print("All compounds have been processed!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_fingerprint_cid(path: str) -> Tuple[np.ndarray, List[int]]:
|
|
|
|
|
|
"""Read fingerprints from .npy files in the specified directory."""
|
|
|
|
|
|
fingerprint = []
|
|
|
|
|
|
cids = []
|
|
|
|
|
|
for file_name in sorted(os.listdir(path)):
|
|
|
|
|
|
if file_name.endswith(".npy"):
|
|
|
|
|
|
cid = int(file_name.split(".")[0])
|
|
|
|
|
|
fing = np.load(os.path.join(path, file_name))
|
|
|
|
|
|
fingerprint.append(fing)
|
|
|
|
|
|
cids.append(cid)
|
|
|
|
|
|
fingerprint = np.array(fingerprint).reshape(-1, 920)
|
|
|
|
|
|
return fingerprint, cids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def common_data_index(data_for_index: np.ndarray, data_for_cmp: np.ndarray) -> np.ndarray:
|
|
|
|
|
|
"""Find indices of elements in data_for_index that exist in data_for_cmp."""
|
|
|
|
|
|
return np.where(np.isin(data_for_index, data_for_cmp))[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_coo_matrix(adj_mat: Union[np.ndarray, sp.coo_matrix]) -> sp.coo_matrix:
|
|
|
|
|
|
"""Convert input matrix to scipy.sparse.coo_matrix format."""
|
|
|
|
|
|
if not sp.isspmatrix_coo(adj_mat):
|
|
|
|
|
|
adj_mat = sp.coo_matrix(adj_mat)
|
|
|
|
|
|
return adj_mat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mask(positive: sp.coo_matrix, negative: sp.coo_matrix, dtype: type = int) -> torch.Tensor:
|
|
|
|
|
|
"""Create a mask combining positive and negative edges."""
|
|
|
|
|
|
row = np.hstack((positive.row, negative.row))
|
|
|
|
|
|
col = np.hstack((positive.col, negative.col))
|
|
|
|
|
|
data = np.ones_like(row)
|
|
|
|
|
|
masked = sp.coo_matrix((data, (row, col)), shape=positive.shape).toarray().astype(dtype)
|
|
|
|
|
|
return torch.from_numpy(masked)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_tensor(positive: sp.coo_matrix, identity: bool = False) -> torch.Tensor:
|
|
|
|
|
|
"""Convert sparse matrix to torch.Tensor, optionally adding identity matrix."""
|
|
|
|
|
|
data = positive + sp.identity(positive.shape[0]) if identity else positive
|
|
|
|
|
|
return torch.from_numpy(data.toarray()).float()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def np_delete_value(arr: np.ndarray, obj: np.ndarray) -> np.ndarray:
|
|
|
|
|
|
"""Remove specified values from a NumPy array."""
|
|
|
|
|
|
indices = [np.where(arr == x)[0][0] for x in obj if x in arr]
|
|
|
|
|
|
return np.delete(arr, indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate_result(tensor: Union[torch.Tensor, np.ndarray]) -> pd.DataFrame:
|
|
|
|
|
|
"""Convert tensor or array to a pandas DataFrame."""
|
|
|
|
|
|
if isinstance(tensor, torch.Tensor):
|
|
|
|
|
|
tensor = tensor.detach().cpu().numpy()
|
|
|
|
|
|
return pd.DataFrame(tensor.reshape(1, -1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_train_test_index(
|
|
|
|
|
|
response: np.ndarray,
|
|
|
|
|
|
pos_train_index: np.ndarray,
|
|
|
|
|
|
pos_test_index: np.ndarray
|
|
|
|
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
|
|
|
|
"""Calculate train and test indices combining positive and negative samples."""
|
|
|
|
|
|
neg_response_index = np.where(response == 0)[0]
|
|
|
|
|
|
neg_test_index = np.random.choice(neg_response_index, pos_test_index.shape[0], replace=False)
|
|
|
|
|
|
neg_train_index = np_delete_value(neg_response_index, neg_test_index)
|
|
|
|
|
|
test_index = np.hstack((pos_test_index, neg_test_index))
|
|
|
|
|
|
train_index = np.hstack((pos_train_index, neg_train_index))
|
|
|
|
|
|
return train_index, test_index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dir_path(k: int = 1) -> str:
|
|
|
|
|
|
"""Get directory path by traversing k levels up from current file."""
|
|
|
|
|
|
fpath = os.path.realpath(__file__)
|
|
|
|
|
|
dir_name = os.path.dirname(fpath).replace("\\", "/")
|
|
|
|
|
|
for _ in range(k):
|
|
|
|
|
|
dir_name = os.path.dirname(dir_name)
|
|
|
|
|
|
return dir_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_row_data(data: pd.DataFrame, row: int) -> np.ndarray:
|
|
|
|
|
|
"""Extract non-NaN data from a specific row of a DataFrame."""
|
|
|
|
|
|
target = np.array(data.iloc[row], dtype=np.float32)
|
|
|
|
|
|
return target[~np.isnan(target)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transfer_data(data: pd.DataFrame, label: str) -> pd.DataFrame:
|
|
|
|
|
|
"""Add a label column to a DataFrame."""
|
|
|
|
|
|
data = data.copy()
|
|
|
|
|
|
data["label"] = label
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def link_data_frame(*data: pd.DataFrame) -> pd.DataFrame:
|
|
|
|
|
|
"""Concatenate multiple DataFrames vertically."""
|
|
|
|
|
|
return pd.concat(data, ignore_index=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_limit(*data: pd.DataFrame, key: Union[str, int]) -> Tuple[float, float]:
|
|
|
|
|
|
"""Calculate min and max values of a key across multiple DataFrames."""
|
|
|
|
|
|
temp = pd.concat(data, ignore_index=True)
|
|
|
|
|
|
return temp[key].min() - 0.1, temp[key].max() + 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def delete_all_sub_str(string: str, sub: str, join_str: str = "") -> str:
|
|
|
|
|
|
"""Remove all occurrences of a substring and join with specified string."""
|
|
|
|
|
|
parts = string.split(sub)
|
|
|
|
|
|
parts = [p for p in parts if p]
|
|
|
|
|
|
return join_str.join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_best_index(fname: str) -> int:
|
|
|
|
|
|
"""Find the index of the AUC closest to the average AUC from a results file."""
|
|
|
|
|
|
with open(fname, "r") as file:
|
|
|
|
|
|
content = file.read().replace("\n", "")
|
|
|
|
|
|
|
|
|
|
|
|
auc_str = content.split("accs")[0].split(":")[1]
|
|
|
|
|
|
auc_str = delete_all_sub_str(auc_str, " ", ",").replace(",]", "]")
|
|
|
|
|
|
aucs = np.array(eval(auc_str))
|
|
|
|
|
|
|
|
|
|
|
|
avg_auc = float(content.split("avg_aucs")[1].split(":")[1].split()[0])
|
|
|
|
|
|
return np.argmin(np.abs(aucs - avg_auc))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gather_color_code(*string: str) -> List[Tuple[float, float, float]]:
|
|
|
|
|
|
"""Map color names to seaborn color palette codes."""
|
|
|
|
|
|
color_str = ["bluea0"] = ["blue", "orange", "green", "red", "purple", "brown", "pink", "grey", "yellow", "cyan"]
|
|
|
|
|
|
palette = sns.color_palette()
|
|
|
|
|
|
color_map = dict(zip(color_str, palette))
|
|
|
|
|
|
return [color_map[color] for color in string] |
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
import time |
|
|
|
|
|
from typing import Tuple, List, Union, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
import seaborn as sns |
|
|
|
|
|
import pubchempy as pcp |
|
|
|
|
|
import scipy.sparse as sp |
|
|
|
|
|
from sklearn.metrics import roc_auc_score, average_precision_score |
|
|
|
|
|
import itertools as it |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------- |
|
|
|
|
|
# Model Evaluation Functions |
|
|
|
|
|
# ---------------------------------------------------------------------------- |
|
|
|
|
|
|
|
|
|
|
|
def roc_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> float: |
|
|
|
|
|
"""Calculate ROC-AUC score for binary classification.""" |
|
|
|
|
|
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" |
|
|
|
|
|
return roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ap_score(true_data: torch.Tensor, predict_data: torch.Tensor) -> float: |
|
|
|
|
|
"""Calculate Average Precision (area under Precision-Recall curve).""" |
|
|
|
|
|
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" |
|
|
|
|
|
return average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def f1_score_binary(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]: |
|
|
|
|
|
"""Calculate F1 score and the optimal threshold for binary classification.""" |
|
|
|
|
|
thresholds = torch.unique(predict_data) |
|
|
|
|
|
n_samples = true_data.size(0) |
|
|
|
|
|
ones = torch.ones((thresholds.size(0), n_samples), device=true_data.device) |
|
|
|
|
|
zeros = torch.zeros((thresholds.size(0), n_samples), device=true_data.device) |
|
|
|
|
|
predict_value = torch.where(predict_data.view(1, -1) >= thresholds.view(-1, 1), ones, zeros) |
|
|
|
|
|
|
|
|
|
|
|
tpn = torch.sum(torch.where(predict_value == true_data.view(1, -1), ones, zeros), dim=1) |
|
|
|
|
|
tp = torch.sum(predict_value * true_data.view(1, -1), dim=1) |
|
|
|
|
|
scores = (2 * tp) / (n_samples + 2 * tp - tpn) |
|
|
|
|
|
|
|
|
|
|
|
max_f1_score = torch.max(scores) |
|
|
|
|
|
threshold = thresholds[torch.argmax(scores)] |
|
|
|
|
|
return max_f1_score.item(), threshold.item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def accuracy_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: |
|
|
|
|
|
"""Calculate accuracy using the specified threshold.""" |
|
|
|
|
|
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) |
|
|
|
|
|
correct = torch.sum(predict_value == true_data).float() |
|
|
|
|
|
return (correct / true_data.size(0)).item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def precision_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: |
|
|
|
|
|
"""Calculate precision using the specified threshold.""" |
|
|
|
|
|
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) |
|
|
|
|
|
tp = torch.sum(true_data * predict_value) |
|
|
|
|
|
fp = torch.sum((1 - true_data) * predict_value) |
|
|
|
|
|
return (tp / (tp + fp + 1e-8)).item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def recall_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: |
|
|
|
|
|
"""Calculate recall using the specified threshold.""" |
|
|
|
|
|
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) |
|
|
|
|
|
tp = torch.sum(true_data * predict_value) |
|
|
|
|
|
fn = torch.sum(true_data * (1 - predict_value)) |
|
|
|
|
|
return (tp / (tp + fn + 1e-8)).item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mcc_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float: |
|
|
|
|
|
"""Calculate Matthews Correlation Coefficient (MCC) using the specified threshold.""" |
|
|
|
|
|
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0) |
|
|
|
|
|
true_neg = 1 - true_data |
|
|
|
|
|
predict_neg = 1 - predict_value |
|
|
|
|
|
|
|
|
|
|
|
tp = torch.sum(true_data * predict_value) |
|
|
|
|
|
tn = torch.sum(true_neg * predict_neg) |
|
|
|
|
|
fp = torch.sum(true_neg * predict_value) |
|
|
|
|
|
fn = torch.sum(true_data * predict_neg) |
|
|
|
|
|
|
|
|
|
|
|
denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + 1e-8) |
|
|
|
|
|
return ((tp * tn - fp * fn) / denominator).item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_all(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, ...]: |
|
|
|
|
|
"""Evaluate multiple metrics: ROC-AUC, AP, accuracy, F1, and MCC.""" |
|
|
|
|
|
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" |
|
|
|
|
|
auc = roc_auc(true_data, predict_data) |
|
|
|
|
|
ap = ap_score(true_data, predict_data) |
|
|
|
|
|
f1, threshold = f1_score_binary(true_data, predict_data) |
|
|
|
|
|
acc = accuracy_binary(true_data, predict_data, threshold) |
|
|
|
|
|
precision = precision_binary(true_data, predict_data, threshold) |
|
|
|
|
|
recall = recall_binary(true_data, predict_data, threshold) |
|
|
|
|
|
mcc = mcc_binary(true_data, predict_data, threshold) |
|
|
|
|
|
return auc, ap, acc, f1, mcc, precision, recall |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]: |
|
|
|
|
|
"""Calculate ROC-AUC and Average Precision.""" |
|
|
|
|
|
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1" |
|
|
|
|
|
auc = roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) |
|
|
|
|
|
ap = average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy()) |
|
|
|
|
|
return auc, ap |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------- |
|
|
|
|
|
# Loss Functions |
|
|
|
|
|
# ---------------------------------------------------------------------------- |
|
|
|
|
|
|
|
|
|
|
|
def cross_entropy_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
"""Calculate masked binary cross-entropy loss.""" |
|
|
|
|
|
masked = masked.to(torch.bool) |
|
|
|
|
|
true_data = torch.masked_select(true_data, masked) |
|
|
|
|
|
pred_data = torch.masked_select(predict_data, masked) |
|
|
|
|
|
return nn.BCELoss()(pred_data, true_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mse_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
"""Calculate masked mean squared error loss.""" |
|
|
|
|
|
true_data = true_data * masked |
|
|
|
|
|
predict_data = predict_data * masked |
|
|
|
|
|
return nn.MSELoss()(predict_data, true_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prototypical_loss( |
|
|
|
|
|
cell_emb: torch.Tensor, |
|
|
|
|
|
drug_emb: torch.Tensor, |
|
|
|
|
|
adj_matrix: Union[torch.Tensor, np.ndarray], |
|
|
|
|
|
margin: float = 2.0 |
|
|
|
|
|
) -> torch.Tensor: |
|
|
|
|
|
"""Calculate prototypical loss for positive and negative pairs.""" |
|
|
|
|
|
if isinstance(adj_matrix, torch.Tensor): |
|
|
|
|
|
adj_matrix = sp.coo_matrix(adj_matrix.detach().cpu().numpy()) |
|
|
|
|
|
|
|
|
|
|
|
pos_pairs = torch.sum(cell_emb[adj_matrix.row] * drug_emb[adj_matrix.col], dim=1) |
|
|
|
|
|
n_pos = len(adj_matrix.row) |
|
|
|
|
|
|
|
|
|
|
|
cell_neg = torch.randint(0, cell_emb.size(0), (n_pos,), device=cell_emb.device) |
|
|
|
|
|
drug_neg = torch.randint(0, drug_emb.size(0), (n_pos,), device=drug_emb.device) |
|
|
|
|
|
neg_pairs = torch.sum(cell_emb[cell_neg] * drug_emb[drug_neg], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
labels = torch.ones_like(pos_pairs, device=cell_emb.device) |
|
|
|
|
|
return F.margin_ranking_loss(pos_pairs, neg_pairs, labels, margin=margin) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------- |
|
|
|
|
|
# Correlation and Normalization Functions |
|
|
|
|
|
# ---------------------------------------------------------------------------- |
|
|
|
|
|
|
|
|
|
|
|
def torch_z_normalized(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: |
|
|
|
|
|
"""Apply z-normalization along the specified dimension.""" |
|
|
|
|
|
mean = tensor.mean(dim=1 - dim, keepdim=True) |
|
|
|
|
|
std = tensor.std(dim=1 - dim, keepdim=True) + 1e-8 |
|
|
|
|
|
return (tensor - mean) / std |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_corr_x_y(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
"""Compute correlation matrix between row vectors of two matrices.""" |
|
|
|
|
|
x_center = x - x.mean(dim=1, keepdim=True) |
|
|
|
|
|
y_center = y - y.mean(dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
|
|
|
x_std = x.std(dim=1, keepdim=True) + 1e-8 |
|
|
|
|
|
y_std = y.std(dim=1, keepdim=True) + 1e-8 |
|
|
|
|
|
|
|
|
|
|
|
x_norm = x_center / x_std |
|
|
|
|
|
y_norm = y_center / y_std |
|
|
|
|
|
|
|
|
|
|
|
corr_matrix = x_norm @ y_norm.t() / (x.size(1) - 1) |
|
|
|
|
|
return corr_matrix |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------- |
|
|
|
|
|
# Distance and Similarity Functions |
|
|
|
|
|
# ---------------------------------------------------------------------------- |
|
|
|
|
|
|
|
|
|
|
|
def torch_euclidean_dist(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: |
|
|
|
|
|
"""Calculate Euclidean distance between rows or columns of a tensor.""" |
|
|
|
|
|
tensor_mul = torch.mm(tensor.t(), tensor) if dim else torch.mm(tensor, tensor.t()) |
|
|
|
|
|
diag = torch.diag(tensor_mul) |
|
|
|
|
|
n_diag = diag.size(0) |
|
|
|
|
|
tensor_diag = diag.repeat(n_diag, 1) |
|
|
|
|
|
diag = diag.view(n_diag, -1) |
|
|
|
|
|
dist = torch.sqrt(tensor_diag + diag - 2 * tensor_mul) |
|
|
|
|
|
return dist |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exp_similarity(tensor: torch.Tensor, sigma: torch.Tensor, normalize: bool = True) -> torch.Tensor: |
|
|
|
|
|
"""Calculate exponential similarity based on Euclidean distance.""" |
|
|
|
|
|
if normalize: |
|
|
|
|
|
tensor = torch_z_normalized(tensor, dim=1) |
|
|
|
|
|
tensor_dist = torch_euclidean_dist(tensor, dim=0) |
|
|
|
|
|
return torch.exp(-tensor_dist / (2 * sigma.pow(2))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def full_kernel(exp_dist: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
"""Calculate full kernel matrix from exponential similarity.""" |
|
|
|
|
|
n = exp_dist.shape[0] |
|
|
|
|
|
ones = torch.ones(n, n, device=exp_dist.device) |
|
|
|
|
|
diag = torch.diag(ones) |
|
|
|
|
|
mask_diag = (ones - diag) * exp_dist |
|
|
|
|
|
mask_diag_sum = mask_diag.sum(dim=1, keepdim=True) |
|
|
|
|
|
mask_diag = mask_diag / (2 * mask_diag_sum) + 0.5 * diag |
|
|
|
|
|
return mask_diag |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sparse_kernel(exp_dist: torch.Tensor, k: int) -> torch.Tensor: |
|
|
|
|
|
"""Calculate sparse kernel using k-nearest neighbors.""" |
|
|
|
|
|
n = exp_dist.shape[0] |
|
|
|
|
|
maxk = torch.topk(exp_dist, k, dim=1) |
|
|
|
|
|
mink_indices = torch.topk(exp_dist, n - k, dim=1, largest=False).indices |
|
|
|
|
|
exp_dist[torch.arange(n, device=exp_dist.device).view(n, -1), mink_indices] = 0 |
|
|
|
|
|
knn_sum = maxk.values.sum(dim=1, keepdim=True) |
|
|
|
|
|
return exp_dist / knn_sum |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scale_sigmoid(tensor: torch.Tensor, alpha: float) -> torch.Tensor: |
|
|
|
|
|
"""Apply scaled sigmoid transformation.""" |
|
|
|
|
|
alpha = torch.tensor(alpha, dtype=torch.float32, device=tensor.device) |
|
|
|
|
|
return torch.sigmoid(alpha * tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------- |
|
|
|
|
|
# Data Processing and Helper Functions |
|
|
|
|
|
# ---------------------------------------------------------------------------- |
|
|
|
|
|
|
|
|
|
|
|
def init_seeds(seed: int = 0) -> None: |
|
|
|
|
|
"""Initialize random seeds for reproducibility.""" |
|
|
|
|
|
np.random.seed(seed) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def distribute_compute( |
|
|
|
|
|
lr_list: List[float], |
|
|
|
|
|
wd_list: List[float], |
|
|
|
|
|
scale_list: List[float], |
|
|
|
|
|
layer_size: List[int], |
|
|
|
|
|
sigma_list: List[float], |
|
|
|
|
|
beta_list: List[float], |
|
|
|
|
|
workers: int, |
|
|
|
|
|
id: int |
|
|
|
|
|
) -> np.ndarray: |
|
|
|
|
|
"""Distribute hyperparameter combinations across workers.""" |
|
|
|
|
|
all_combinations = [ |
|
|
|
|
|
[lr, wd, sc, la, sg, bt] |
|
|
|
|
|
for lr, wd, sc, la, sg, bt in it.product(lr_list, wd_list, scale_list, layer_size, sigma_list, beta_list) |
|
|
|
|
|
] |
|
|
|
|
|
return np.array_split(all_combinations, workers)[id] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_fingerprint(cid: int) -> np.ndarray: |
|
|
|
|
|
"""Retrieve PubChem fingerprint for a given compound CID.""" |
|
|
|
|
|
compound = pcp.Compound.from_cid(cid) |
|
|
|
|
|
fingerprint = "".join(f"{int(bit, 16):04b}" for bit in compound.fingerprint) |
|
|
|
|
|
return np.array([int(b) for b in fingerprint], dtype=np.int32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_fingerprint(cid_list: List[int], last_cid: int, fpath: str) -> None: |
|
|
|
|
|
"""Save fingerprints for a list of compound CIDs to disk.""" |
|
|
|
|
|
start_idx = np.where(np.array(cid_list) == last_cid)[0][0] + 1 if last_cid > 0 else 0 |
|
|
|
|
|
for cid in cid_list[start_idx:]: |
|
|
|
|
|
fingerprint = get_fingerprint(cid) |
|
|
|
|
|
np.save(os.path.join(fpath, str(cid)), fingerprint) |
|
|
|
|
|
print(f"CID {cid} processed successfully.") |
|
|
|
|
|
time.sleep(1) |
|
|
|
|
|
if start_idx >= len(cid_list): |
|
|
|
|
|
print("All compounds have been processed!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_fingerprint_cid(path: str) -> Tuple[np.ndarray, List[int]]: |
|
|
|
|
|
"""Read fingerprints from .npy files in the specified directory.""" |
|
|
|
|
|
fingerprint = [] |
|
|
|
|
|
cids = [] |
|
|
|
|
|
for file_name in sorted(os.listdir(path)): |
|
|
|
|
|
if file_name.endswith(".npy"): |
|
|
|
|
|
cid = int(file_name.split(".")[0]) |
|
|
|
|
|
fing = np.load(os.path.join(path, file_name)) |
|
|
|
|
|
fingerprint.append(fing) |
|
|
|
|
|
cids.append(cid) |
|
|
|
|
|
fingerprint = np.array(fingerprint).reshape(-1, 920) |
|
|
|
|
|
return fingerprint, cids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def common_data_index(data_for_index: np.ndarray, data_for_cmp: np.ndarray) -> np.ndarray: |
|
|
|
|
|
"""Find indices of elements in data_for_index that exist in data_for_cmp.""" |
|
|
|
|
|
return np.where(np.isin(data_for_index, data_for_cmp))[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_coo_matrix(adj_mat: Union[np.ndarray, sp.coo_matrix]) -> sp.coo_matrix: |
|
|
|
|
|
"""Convert input matrix to scipy.sparse.coo_matrix format.""" |
|
|
|
|
|
if not sp.isspmatrix_coo(adj_mat): |
|
|
|
|
|
adj_mat = sp.coo_matrix(adj_mat) |
|
|
|
|
|
return adj_mat |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mask(positive: sp.coo_matrix, negative: sp.coo_matrix, dtype: type = int) -> torch.Tensor: |
|
|
|
|
|
"""Create a mask combining positive and negative edges.""" |
|
|
|
|
|
row = np.hstack((positive.row, negative.row)) |
|
|
|
|
|
col = np.hstack((positive.col, negative.col)) |
|
|
|
|
|
data = np.ones_like(row) |
|
|
|
|
|
masked = sp.coo_matrix((data, (row, col)), shape=positive.shape).toarray().astype(dtype) |
|
|
|
|
|
return torch.from_numpy(masked) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_tensor(positive: sp.coo_matrix, identity: bool = False) -> torch.Tensor: |
|
|
|
|
|
"""Convert sparse matrix to torch.Tensor, optionally adding identity matrix.""" |
|
|
|
|
|
data = positive + sp.identity(positive.shape[0]) if identity else positive |
|
|
|
|
|
return torch.from_numpy(data.toarray()).float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def np_delete_value(arr: np.ndarray, obj: np.ndarray) -> np.ndarray: |
|
|
|
|
|
"""Remove specified values from a NumPy array.""" |
|
|
|
|
|
indices = [np.where(arr == x)[0][0] for x in obj if x in arr] |
|
|
|
|
|
return np.delete(arr, indices) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate_result(tensor: Union[torch.Tensor, np.ndarray]) -> pd.DataFrame: |
|
|
|
|
|
"""Convert tensor or array to a pandas DataFrame.""" |
|
|
|
|
|
if isinstance(tensor, torch.Tensor): |
|
|
|
|
|
tensor = tensor.detach().cpu().numpy() |
|
|
|
|
|
return pd.DataFrame(tensor.reshape(1, -1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_train_test_index( |
|
|
|
|
|
response: np.ndarray, |
|
|
|
|
|
pos_train_index: np.ndarray, |
|
|
|
|
|
pos_test_index: np.ndarray |
|
|
|
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
|
|
|
|
"""Calculate train and test indices combining positive and negative samples.""" |
|
|
|
|
|
neg_response_index = np.where(response == 0)[0] |
|
|
|
|
|
neg_test_index = np.random.choice(neg_response_index, pos_test_index.shape[0], replace=False) |
|
|
|
|
|
neg_train_index = np_delete_value(neg_response_index, neg_test_index) |
|
|
|
|
|
test_index = np.hstack((pos_test_index, neg_test_index)) |
|
|
|
|
|
train_index = np.hstack((pos_train_index, neg_train_index)) |
|
|
|
|
|
return train_index, test_index |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dir_path(k: int = 1) -> str: |
|
|
|
|
|
"""Get directory path by traversing k levels up from current file.""" |
|
|
|
|
|
fpath = os.path.realpath(__file__) |
|
|
|
|
|
dir_name = os.path.dirname(fpath).replace("\\", "/") |
|
|
|
|
|
for _ in range(k): |
|
|
|
|
|
dir_name = os.path.dirname(dir_name) |
|
|
|
|
|
return dir_name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_row_data(data: pd.DataFrame, row: int) -> np.ndarray: |
|
|
|
|
|
"""Extract non-NaN data from a specific row of a DataFrame.""" |
|
|
|
|
|
target = np.array(data.iloc[row], dtype=np.float32) |
|
|
|
|
|
return target[~np.isnan(target)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transfer_data(data: pd.DataFrame, label: str) -> pd.DataFrame: |
|
|
|
|
|
"""Add a label column to a DataFrame.""" |
|
|
|
|
|
data = data.copy() |
|
|
|
|
|
data["label"] = label |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def link_data_frame(*data: pd.DataFrame) -> pd.DataFrame: |
|
|
|
|
|
"""Concatenate multiple DataFrames vertically.""" |
|
|
|
|
|
return pd.concat(data, ignore_index=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_limit(*data: pd.DataFrame, key: Union[str, int]) -> Tuple[float, float]: |
|
|
|
|
|
"""Calculate min and max values of a key across multiple DataFrames.""" |
|
|
|
|
|
temp = pd.concat(data, ignore_index=True) |
|
|
|
|
|
return temp[key].min() - 0.1, temp[key].max() + 0.1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def delete_all_sub_str(string: str, sub: str, join_str: str = "") -> str: |
|
|
|
|
|
"""Remove all occurrences of a substring and join with specified string.""" |
|
|
|
|
|
parts = string.split(sub) |
|
|
|
|
|
parts = [p for p in parts if p] |
|
|
|
|
|
return join_str.join(parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_best_index(fname: str) -> int: |
|
|
|
|
|
"""Find the index of the AUC closest to the average AUC from a results file.""" |
|
|
|
|
|
with open(fname, "r") as file: |
|
|
|
|
|
content = file.read().replace("\n", "") |
|
|
|
|
|
|
|
|
|
|
|
auc_str = content.split("accs")[0].split(":")[1] |
|
|
|
|
|
auc_str = delete_all_sub_str(auc_str, " ", ",").replace(",]", "]") |
|
|
|
|
|
aucs = np.array(eval(auc_str)) |
|
|
|
|
|
|
|
|
|
|
|
avg_auc = float(content.split("avg_aucs")[1].split(":")[1].split()[0]) |
|
|
|
|
|
return np.argmin(np.abs(aucs - avg_auc)) |
|
|
|
|
|
|
|
|
|
|
|
|