DeepTraCDR: Prediction Cancer Drug Response using multimodal deep learning with Transformers
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. # model.py
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch.nn import MultiheadAttention, TransformerEncoder, TransformerEncoderLayer
  7. from utils import torch_corr_x_y, cross_entropy_loss, prototypical_loss
  8. from sklearn.metrics import roc_auc_score, average_precision_score
  9. class ConstructAdjMatrix(nn.Module):
  10. """Module to construct normalized adjacency matrices for graph-based operations."""
  11. def __init__(self, original_adj_mat, device="cpu"):
  12. """
  13. Initialize the adjacency matrix constructor.
  14. Args:
  15. original_adj_mat (np.ndarray or torch.Tensor): Original adjacency matrix.
  16. device (str): Device to perform computations on (e.g., 'cpu' or 'cuda:0').
  17. """
  18. super().__init__()
  19. # Convert NumPy array to PyTorch tensor if necessary
  20. if isinstance(original_adj_mat, np.ndarray):
  21. original_adj_mat = torch.from_numpy(original_adj_mat).float()
  22. self.adj = original_adj_mat.to(device)
  23. self.device = device
  24. def forward(self):
  25. """
  26. Compute normalized adjacency matrices for cells and drugs.
  27. Returns:
  28. tuple: (agg_cell_lp, agg_drug_lp, self_cell_lp, self_drug_lp) normalized matrices.
  29. """
  30. with torch.no_grad():
  31. # Compute degree-normalized matrices
  32. d_x = torch.diag(torch.pow(torch.sum(self.adj, dim=1) + 1, -0.5))
  33. d_y = torch.diag(torch.pow(torch.sum(self.adj, dim=0) + 1, -0.5))
  34. agg_cell_lp = torch.mm(torch.mm(d_x, self.adj), d_y)
  35. agg_drug_lp = torch.mm(torch.mm(d_y, self.adj.T), d_x)
  36. self_cell_lp = torch.diag(torch.add(torch.pow(torch.sum(self.adj, dim=1) + 1, -1), 1))
  37. self_drug_lp = torch.diag(torch.add(torch.pow(torch.sum(self.adj, dim=0) + 1, -1), 1))
  38. return (
  39. agg_cell_lp.to(self.device),
  40. agg_drug_lp.to(self.device),
  41. self_cell_lp.to(self.device),
  42. self_drug_lp.to(self.device)
  43. )
  44. class LoadFeature(nn.Module):
  45. """Module to load and preprocess cell and drug features."""
  46. def __init__(self, cell_exprs, drug_fingerprints, device="cpu"):
  47. """
  48. Initialize feature loading and preprocessing layers.
  49. Args:
  50. cell_exprs (np.ndarray): Cell expression data.
  51. drug_fingerprints (list): List of drug fingerprint arrays.
  52. device (str): Device to perform computations on.
  53. """
  54. super().__init__()
  55. self.device = device
  56. self.cell_exprs = torch.from_numpy(cell_exprs).float().to(device)
  57. self.drug_fingerprints = [torch.from_numpy(fp).float().to(device) for fp in drug_fingerprints]
  58. # Projection layers for drug fingerprints
  59. self.drug_proj = nn.ModuleList([
  60. nn.Sequential(
  61. nn.Linear(fp.shape[1], 512),
  62. nn.BatchNorm1d(512),
  63. nn.GELU(),
  64. nn.Dropout(0.3)
  65. ).to(device) for fp in drug_fingerprints
  66. ])
  67. # Transformer encoder for drug features
  68. self.transformer = TransformerEncoder(
  69. TransformerEncoderLayer(
  70. d_model=512,
  71. nhead=8,
  72. dim_feedforward=2048,
  73. batch_first=True
  74. ),
  75. num_layers=3
  76. ).to(device)
  77. # Normalization layers
  78. self.cell_norm = nn.LayerNorm(cell_exprs.shape[1]).to(device)
  79. self.drug_norm = nn.LayerNorm(512).to(device)
  80. # Encoder for cell features
  81. self.cell_encoder = nn.Sequential(
  82. nn.Linear(cell_exprs.shape[1], 1024),
  83. nn.BatchNorm1d(1024),
  84. nn.GELU(),
  85. nn.Dropout(0.3),
  86. nn.Linear(1024, 512)
  87. ).to(device)
  88. def forward(self):
  89. """
  90. Process cell and drug features through encoding and transformation.
  91. Returns:
  92. tuple: (cell_encoded, drug_feat) encoded cell and drug features.
  93. """
  94. # Normalize and encode cell features
  95. cell_feat = self.cell_norm(self.cell_exprs)
  96. cell_encoded = self.cell_encoder(cell_feat)
  97. # Project and transform drug fingerprints
  98. projected = [proj(fp) for proj, fp in zip(self.drug_proj, self.drug_fingerprints)]
  99. stacked = torch.stack(projected, dim=1)
  100. drug_feat = self.transformer(stacked)
  101. drug_feat = self.drug_norm(drug_feat.mean(dim=1))
  102. return cell_encoded, drug_feat
  103. class GEncoder(nn.Module):
  104. """Graph encoder module for processing cell and drug embeddings."""
  105. def __init__(self, agg_c_lp, agg_d_lp, self_c_lp, self_d_lp, device="cpu"):
  106. """
  107. Initialize the graph encoder.
  108. Args:
  109. agg_c_lp (torch.Tensor): Aggregated cell Laplacian matrix.
  110. agg_d_lp (torch.Tensor): Aggregated drug Laplacian matrix.
  111. self_c_lp (torch.Tensor): Self-loop cell Laplacian matrix.
  112. self_d_lp (torch.Tensor): Self-loop drug Laplacian matrix.
  113. device (str): Device to perform computations on.
  114. """
  115. super().__init__()
  116. self.agg_c_lp = agg_c_lp
  117. self.agg_d_lp = agg_d_lp
  118. self.self_c_lp = self_c_lp
  119. self.self_d_lp = self_d_lp
  120. self.device = device
  121. # Cell feature encoder
  122. self.cell_encoder = nn.Sequential(
  123. nn.Linear(512, 1024),
  124. nn.BatchNorm1d(1024),
  125. nn.GELU(),
  126. nn.Dropout(0.3),
  127. nn.Linear(1024, 512)
  128. ).to(device)
  129. # Drug feature encoder
  130. self.drug_encoder = nn.Sequential(
  131. nn.Linear(512, 1024),
  132. nn.BatchNorm1d(1024),
  133. nn.GELU(),
  134. nn.Dropout(0.3),
  135. nn.Linear(1024, 512)
  136. ).to(device)
  137. # Attention mechanism
  138. self.attention = MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True).to(device)
  139. self.residual = nn.Linear(512, 512).to(device)
  140. self.fc = nn.Sequential(
  141. nn.Linear(1024, 512),
  142. nn.BatchNorm1d(512),
  143. nn.GELU(),
  144. nn.Dropout(0.2)
  145. ).to(device)
  146. def forward(self, cell_f, drug_f):
  147. """
  148. Encode cell and drug features using graph-based aggregation and attention.
  149. Args:
  150. cell_f (torch.Tensor): Cell features.
  151. drug_f (torch.Tensor): Drug features.
  152. Returns:
  153. tuple: (cell_emb, drug_emb) encoded embeddings.
  154. """
  155. # Aggregate features using Laplacian matrices
  156. cell_agg = torch.mm(self.agg_c_lp, drug_f) + torch.mm(self.self_c_lp, cell_f)
  157. drug_agg = torch.mm(self.agg_d_lp, cell_f) + torch.mm(self.self_d_lp, drug_f)
  158. # Encode aggregated features
  159. cell_fc = self.cell_encoder(cell_agg)
  160. drug_fc = self.drug_encoder(drug_agg)
  161. # Apply attention mechanism
  162. attn_output, _ = self.attention(
  163. query=cell_fc.unsqueeze(0),
  164. key=drug_fc.unsqueeze(0),
  165. value=drug_fc.unsqueeze(0)
  166. )
  167. attn_output = attn_output.squeeze(0)
  168. cell_emb = cell_fc + self.residual(attn_output)
  169. # Apply final activation
  170. cell_emb = F.gelu(cell_emb)
  171. drug_emb = F.gelu(drug_fc)
  172. return cell_emb, drug_emb
  173. class GDecoder(nn.Module):
  174. """Decoder module to predict interaction scores from embeddings."""
  175. def __init__(self, emb_dim, gamma):
  176. """
  177. Initialize the decoder.
  178. Args:
  179. emb_dim (int): Embedding dimension.
  180. gamma (float): Scaling factor for output scores.
  181. """
  182. super().__init__()
  183. self.gamma = gamma
  184. self.decoder = nn.Sequential(
  185. nn.Linear(2 * emb_dim, 1024),
  186. nn.BatchNorm1d(1024),
  187. nn.GELU(),
  188. nn.Dropout(0.2),
  189. nn.Linear(1024, 1)
  190. )
  191. self.corr_weight = nn.Parameter(torch.tensor(0.5))
  192. def forward(self, cell_emb, drug_emb):
  193. """
  194. Decode embeddings to predict interaction scores.
  195. Args:
  196. cell_emb (torch.Tensor): Cell embeddings.
  197. drug_emb (torch.Tensor): Drug embeddings.
  198. Returns:
  199. torch.Tensor: Predicted interaction scores.
  200. """
  201. # Expand embeddings for pairwise combinations
  202. cell_exp = cell_emb.unsqueeze(1).repeat(1, drug_emb.size(0), 1)
  203. drug_exp = drug_emb.unsqueeze(0).repeat(cell_emb.size(0), 1, 1)
  204. combined = torch.cat([cell_exp, drug_exp], dim=-1)
  205. # Decode combined embeddings
  206. scores = self.decoder(combined.view(-1, 2 * cell_emb.size(1))).view(cell_emb.size(0), drug_emb.size(0))
  207. corr = torch_corr_x_y(cell_emb, drug_emb)
  208. # Combine scores with correlation and apply sigmoid
  209. return torch.sigmoid(self.gamma * (self.corr_weight * scores + (1 - self.corr_weight) * corr))
  210. class DeepTraCDR(nn.Module):
  211. """DeepTraCDR model for cell-drug interaction prediction."""
  212. def __init__(self, adj_mat, cell_exprs, drug_finger, layer_size, gamma, device="cpu"):
  213. """
  214. Initialize the DeepTraCDR model.
  215. Args:
  216. adj_mat (np.ndarray or torch.Tensor): Adjacency matrix.
  217. cell_exprs (np.ndarray): Cell expression data.
  218. drug_finger (list): Drug fingerprints.
  219. layer_size (list): Layer sizes for the model.
  220. gamma (float): Scaling factor for decoder.
  221. device (str): Device to perform computations on.
  222. """
  223. super().__init__()
  224. self.device = device
  225. if isinstance(adj_mat, np.ndarray):
  226. adj_mat = torch.from_numpy(adj_mat).float()
  227. self.adj_mat = adj_mat.to(device)
  228. # Initialize submodules
  229. self.construct_adj = ConstructAdjMatrix(self.adj_mat, device=device)
  230. self.load_feat = LoadFeature(cell_exprs, drug_finger, device=device)
  231. # Compute fixed adjacency matrices
  232. agg_c, agg_d, self_c, self_d = self.construct_adj()
  233. self.encoder = GEncoder(agg_c, agg_d, self_c, self_d, device=device).to(device)
  234. self.decoder = GDecoder(512, gamma).to(device) # Fixed emb_dim to 512
  235. def forward(self):
  236. """
  237. Forward pass through the DeepTraCDR model.
  238. Returns:
  239. tuple: (predicted_scores, cell_emb, drug_emb) predicted scores and embeddings.
  240. """
  241. cell_f, drug_f = self.load_feat()
  242. cell_emb, drug_emb = self.encoder(cell_f, drug_f)
  243. return self.decoder(cell_emb, drug_emb), cell_emb, drug_emb
  244. def get_cell_embeddings(self):
  245. """
  246. Retrieve cell embeddings from the model.
  247. Returns:
  248. torch.Tensor: Cell embeddings.
  249. """
  250. cell_f, drug_f = self.load_feat()
  251. cell_emb, _ = self.encoder(cell_f, drug_f)
  252. return cell_emb
  253. class Optimizer:
  254. """Optimizer class for training the DeepTraCDR model with early stopping."""
  255. def __init__(self, model, train_data, test_data, test_mask, train_mask, adj_matrix, evaluate_fun,
  256. lr=0.001, wd=1e-05, epochs=200, test_freq=20, patience=50, device="cpu"):
  257. """
  258. Initialize the optimizer.
  259. Args:
  260. model (nn.Module): DeepTraCDR model to optimize.
  261. train_data (torch.Tensor): Training data.
  262. test_data (torch.Tensor): Test data.
  263. test_mask (torch.Tensor): Test mask.
  264. train_mask (torch.Tensor): Training mask.
  265. adj_matrix (torch.Tensor): Adjacency matrix.
  266. evaluate_fun (callable): Evaluation function.
  267. lr (float): Learning rate.
  268. wd (float): Weight decay.
  269. epochs (int): Number of training epochs.
  270. test_freq (int): Frequency of evaluation.
  271. patience (int): Patience for early stopping.
  272. device (str): Device to perform computations on.
  273. """
  274. self.model = model.to(device)
  275. self.train_data = train_data.float().to(device)
  276. self.test_data = test_data.float().to(device)
  277. self.train_mask = train_mask.to(device)
  278. self.test_mask_bool = test_mask.to(device).bool()
  279. self.adj_matrix = adj_matrix.to(device)
  280. self.evaluate_fun = evaluate_fun
  281. self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=wd)
  282. self.epochs = epochs
  283. self.test_freq = test_freq
  284. self.patience = patience
  285. self.best_auc = 0.0
  286. self.best_auprc = 0.0
  287. self.best_weights = None
  288. self.counter = 0
  289. self.device = device
  290. self.best_epoch_auc = None
  291. self.best_epoch_auprc = None
  292. def train(self):
  293. """
  294. Train the model with early stopping and evaluate performance.
  295. Returns:
  296. tuple: (true_data, final_pred_masked, best_auc, best_auprc) evaluation results.
  297. """
  298. true_data = torch.masked_select(self.test_data, self.test_mask_bool).cpu().numpy()
  299. for epoch in range(self.epochs):
  300. self.model.train()
  301. # Forward pass and compute loss
  302. pred_train, cell_emb, drug_emb = self.model()
  303. ce_loss = cross_entropy_loss(self.train_data, pred_train, self.train_mask)
  304. proto_loss = prototypical_loss(cell_emb, drug_emb, self.adj_matrix)
  305. total_loss = 0.7 * ce_loss + 0.3 * proto_loss
  306. # Backpropagation
  307. self.optimizer.zero_grad()
  308. total_loss.backward()
  309. self.optimizer.step()
  310. self.model.eval()
  311. with torch.no_grad():
  312. # Evaluate training performance
  313. train_pred, _, _ = self.model()
  314. train_pred_masked = torch.masked_select(train_pred, self.train_mask).cpu().numpy()
  315. train_true_data = torch.masked_select(self.train_data, self.train_mask).cpu().numpy()
  316. try:
  317. train_auc = roc_auc_score(train_true_data, train_pred_masked)
  318. train_auprc = average_precision_score(train_true_data, train_pred_masked)
  319. except ValueError:
  320. train_auc, train_auprc = 0.0, 0.0
  321. # Evaluate test performance
  322. pred_eval, _, _ = self.model()
  323. pred_masked = torch.masked_select(pred_eval, self.test_mask_bool).cpu().numpy()
  324. try:
  325. auc = roc_auc_score(true_data, pred_masked)
  326. auprc = average_precision_score(true_data, pred_masked)
  327. except ValueError:
  328. auc, auprc = 0.0, 0.0
  329. # Update best metrics and weights
  330. if auc > self.best_auc:
  331. self.best_auc = auc
  332. self.best_auprc = auprc
  333. self.best_weights = self.model.state_dict().copy()
  334. self.counter = 0
  335. self.best_epoch_auc = auc
  336. self.best_epoch_auprc = auprc
  337. else:
  338. self.counter += 1
  339. # Log progress
  340. if epoch % self.test_freq == 0 or epoch == self.epochs - 1:
  341. print(f"Epoch {epoch}: Loss={total_loss.item():.4f}, Train AUC={train_auc:.4f}, "
  342. f"Train AUPRC={train_auprc:.4f}, Test AUC={auc:.4f}, Test AUPRC={auprc:.4f}")
  343. # Check early stopping
  344. if self.counter >= self.patience:
  345. print(f"\nEarly stopping triggered at epoch {epoch}!")
  346. print(f"No improvement in AUC for {self.patience} consecutive epochs.")
  347. break
  348. # Restore best weights
  349. if self.best_weights is not None:
  350. self.model.load_state_dict(self.best_weights)
  351. # Final evaluation
  352. self.model.eval()
  353. with torch.no_grad():
  354. final_pred, _, _ = self.model()
  355. final_pred_masked = torch.masked_select(final_pred, self.test_mask_bool).cpu().numpy()
  356. best_auc = roc_auc_score(true_data, final_pred_masked)
  357. best_auprc = average_precision_score(true_data, final_pred_masked)
  358. # Log final results
  359. print("\nBest Metrics After Training (on Test Data):")
  360. print(f"AUC: {self.best_auc:.4f}")
  361. print(f"AUPRC: {self.best_auprc:.4f}")
  362. return true_data, final_pred_masked, self.best_auc, self.best_auprc