DeepTraCDR: Prediction Cancer Drug Response using multimodal deep learning with Transformers
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. import os
  2. import time
  3. from typing import Tuple, List, Union, Optional
  4. import torch
  5. import numpy as np
  6. import pandas as pd
  7. import torch.nn as nn
  8. import seaborn as sns
  9. import pubchempy as pcp
  10. import scipy.sparse as sp
  11. from sklearn.metrics import roc_auc_score, average_precision_score
  12. import itertools as it
  13. import torch.nn.functional as F
  14. # ----------------------------------------------------------------------------
  15. # Model Evaluation Functions
  16. # ----------------------------------------------------------------------------
  17. def roc_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> float:
  18. """Calculate ROC-AUC score for binary classification."""
  19. assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
  20. return roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
  21. def ap_score(true_data: torch.Tensor, predict_data: torch.Tensor) -> float:
  22. """Calculate Average Precision (area under Precision-Recall curve)."""
  23. assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
  24. return average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
  25. def f1_score_binary(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]:
  26. """Calculate F1 score and the optimal threshold for binary classification."""
  27. thresholds = torch.unique(predict_data)
  28. n_samples = true_data.size(0)
  29. ones = torch.ones((thresholds.size(0), n_samples), device=true_data.device)
  30. zeros = torch.zeros((thresholds.size(0), n_samples), device=true_data.device)
  31. predict_value = torch.where(predict_data.view(1, -1) >= thresholds.view(-1, 1), ones, zeros)
  32. tpn = torch.sum(torch.where(predict_value == true_data.view(1, -1), ones, zeros), dim=1)
  33. tp = torch.sum(predict_value * true_data.view(1, -1), dim=1)
  34. scores = (2 * tp) / (n_samples + 2 * tp - tpn)
  35. max_f1_score = torch.max(scores)
  36. threshold = thresholds[torch.argmax(scores)]
  37. return max_f1_score.item(), threshold.item()
  38. def accuracy_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
  39. """Calculate accuracy using the specified threshold."""
  40. predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
  41. correct = torch.sum(predict_value == true_data).float()
  42. return (correct / true_data.size(0)).item()
  43. def precision_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
  44. """Calculate precision using the specified threshold."""
  45. predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
  46. tp = torch.sum(true_data * predict_value)
  47. fp = torch.sum((1 - true_data) * predict_value)
  48. return (tp / (tp + fp + 1e-8)).item()
  49. def recall_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
  50. """Calculate recall using the specified threshold."""
  51. predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
  52. tp = torch.sum(true_data * predict_value)
  53. fn = torch.sum(true_data * (1 - predict_value))
  54. return (tp / (tp + fn + 1e-8)).item()
  55. def mcc_binary(true_data: torch.Tensor, predict_data: torch.Tensor, threshold: float) -> float:
  56. """Calculate Matthews Correlation Coefficient (MCC) using the specified threshold."""
  57. predict_value = torch.where(predict_data >= threshold, 1.0, 0.0)
  58. true_neg = 1 - true_data
  59. predict_neg = 1 - predict_value
  60. tp = torch.sum(true_data * predict_value)
  61. tn = torch.sum(true_neg * predict_neg)
  62. fp = torch.sum(true_neg * predict_value)
  63. fn = torch.sum(true_data * predict_neg)
  64. denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + 1e-8)
  65. return ((tp * tn - fp * fn) / denominator).item()
  66. def evaluate_all(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, ...]:
  67. """Evaluate multiple metrics: ROC-AUC, AP, accuracy, F1, and MCC."""
  68. assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
  69. auc = roc_auc(true_data, predict_data)
  70. ap = ap_score(true_data, predict_data)
  71. f1, threshold = f1_score_binary(true_data, predict_data)
  72. acc = accuracy_binary(true_data, predict_data, threshold)
  73. precision = precision_binary(true_data, predict_data, threshold)
  74. recall = recall_binary(true_data, predict_data, threshold)
  75. mcc = mcc_binary(true_data, predict_data, threshold)
  76. return auc, ap, acc, f1, mcc, precision, recall
  77. def evaluate_auc(true_data: torch.Tensor, predict_data: torch.Tensor) -> Tuple[float, float]:
  78. """Calculate ROC-AUC and Average Precision."""
  79. assert torch.all((true_data >= 0) & (true_data <= 1)), "True labels must be 0 or 1"
  80. auc = roc_auc_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
  81. ap = average_precision_score(true_data.cpu().numpy(), predict_data.cpu().numpy())
  82. return auc, ap
  83. # ----------------------------------------------------------------------------
  84. # Loss Functions
  85. # ----------------------------------------------------------------------------
  86. def cross_entropy_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor:
  87. """Calculate masked binary cross-entropy loss."""
  88. masked = masked.to(torch.bool)
  89. true_data = torch.masked_select(true_data, masked)
  90. pred_data = torch.masked_select(predict_data, masked)
  91. return nn.BCELoss()(pred_data, true_data)
  92. def mse_loss(true_data: torch.Tensor, predict_data: torch.Tensor, masked: torch.Tensor) -> torch.Tensor:
  93. """Calculate masked mean squared error loss."""
  94. true_data = true_data * masked
  95. predict_data = predict_data * masked
  96. return nn.MSELoss()(predict_data, true_data)
  97. def prototypical_loss(
  98. cell_emb: torch.Tensor,
  99. drug_emb: torch.Tensor,
  100. adj_matrix: Union[torch.Tensor, np.ndarray],
  101. margin: float = 2.0
  102. ) -> torch.Tensor:
  103. """Calculate prototypical loss for positive and negative pairs."""
  104. if isinstance(adj_matrix, torch.Tensor):
  105. adj_matrix = sp.coo_matrix(adj_matrix.detach().cpu().numpy())
  106. pos_pairs = torch.sum(cell_emb[adj_matrix.row] * drug_emb[adj_matrix.col], dim=1)
  107. n_pos = len(adj_matrix.row)
  108. cell_neg = torch.randint(0, cell_emb.size(0), (n_pos,), device=cell_emb.device)
  109. drug_neg = torch.randint(0, drug_emb.size(0), (n_pos,), device=drug_emb.device)
  110. neg_pairs = torch.sum(cell_emb[cell_neg] * drug_emb[drug_neg], dim=1)
  111. labels = torch.ones_like(pos_pairs, device=cell_emb.device)
  112. return F.margin_ranking_loss(pos_pairs, neg_pairs, labels, margin=margin)
  113. # ----------------------------------------------------------------------------
  114. # Correlation and Normalization Functions
  115. # ----------------------------------------------------------------------------
  116. def torch_z_normalized(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor:
  117. """Apply z-normalization along the specified dimension."""
  118. mean = tensor.mean(dim=1 - dim, keepdim=True)
  119. std = tensor.std(dim=1 - dim, keepdim=True) + 1e-8
  120. return (tensor - mean) / std
  121. def torch_corr_x_y(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  122. """Compute correlation matrix between row vectors of two matrices."""
  123. x_center = x - x.mean(dim=1, keepdim=True)
  124. y_center = y - y.mean(dim=1, keepdim=True)
  125. x_std = x.std(dim=1, keepdim=True) + 1e-8
  126. y_std = y.std(dim=1, keepdim=True) + 1e-8
  127. x_norm = x_center / x_std
  128. y_norm = y_center / y_std
  129. corr_matrix = x_norm @ y_norm.t() / (x.size(1) - 1)
  130. return corr_matrix
  131. # ----------------------------------------------------------------------------
  132. # Distance and Similarity Functions
  133. # ----------------------------------------------------------------------------
  134. def torch_euclidean_dist(tensor: torch.Tensor, dim: int = 0) -> torch.Tensor:
  135. """Calculate Euclidean distance between rows or columns of a tensor."""
  136. tensor_mul = torch.mm(tensor.t(), tensor) if dim else torch.mm(tensor, tensor.t())
  137. diag = torch.diag(tensor_mul)
  138. n_diag = diag.size(0)
  139. tensor_diag = diag.repeat(n_diag, 1)
  140. diag = diag.view(n_diag, -1)
  141. dist = torch.sqrt(tensor_diag + diag - 2 * tensor_mul)
  142. return dist
  143. def exp_similarity(tensor: torch.Tensor, sigma: torch.Tensor, normalize: bool = True) -> torch.Tensor:
  144. """Calculate exponential similarity based on Euclidean distance."""
  145. if normalize:
  146. tensor = torch_z_normalized(tensor, dim=1)
  147. tensor_dist = torch_euclidean_dist(tensor, dim=0)
  148. return torch.exp(-tensor_dist / (2 * sigma.pow(2)))
  149. def full_kernel(exp_dist: torch.Tensor) -> torch.Tensor:
  150. """Calculate full kernel matrix from exponential similarity."""
  151. n = exp_dist.shape[0]
  152. ones = torch.ones(n, n, device=exp_dist.device)
  153. diag = torch.diag(ones)
  154. mask_diag = (ones - diag) * exp_dist
  155. mask_diag_sum = mask_diag.sum(dim=1, keepdim=True)
  156. mask_diag = mask_diag / (2 * mask_diag_sum) + 0.5 * diag
  157. return mask_diag
  158. def sparse_kernel(exp_dist: torch.Tensor, k: int) -> torch.Tensor:
  159. """Calculate sparse kernel using k-nearest neighbors."""
  160. n = exp_dist.shape[0]
  161. maxk = torch.topk(exp_dist, k, dim=1)
  162. mink_indices = torch.topk(exp_dist, n - k, dim=1, largest=False).indices
  163. exp_dist[torch.arange(n, device=exp_dist.device).view(n, -1), mink_indices] = 0
  164. knn_sum = maxk.values.sum(dim=1, keepdim=True)
  165. return exp_dist / knn_sum
  166. def scale_sigmoid(tensor: torch.Tensor, alpha: float) -> torch.Tensor:
  167. """Apply scaled sigmoid transformation."""
  168. alpha = torch.tensor(alpha, dtype=torch.float32, device=tensor.device)
  169. return torch.sigmoid(alpha * tensor)
  170. # ----------------------------------------------------------------------------
  171. # Data Processing and Helper Functions
  172. # ----------------------------------------------------------------------------
  173. def init_seeds(seed: int = 0) -> None:
  174. """Initialize random seeds for reproducibility."""
  175. np.random.seed(seed)
  176. torch.manual_seed(seed)
  177. torch.cuda.manual_seed(seed)
  178. torch.cuda.manual_seed_all(seed)
  179. torch.backends.cudnn.deterministic = True
  180. torch.backends.cudnn.benchmark = False
  181. def distribute_compute(
  182. lr_list: List[float],
  183. wd_list: List[float],
  184. scale_list: List[float],
  185. layer_size: List[int],
  186. sigma_list: List[float],
  187. beta_list: List[float],
  188. workers: int,
  189. id: int
  190. ) -> np.ndarray:
  191. """Distribute hyperparameter combinations across workers."""
  192. all_combinations = [
  193. [lr, wd, sc, la, sg, bt]
  194. for lr, wd, sc, la, sg, bt in it.product(lr_list, wd_list, scale_list, layer_size, sigma_list, beta_list)
  195. ]
  196. return np.array_split(all_combinations, workers)[id]
  197. def get_fingerprint(cid: int) -> np.ndarray:
  198. """Retrieve PubChem fingerprint for a given compound CID."""
  199. compound = pcp.Compound.from_cid(cid)
  200. fingerprint = "".join(f"{int(bit, 16):04b}" for bit in compound.fingerprint)
  201. return np.array([int(b) for b in fingerprint], dtype=np.int32)
  202. def save_fingerprint(cid_list: List[int], last_cid: int, fpath: str) -> None:
  203. """Save fingerprints for a list of compound CIDs to disk."""
  204. start_idx = np.where(np.array(cid_list) == last_cid)[0][0] + 1 if last_cid > 0 else 0
  205. for cid in cid_list[start_idx:]:
  206. fingerprint = get_fingerprint(cid)
  207. np.save(os.path.join(fpath, str(cid)), fingerprint)
  208. print(f"CID {cid} processed successfully.")
  209. time.sleep(1)
  210. if start_idx >= len(cid_list):
  211. print("All compounds have been processed!")
  212. def read_fingerprint_cid(path: str) -> Tuple[np.ndarray, List[int]]:
  213. """Read fingerprints from .npy files in the specified directory."""
  214. fingerprint = []
  215. cids = []
  216. for file_name in sorted(os.listdir(path)):
  217. if file_name.endswith(".npy"):
  218. cid = int(file_name.split(".")[0])
  219. fing = np.load(os.path.join(path, file_name))
  220. fingerprint.append(fing)
  221. cids.append(cid)
  222. fingerprint = np.array(fingerprint).reshape(-1, 920)
  223. return fingerprint, cids
  224. def common_data_index(data_for_index: np.ndarray, data_for_cmp: np.ndarray) -> np.ndarray:
  225. """Find indices of elements in data_for_index that exist in data_for_cmp."""
  226. return np.where(np.isin(data_for_index, data_for_cmp))[0]
  227. def to_coo_matrix(adj_mat: Union[np.ndarray, sp.coo_matrix]) -> sp.coo_matrix:
  228. """Convert input matrix to scipy.sparse.coo_matrix format."""
  229. if not sp.isspmatrix_coo(adj_mat):
  230. adj_mat = sp.coo_matrix(adj_mat)
  231. return adj_mat
  232. def mask(positive: sp.coo_matrix, negative: sp.coo_matrix, dtype: type = int) -> torch.Tensor:
  233. """Create a mask combining positive and negative edges."""
  234. row = np.hstack((positive.row, negative.row))
  235. col = np.hstack((positive.col, negative.col))
  236. data = np.ones_like(row)
  237. masked = sp.coo_matrix((data, (row, col)), shape=positive.shape).toarray().astype(dtype)
  238. return torch.from_numpy(masked)
  239. def to_tensor(positive: sp.coo_matrix, identity: bool = False) -> torch.Tensor:
  240. """Convert sparse matrix to torch.Tensor, optionally adding identity matrix."""
  241. data = positive + sp.identity(positive.shape[0]) if identity else positive
  242. return torch.from_numpy(data.toarray()).float()
  243. def np_delete_value(arr: np.ndarray, obj: np.ndarray) -> np.ndarray:
  244. """Remove specified values from a NumPy array."""
  245. indices = [np.where(arr == x)[0][0] for x in obj if x in arr]
  246. return np.delete(arr, indices)
  247. def translate_result(tensor: Union[torch.Tensor, np.ndarray]) -> pd.DataFrame:
  248. """Convert tensor or array to a pandas DataFrame."""
  249. if isinstance(tensor, torch.Tensor):
  250. tensor = tensor.detach().cpu().numpy()
  251. return pd.DataFrame(tensor.reshape(1, -1))
  252. def calculate_train_test_index(
  253. response: np.ndarray,
  254. pos_train_index: np.ndarray,
  255. pos_test_index: np.ndarray
  256. ) -> Tuple[np.ndarray, np.ndarray]:
  257. """Calculate train and test indices combining positive and negative samples."""
  258. neg_response_index = np.where(response == 0)[0]
  259. neg_test_index = np.random.choice(neg_response_index, pos_test_index.shape[0], replace=False)
  260. neg_train_index = np_delete_value(neg_response_index, neg_test_index)
  261. test_index = np.hstack((pos_test_index, neg_test_index))
  262. train_index = np.hstack((pos_train_index, neg_train_index))
  263. return train_index, test_index
  264. def dir_path(k: int = 1) -> str:
  265. """Get directory path by traversing k levels up from current file."""
  266. fpath = os.path.realpath(__file__)
  267. dir_name = os.path.dirname(fpath).replace("\\", "/")
  268. for _ in range(k):
  269. dir_name = os.path.dirname(dir_name)
  270. return dir_name
  271. def extract_row_data(data: pd.DataFrame, row: int) -> np.ndarray:
  272. """Extract non-NaN data from a specific row of a DataFrame."""
  273. target = np.array(data.iloc[row], dtype=np.float32)
  274. return target[~np.isnan(target)]
  275. def transfer_data(data: pd.DataFrame, label: str) -> pd.DataFrame:
  276. """Add a label column to a DataFrame."""
  277. data = data.copy()
  278. data["label"] = label
  279. return data
  280. def link_data_frame(*data: pd.DataFrame) -> pd.DataFrame:
  281. """Concatenate multiple DataFrames vertically."""
  282. return pd.concat(data, ignore_index=True)
  283. def calculate_limit(*data: pd.DataFrame, key: Union[str, int]) -> Tuple[float, float]:
  284. """Calculate min and max values of a key across multiple DataFrames."""
  285. temp = pd.concat(data, ignore_index=True)
  286. return temp[key].min() - 0.1, temp[key].max() + 0.1
  287. def delete_all_sub_str(string: str, sub: str, join_str: str = "") -> str:
  288. """Remove all occurrences of a substring and join with specified string."""
  289. parts = string.split(sub)
  290. parts = [p for p in parts if p]
  291. return join_str.join(parts)
  292. def get_best_index(fname: str) -> int:
  293. """Find the index of the AUC closest to the average AUC from a results file."""
  294. with open(fname, "r") as file:
  295. content = file.read().replace("\n", "")
  296. auc_str = content.split("accs")[0].split(":")[1]
  297. auc_str = delete_all_sub_str(auc_str, " ", ",").replace(",]", "]")
  298. aucs = np.array(eval(auc_str))
  299. avg_auc = float(content.split("avg_aucs")[1].split(":")[1].split()[0])
  300. return np.argmin(np.abs(aucs - avg_auc))