Browse Source

Update 'utils.py'

master
Zahra Asgari 3 days ago
parent
commit
608be828ab
1 changed files with 395 additions and 401 deletions
  1. 395
    401
      utils.py

+ 395
- 401
utils.py View File

import os
import time
from typing import Tuple, List, Union, Optional
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import seaborn as sns
import pubchempy as pcp
import scipy.sparse as sp
from sklearn.metrics import roc_auc_score, average_precision_score
import itertools as it
import torch.nn.functional as F
# ----------------------------------------------------------------------------
# Model Evaluation Functions
# ----------------------------------------------------------------------------
def roc_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> float:
"""Calculate ROC-AUC score for binary classification."""
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
return roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
def ap_score(true_data: torch.Tensor, predict_data: torch.Tensor) -> float:
"""Calculate Average Precision (area under Precision-Recall curve)."""
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
return average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
def f1_score_binary(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]:
"""Calculate F1 score and the optimal threshold for binary classification."""
thresholds = torch.unique(predict_data)
n_samples = true_data.size(0)
ones = torch.ones((thresholds.size(0), n_samples), device=true_data.device)
zeros = torch.zeros((thresholds.size(0), n_samples), device=true_data.device)
predict_value = torch.where(predict_data.view(1, -1) >= thresholds.view(-1, 1), ones, zeros)
tpn = torch.sum(torch.where(predict_value == true_data.view(1, -1), ones, zeros), dim=1)
tp = torch.sum(predict_value * true_data.view(1, -1), dim=1)
scores = (2 * tp) / (n_samples + 2 * tp - tpn)
max_f1_score = torch.max(scores)
threshold = thresholds[torch.argmax(scores)]
return max_f1_score.item(), threshold.item()
def accuracy_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
"""Calculate accuracy using the specified threshold."""
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
correct = torch.sum(predict_value == true_data).float()
return (correct / true_data.size(0)).item()
def precision_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
"""Calculate precision using the specified threshold."""
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
tp = torch.sum(true_data * predict_value)
fp = torch.sum((1 - true_data) * predict_value)
return (tp / (tp + fp + 1e-8)).item()
def recall_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
"""Calculate recall using the specified threshold."""
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
tp = torch.sum(true_data * predict_value)
fn = torch.sum(true_data * (1 - predict_value))
return (tp / (tp + fn + 1e-8)).item()
def mcc_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
"""Calculate Matthews Correlation Coefficient (MCC) using the specified threshold."""
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
true_neg = 1 - true_data
predict_neg = 1 - predict_value
tp = torch.sum(true_data * predict_value)
tn = torch.sum(true_neg * predict_neg)
fp = torch.sum(true_neg * predict_value)
fn = torch.sum(true_data * predict_neg)
denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + 1e-8)
return ((tp * tn - fp * fn) / denominator).item()
def evaluate_all(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, ...]:
"""Evaluate multiple metrics: ROC-AUC, AP, accuracy, F1, and MCC."""
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
auc = roc_auc(true_data, predict_data)
ap = ap_score(true_data, predict_data)
f1, threshold = f1_score_binary(true_data, predict_data)
acc = accuracy_binary(true_data, predict_data, threshold)
precision = precision_binary(true_data, predict_data, threshold)
recall = recall_binary(true_data, predict_data, threshold)
mcc = mcc_binary(true_data, predict_data, threshold)
return auc, ap, acc, f1, mcc, precision, recall
def evaluate_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]:
"""Calculate ROC-AUC and Average Precision."""
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
auc = roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
ap = average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
return auc, ap
# ----------------------------------------------------------------------------
# Loss Functions
# ----------------------------------------------------------------------------
def cross_entropy_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor:
"""Calculate masked binary cross-entropy loss."""
masked = masked.to(torch.bool)
true_data = torch.masked_select(true_data, masked)
pred_data = torch.masked_select(predict_data, masked)
return nn.BCELoss()(pred_data, true_data)
def mse_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor:
"""Calculate masked mean squared error loss."""
true_data = true_data * masked
predict_data = predict_data * masked
return nn.MSELoss()(predict_data, true_data)
def prototypical_loss(
cell_emb: torch.Tensor,
drug_emb: torch.Tensor,
adj_matrix: Union[torch.Tensor, np.ndarray],
margin: float = 2.0
) -> torch.Tensor:
"""Calculate prototypical loss for positive and negative pairs."""
if isinstance(adj_matrix, torch.Tensor):
adj_matrix = sp.coo_matrix(adj_matrix.detach().cpu().numpy())
pos_pairs = torch.sum(cell_emb[adj_matrix.row] * drug_emb[adj_matrix.col], dim=1)
n_pos = len(adj_matrix.row)
cell_neg = torch.randint(0, cell_emb.size(0), (n_pos,), device=cell_emb.device)
drug_neg = torch.randint(0, drug_emb.size(0), (n_pos,), device=drug_emb.device)
neg_pairs = torch.sum(cell_emb[cell_neg] * drug_emb[drug_neg], dim=1)
labels = torch.ones_like(pos_pairs, device=cell_emb.device)
return F.margin_ranking_loss(pos_pairs, neg_pairs, labels, margin=margin)
# ----------------------------------------------------------------------------
# Correlation and Normalization Functions
# ----------------------------------------------------------------------------
def torch_z_normalized(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""Apply z-normalization along the specified dimension."""
mean = tensor.mean(dim=1 - dim, keepdim=True)
std = tensor.std(dim=1 - dim, keepdim=True) + 1e-8
return (tensor - mean) / std
def torch_corr_x_y(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Compute correlation matrix between row vectors of two matrices."""
x_center = x - x.mean(dim=1, keepdim=True)
y_center = y - y.mean(dim=1, keepdim=True)
x_std = x.std(dim=1, keepdim=True) + 1e-8
y_std = y.std(dim=1, keepdim=True) + 1e-8
x_norm = x_center / x_std
y_norm = y_center / y_std
corr_matrix = x_norm @ y_norm.t() / (x.size(1) - 1)
return corr_matrix
# ----------------------------------------------------------------------------
# Distance and Similarity Functions
# ----------------------------------------------------------------------------
def torch_euclidean_dist(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""Calculate Euclidean distance between rows or columns of a tensor."""
tensor_mul = torch.mm(tensor.t(), tensor) if dim else torch.mm(tensor, tensor.t())
diag = torch.diag(tensor_mul)
n_diag = diag.size(0)
tensor_diag = diag.repeat(n_diag, 1)
diag = diag.view(n_diag, -1)
dist = torch.sqrt(tensor_diag + diag - 2 * tensor_mul)
return dist
def exp_similarity(tensor: torch.Tensor, sigma: torch.Tensor, normalize: bool = True) -> torch.Tensor:
"""Calculate exponential similarity based on Euclidean distance."""
if normalize:
tensor = torch_z_normalized(tensor, dim=1)
tensor_dist = torch_euclidean_dist(tensor, dim=0)
return torch.exp(-tensor_dist / (2 * sigma.pow(2)))
def full_kernel(exp_dist: torch.Tensor) -> torch.Tensor:
"""Calculate full kernel matrix from exponential similarity."""
n = exp_dist.shape[0]
ones = torch.ones(n, n, device=exp_dist.device)
diag = torch.diag(ones)
mask_diag = (ones - diag) * exp_dist
mask_diag_sum = mask_diag.sum(dim=1, keepdim=True)
mask_diag = mask_diag / (2 * mask_diag_sum) + 0.5 * diag
return mask_diag
def sparse_kernel(exp_dist: torch.Tensor, k: int) -> torch.Tensor:
"""Calculate sparse kernel using k-nearest neighbors."""
n = exp_dist.shape[0]
maxk = torch.topk(exp_dist, k, dim=1)
mink_indices = torch.topk(exp_dist, n - k, dim=1, largest=False).indices
exp_dist[torch.arange(n, device=exp_dist.device).view(n, -1), mink_indices] = 0
knn_sum = maxk.values.sum(dim=1, keepdim=True)
return exp_dist / knn_sum
def scale_sigmoid(tensor: torch.Tensor, alpha: float) -> torch.Tensor:
"""Apply scaled sigmoid transformation."""
alpha = torch.tensor(alpha, dtype=torch.float32, device=tensor.device)
return torch.sigmoid(alpha * tensor)
# ----------------------------------------------------------------------------
# Data Processing and Helper Functions
# ----------------------------------------------------------------------------
def init_seeds(seed: int = 0) -> None:
"""Initialize random seeds for reproducibility."""
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def distribute_compute(
lr_list: List[float],
wd_list: List[float],
scale_list: List[float],
layer_size: List[int],
sigma_list: List[float],
beta_list: List[float],
workers: int,
id: int
) -> np.ndarray:
"""Distribute hyperparameter combinations across workers."""
all_combinations = [
[lr, wd, sc, la, sg, bt]
for lr, wd, sc, la, sg, bt in it.product(lr_list, wd_list, scale_list, layer_size, sigma_list, beta_list)
]
return np.array_split(all_combinations, workers)[id]
def get_fingerprint(cid: int) -> np.ndarray:
"""Retrieve PubChem fingerprint for a given compound CID."""
compound = pcp.Compound.from_cid(cid)
fingerprint = "".join(f"{int(bit, 16):04b}" for bit in compound.fingerprint)
return np.array([int(b) for b in fingerprint], dtype=np.int32)
def save_fingerprint(cid_list: List[int], last_cid: int, fpath: str) -> None:
"""Save fingerprints for a list of compound CIDs to disk."""
start_idx = np.where(np.array(cid_list) == last_cid)[0][0] + 1 if last_cid > 0 else 0
for cid in cid_list[start_idx:]:
fingerprint = get_fingerprint(cid)
np.save(os.path.join(fpath, str(cid)), fingerprint)
print(f"CID {cid} processed successfully.")
time.sleep(1)
if start_idx >= len(cid_list):
print("All compounds have been processed!")
def read_fingerprint_cid(path: str) -> Tuple[np.ndarray, List[int]]:
"""Read fingerprints from .npy files in the specified directory."""
fingerprint = []
cids = []
for file_name in sorted(os.listdir(path)):
if file_name.endswith(".npy"):
cid = int(file_name.split(".")[0])
fing = np.load(os.path.join(path, file_name))
fingerprint.append(fing)
cids.append(cid)
fingerprint = np.array(fingerprint).reshape(-1, 920)
return fingerprint, cids
def common_data_index(data_for_index: np.ndarray, data_for_cmp: np.ndarray) -> np.ndarray:
"""Find indices of elements in data_for_index that exist in data_for_cmp."""
return np.where(np.isin(data_for_index, data_for_cmp))[0]
def to_coo_matrix(adj_mat: Union[np.ndarray, sp.coo_matrix]) -> sp.coo_matrix:
"""Convert input matrix to scipy.sparse.coo_matrix format."""
if not sp.isspmatrix_coo(adj_mat):
adj_mat = sp.coo_matrix(adj_mat)
return adj_mat
def mask(positive: sp.coo_matrix, negative: sp.coo_matrix, dtype: type = int) -> torch.Tensor:
"""Create a mask combining positive and negative edges."""
row = np.hstack((positive.row, negative.row))
col = np.hstack((positive.col, negative.col))
data = np.ones_like(row)
masked = sp.coo_matrix((data, (row, col)), shape=positive.shape).toarray().astype(dtype)
return torch.from_numpy(masked)
def to_tensor(positive: sp.coo_matrix, identity: bool = False) -> torch.Tensor:
"""Convert sparse matrix to torch.Tensor, optionally adding identity matrix."""
data = positive + sp.identity(positive.shape[0]) if identity else positive
return torch.from_numpy(data.toarray()).float()
def np_delete_value(arr: np.ndarray, obj: np.ndarray) -> np.ndarray:
"""Remove specified values from a NumPy array."""
indices = [np.where(arr == x)[0][0] for x in obj if x in arr]
return np.delete(arr, indices)
def translate_result(tensor: Union[torch.Tensor, np.ndarray]) -> pd.DataFrame:
"""Convert tensor or array to a pandas DataFrame."""
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return pd.DataFrame(tensor.reshape(1, -1))
def calculate_train_test_index(
response: np.ndarray,
pos_train_index: np.ndarray,
pos_test_index: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Calculate train and test indices combining positive and negative samples."""
neg_response_index = np.where(response == 0)[0]
neg_test_index = np.random.choice(neg_response_index, pos_test_index.shape[0], replace=False)
neg_train_index = np_delete_value(neg_response_index, neg_test_index)
test_index = np.hstack((pos_test_index, neg_test_index))
train_index = np.hstack((pos_train_index, neg_train_index))
return train_index, test_index
def dir_path(k: int = 1) -> str:
"""Get directory path by traversing k levels up from current file."""
fpath = os.path.realpath(__file__)
dir_name = os.path.dirname(fpath).replace("\\", "/")
for _ in range(k):
dir_name = os.path.dirname(dir_name)
return dir_name
def extract_row_data(data: pd.DataFrame, row: int) -> np.ndarray:
"""Extract non-NaN data from a specific row of a DataFrame."""
target = np.array(data.iloc[row], dtype=np.float32)
return target[~np.isnan(target)]
def transfer_data(data: pd.DataFrame, label: str) -> pd.DataFrame:
"""Add a label column to a DataFrame."""
data = data.copy()
data["label"] = label
return data
def link_data_frame(*data: pd.DataFrame) -> pd.DataFrame:
"""Concatenate multiple DataFrames vertically."""
return pd.concat(data, ignore_index=True)
def calculate_limit(*data: pd.DataFrame, key: Union[str, int]) -> Tuple[float, float]:
"""Calculate min and max values of a key across multiple DataFrames."""
temp = pd.concat(data, ignore_index=True)
return temp[key].min() - 0.1, temp[key].max() + 0.1
def delete_all_sub_str(string: str, sub: str, join_str: str = "") -> str:
"""Remove all occurrences of a substring and join with specified string."""
parts = string.split(sub)
parts = [p for p in parts if p]
return join_str.join(parts)
def get_best_index(fname: str) -> int:
"""Find the index of the AUC closest to the average AUC from a results file."""
with open(fname, "r") as file:
content = file.read().replace("\n", "")
auc_str = content.split("accs")[0].split(":")[1]
auc_str = delete_all_sub_str(auc_str, " ", ",").replace(",]", "]")
aucs = np.array(eval(auc_str))
avg_auc = float(content.split("avg_aucs")[1].split(":")[1].split()[0])
return np.argmin(np.abs(aucs - avg_auc))
def gather_color_code(*string: str) -> List[Tuple[float, float, float]]:
"""Map color names to seaborn color palette codes."""
color_str = ["bluea0"] = ["blue", "orange", "green", "red", "purple", "brown", "pink", "grey", "yellow", "cyan"]
palette = sns.color_palette()
color_map = dict(zip(color_str, palette))
return [color_map[color] for color in string]
import os
import time
from typing import Tuple, List, Union, Optional
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import seaborn as sns
import pubchempy as pcp
import scipy.sparse as sp
from sklearn.metrics import roc_auc_score, average_precision_score
import itertools as it
import torch.nn.functional as F


# ----------------------------------------------------------------------------
# Model Evaluation Functions
# ----------------------------------------------------------------------------

def roc_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> float:
"""Calculate ROC-AUC score for binary classification."""
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
return roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy())


def ap_score(true_data: torch.Tensor, predict_data: torch.Tensor) -> float:
"""Calculate Average Precision (area under Precision-Recall curve)."""
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
return average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy())


def f1_score_binary(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]:
"""Calculate F1 score and the optimal threshold for binary classification."""
thresholds = torch.unique(predict_data)
n_samples = true_data.size(0)
ones = torch.ones((thresholds.size(0), n_samples), device=true_data.device)
zeros = torch.zeros((thresholds.size(0), n_samples), device=true_data.device)
predict_value = torch.where(predict_data.view(1, -1) >= thresholds.view(-1, 1), ones, zeros)
tpn = torch.sum(torch.where(predict_value == true_data.view(1, -1), ones, zeros), dim=1)
tp = torch.sum(predict_value * true_data.view(1, -1), dim=1)
scores = (2 * tp) / (n_samples + 2 * tp - tpn)
max_f1_score = torch.max(scores)
threshold = thresholds[torch.argmax(scores)]
return max_f1_score.item(), threshold.item()


def accuracy_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
"""Calculate accuracy using the specified threshold."""
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
correct = torch.sum(predict_value == true_data).float()
return (correct / true_data.size(0)).item()


def precision_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
"""Calculate precision using the specified threshold."""
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
tp = torch.sum(true_data * predict_value)
fp = torch.sum((1 - true_data) * predict_value)
return (tp / (tp + fp + 1e-8)).item()


def recall_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
"""Calculate recall using the specified threshold."""
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
tp = torch.sum(true_data * predict_value)
fn = torch.sum(true_data * (1 - predict_value))
return (tp / (tp + fn + 1e-8)).item()


def mcc_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
"""Calculate Matthews Correlation Coefficient (MCC) using the specified threshold."""
predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
true_neg = 1 - true_data
predict_neg = 1 - predict_value
tp = torch.sum(true_data * predict_value)
tn = torch.sum(true_neg * predict_neg)
fp = torch.sum(true_neg * predict_value)
fn = torch.sum(true_data * predict_neg)
denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + 1e-8)
return ((tp * tn - fp * fn) / denominator).item()


def evaluate_all(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, ...]:
"""Evaluate multiple metrics: ROC-AUC, AP, accuracy, F1, and MCC."""
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
auc = roc_auc(true_data, predict_data)
ap = ap_score(true_data, predict_data)
f1, threshold = f1_score_binary(true_data, predict_data)
acc = accuracy_binary(true_data, predict_data, threshold)
precision = precision_binary(true_data, predict_data, threshold)
recall = recall_binary(true_data, predict_data, threshold)
mcc = mcc_binary(true_data, predict_data, threshold)
return auc, ap, acc, f1, mcc, precision, recall


def evaluate_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]:
"""Calculate ROC-AUC and Average Precision."""
assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
auc = roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
ap = average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
return auc, ap


# ----------------------------------------------------------------------------
# Loss Functions
# ----------------------------------------------------------------------------

def cross_entropy_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor:
"""Calculate masked binary cross-entropy loss."""
masked = masked.to(torch.bool)
true_data = torch.masked_select(true_data, masked)
pred_data = torch.masked_select(predict_data, masked)
return nn.BCELoss()(pred_data, true_data)


def mse_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor:
"""Calculate masked mean squared error loss."""
true_data = true_data * masked
predict_data = predict_data * masked
return nn.MSELoss()(predict_data, true_data)


def prototypical_loss(
cell_emb: torch.Tensor,
drug_emb: torch.Tensor,
adj_matrix: Union[torch.Tensor, np.ndarray],
margin: float = 2.0
) -> torch.Tensor:
"""Calculate prototypical loss for positive and negative pairs."""
if isinstance(adj_matrix, torch.Tensor):
adj_matrix = sp.coo_matrix(adj_matrix.detach().cpu().numpy())
pos_pairs = torch.sum(cell_emb[adj_matrix.row] * drug_emb[adj_matrix.col], dim=1)
n_pos = len(adj_matrix.row)
cell_neg = torch.randint(0, cell_emb.size(0), (n_pos,), device=cell_emb.device)
drug_neg = torch.randint(0, drug_emb.size(0), (n_pos,), device=drug_emb.device)
neg_pairs = torch.sum(cell_emb[cell_neg] * drug_emb[drug_neg], dim=1)
labels = torch.ones_like(pos_pairs, device=cell_emb.device)
return F.margin_ranking_loss(pos_pairs, neg_pairs, labels, margin=margin)


# ----------------------------------------------------------------------------
# Correlation and Normalization Functions
# ----------------------------------------------------------------------------

def torch_z_normalized(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""Apply z-normalization along the specified dimension."""
mean = tensor.mean(dim=1 - dim, keepdim=True)
std = tensor.std(dim=1 - dim, keepdim=True) + 1e-8
return (tensor - mean) / std


def torch_corr_x_y(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Compute correlation matrix between row vectors of two matrices."""
x_center = x - x.mean(dim=1, keepdim=True)
y_center = y - y.mean(dim=1, keepdim=True)
x_std = x.std(dim=1, keepdim=True) + 1e-8
y_std = y.std(dim=1, keepdim=True) + 1e-8
x_norm = x_center / x_std
y_norm = y_center / y_std
corr_matrix = x_norm @ y_norm.t() / (x.size(1) - 1)
return corr_matrix


# ----------------------------------------------------------------------------
# Distance and Similarity Functions
# ----------------------------------------------------------------------------

def torch_euclidean_dist(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""Calculate Euclidean distance between rows or columns of a tensor."""
tensor_mul = torch.mm(tensor.t(), tensor) if dim else torch.mm(tensor, tensor.t())
diag = torch.diag(tensor_mul)
n_diag = diag.size(0)
tensor_diag = diag.repeat(n_diag, 1)
diag = diag.view(n_diag, -1)
dist = torch.sqrt(tensor_diag + diag - 2 * tensor_mul)
return dist


def exp_similarity(tensor: torch.Tensor, sigma: torch.Tensor, normalize: bool = True) -> torch.Tensor:
"""Calculate exponential similarity based on Euclidean distance."""
if normalize:
tensor = torch_z_normalized(tensor, dim=1)
tensor_dist = torch_euclidean_dist(tensor, dim=0)
return torch.exp(-tensor_dist / (2 * sigma.pow(2)))


def full_kernel(exp_dist: torch.Tensor) -> torch.Tensor:
"""Calculate full kernel matrix from exponential similarity."""
n = exp_dist.shape[0]
ones = torch.ones(n, n, device=exp_dist.device)
diag = torch.diag(ones)
mask_diag = (ones - diag) * exp_dist
mask_diag_sum = mask_diag.sum(dim=1, keepdim=True)
mask_diag = mask_diag / (2 * mask_diag_sum) + 0.5 * diag
return mask_diag


def sparse_kernel(exp_dist: torch.Tensor, k: int) -> torch.Tensor:
"""Calculate sparse kernel using k-nearest neighbors."""
n = exp_dist.shape[0]
maxk = torch.topk(exp_dist, k, dim=1)
mink_indices = torch.topk(exp_dist, n - k, dim=1, largest=False).indices
exp_dist[torch.arange(n, device=exp_dist.device).view(n, -1), mink_indices] = 0
knn_sum = maxk.values.sum(dim=1, keepdim=True)
return exp_dist / knn_sum


def scale_sigmoid(tensor: torch.Tensor, alpha: float) -> torch.Tensor:
"""Apply scaled sigmoid transformation."""
alpha = torch.tensor(alpha, dtype=torch.float32, device=tensor.device)
return torch.sigmoid(alpha * tensor)


# ----------------------------------------------------------------------------
# Data Processing and Helper Functions
# ----------------------------------------------------------------------------

def init_seeds(seed: int = 0) -> None:
"""Initialize random seeds for reproducibility."""
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def distribute_compute(
lr_list: List[float],
wd_list: List[float],
scale_list: List[float],
layer_size: List[int],
sigma_list: List[float],
beta_list: List[float],
workers: int,
id: int
) -> np.ndarray:
"""Distribute hyperparameter combinations across workers."""
all_combinations = [
[lr, wd, sc, la, sg, bt]
for lr, wd, sc, la, sg, bt in it.product(lr_list, wd_list, scale_list, layer_size, sigma_list, beta_list)
]
return np.array_split(all_combinations, workers)[id]


def get_fingerprint(cid: int) -> np.ndarray:
"""Retrieve PubChem fingerprint for a given compound CID."""
compound = pcp.Compound.from_cid(cid)
fingerprint = "".join(f"{int(bit, 16):04b}" for bit in compound.fingerprint)
return np.array([int(b) for b in fingerprint], dtype=np.int32)


def save_fingerprint(cid_list: List[int], last_cid: int, fpath: str) -> None:
"""Save fingerprints for a list of compound CIDs to disk."""
start_idx = np.where(np.array(cid_list) == last_cid)[0][0] + 1 if last_cid > 0 else 0
for cid in cid_list[start_idx:]:
fingerprint = get_fingerprint(cid)
np.save(os.path.join(fpath, str(cid)), fingerprint)
print(f"CID {cid} processed successfully.")
time.sleep(1)
if start_idx >= len(cid_list):
print("All compounds have been processed!")


def read_fingerprint_cid(path: str) -> Tuple[np.ndarray, List[int]]:
"""Read fingerprints from .npy files in the specified directory."""
fingerprint = []
cids = []
for file_name in sorted(os.listdir(path)):
if file_name.endswith(".npy"):
cid = int(file_name.split(".")[0])
fing = np.load(os.path.join(path, file_name))
fingerprint.append(fing)
cids.append(cid)
fingerprint = np.array(fingerprint).reshape(-1, 920)
return fingerprint, cids


def common_data_index(data_for_index: np.ndarray, data_for_cmp: np.ndarray) -> np.ndarray:
"""Find indices of elements in data_for_index that exist in data_for_cmp."""
return np.where(np.isin(data_for_index, data_for_cmp))[0]


def to_coo_matrix(adj_mat: Union[np.ndarray, sp.coo_matrix]) -> sp.coo_matrix:
"""Convert input matrix to scipy.sparse.coo_matrix format."""
if not sp.isspmatrix_coo(adj_mat):
adj_mat = sp.coo_matrix(adj_mat)
return adj_mat


def mask(positive: sp.coo_matrix, negative: sp.coo_matrix, dtype: type = int) -> torch.Tensor:
"""Create a mask combining positive and negative edges."""
row = np.hstack((positive.row, negative.row))
col = np.hstack((positive.col, negative.col))
data = np.ones_like(row)
masked = sp.coo_matrix((data, (row, col)), shape=positive.shape).toarray().astype(dtype)
return torch.from_numpy(masked)


def to_tensor(positive: sp.coo_matrix, identity: bool = False) -> torch.Tensor:
"""Convert sparse matrix to torch.Tensor, optionally adding identity matrix."""
data = positive + sp.identity(positive.shape[0]) if identity else positive
return torch.from_numpy(data.toarray()).float()


def np_delete_value(arr: np.ndarray, obj: np.ndarray) -> np.ndarray:
"""Remove specified values from a NumPy array."""
indices = [np.where(arr == x)[0][0] for x in obj if x in arr]
return np.delete(arr, indices)


def translate_result(tensor: Union[torch.Tensor, np.ndarray]) -> pd.DataFrame:
"""Convert tensor or array to a pandas DataFrame."""
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return pd.DataFrame(tensor.reshape(1, -1))


def calculate_train_test_index(
response: np.ndarray,
pos_train_index: np.ndarray,
pos_test_index: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Calculate train and test indices combining positive and negative samples."""
neg_response_index = np.where(response == 0)[0]
neg_test_index = np.random.choice(neg_response_index, pos_test_index.shape[0], replace=False)
neg_train_index = np_delete_value(neg_response_index, neg_test_index)
test_index = np.hstack((pos_test_index, neg_test_index))
train_index = np.hstack((pos_train_index, neg_train_index))
return train_index, test_index


def dir_path(k: int = 1) -> str:
"""Get directory path by traversing k levels up from current file."""
fpath = os.path.realpath(__file__)
dir_name = os.path.dirname(fpath).replace("\\", "/")
for _ in range(k):
dir_name = os.path.dirname(dir_name)
return dir_name


def extract_row_data(data: pd.DataFrame, row: int) -> np.ndarray:
"""Extract non-NaN data from a specific row of a DataFrame."""
target = np.array(data.iloc[row], dtype=np.float32)
return target[~np.isnan(target)]


def transfer_data(data: pd.DataFrame, label: str) -> pd.DataFrame:
"""Add a label column to a DataFrame."""
data = data.copy()
data["label"] = label
return data


def link_data_frame(*data: pd.DataFrame) -> pd.DataFrame:
"""Concatenate multiple DataFrames vertically."""
return pd.concat(data, ignore_index=True)


def calculate_limit(*data: pd.DataFrame, key: Union[str, int]) -> Tuple[float, float]:
"""Calculate min and max values of a key across multiple DataFrames."""
temp = pd.concat(data, ignore_index=True)
return temp[key].min() - 0.1, temp[key].max() + 0.1


def delete_all_sub_str(string: str, sub: str, join_str: str = "") -> str:
"""Remove all occurrences of a substring and join with specified string."""
parts = string.split(sub)
parts = [p for p in parts if p]
return join_str.join(parts)


def get_best_index(fname: str) -> int:
"""Find the index of the AUC closest to the average AUC from a results file."""
with open(fname, "r") as file:
content = file.read().replace("\n", "")
auc_str = content.split("accs")[0].split(":")[1]
auc_str = delete_all_sub_str(auc_str, " ", ",").replace(",]", "]")
aucs = np.array(eval(auc_str))
avg_auc = float(content.split("avg_aucs")[1].split(":")[1].split()[0])
return np.argmin(np.abs(aucs - avg_auc))



Loading…
Cancel
Save