You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DeepDRA.py 5.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import torch.nn as nn
  2. import torch
  3. import torch.optim as optim
  4. from torch.optim import lr_scheduler
  5. from autoencoder import Autoencoder
  6. from evaluation import Evaluation
  7. from mlp import MLP
  8. from utils import MODEL_FOLDER
  9. class DeepDRA(nn.Module):
  10. """
  11. DeepDRA (Deep Drug Response Anticipation) is a neural network model composed of two autoencoders for cell and drug modalities
  12. and an MLP for integrating the encoded features and making predictions.
  13. Parameters:
  14. - cell_modality_sizes (list): Sizes of the cell modality features.
  15. - drug_modality_sizes (list): Sizes of the drug modality features.
  16. - cell_ae_latent_dim (int): Latent dimension for the cell autoencoder.
  17. - drug_ae_latent_dim (int): Latent dimension for the drug autoencoder.
  18. - mlp_input_dim (int): Input dimension for the MLP.
  19. - mlp_output_dim (int): Output dimension for the MLP.
  20. """
  21. def __init__(self, cell_modality_sizes, drug_modality_sizes, cell_ae_latent_dim, drug_ae_latent_dim, mlp_input_dim,
  22. mlp_output_dim):
  23. super(DeepDRA, self).__init__()
  24. # Initialize cell and drug autoencoders
  25. self.cell_autoencoder = Autoencoder(sum(cell_modality_sizes), cell_ae_latent_dim)
  26. self.drug_autoencoder = Autoencoder(sum(drug_modality_sizes), drug_ae_latent_dim)
  27. # Store modality sizes
  28. self.cell_modality_sizes = cell_modality_sizes
  29. self.drug_modality_sizes = drug_modality_sizes
  30. # Initialize MLP
  31. self.mlp = MLP(mlp_input_dim, mlp_output_dim)
  32. def forward(self, cell_x, drug_x):
  33. """
  34. Forward pass of the DeepDRA model.
  35. Parameters:
  36. - cell_x (torch.Tensor): Input tensor for cell modality.
  37. - drug_x (torch.Tensor): Input tensor for drug modality.
  38. Returns:
  39. - cell_decoded (torch.Tensor): Decoded tensor for the cell modality.
  40. - drug_decoded (torch.Tensor): Decoded tensor for the drug modality.
  41. - mlp_output (torch.Tensor): Output tensor from the MLP.
  42. """
  43. # Encode and decode cell modality
  44. cell_encoded = self.cell_autoencoder.encoder(cell_x)
  45. cell_decoded = self.cell_autoencoder.decoder(cell_encoded)
  46. # Encode and decode drug modality
  47. drug_encoded = self.drug_autoencoder.encoder(drug_x)
  48. drug_decoded = self.drug_autoencoder.decoder(drug_encoded)
  49. # Concatenate encoded cell and drug features and pass through MLP
  50. mlp_output = self.mlp(torch.cat((cell_encoded, drug_encoded), 1))
  51. return cell_decoded, drug_decoded, mlp_output
  52. def compute_l1_loss(self, w):
  53. """
  54. Computes L1 regularization loss.
  55. Parameters:
  56. - w (torch.Tensor): Input tensor.
  57. Returns:
  58. - loss (torch.Tensor): L1 regularization loss.
  59. """
  60. return torch.abs(w).sum()
  61. def compute_l2_loss(self, w):
  62. """
  63. Computes L2 regularization loss.
  64. Parameters:
  65. - w (torch.Tensor): Input tensor.
  66. Returns:
  67. - loss (torch.Tensor): L2 regularization loss.
  68. """
  69. return torch.square(w).sum()
  70. def train(model, train_loader, num_epochs):
  71. """
  72. Trains the DeepDRA (Deep Drug Response Anticipation) model.
  73. Parameters:
  74. - model (DeepDRA): The DeepDRA model to be trained.
  75. - train_loader (DataLoader): DataLoader for the training dataset.
  76. - num_epochs (int): Number of training epochs.
  77. """
  78. autoencoder_loss_fn = nn.MSELoss()
  79. mlp_loss_fn = nn.BCELoss()
  80. mlp_optimizer = optim.Adam(model.parameters(), lr=0.0005)
  81. scheduler = lr_scheduler.ReduceLROnPlateau(mlp_optimizer, mode='min', factor=0.8, patience=5, verbose=True)
  82. for epoch in range(num_epochs):
  83. for batch_idx, (cell_data, drug_data, target) in enumerate(train_loader):
  84. mlp_optimizer.zero_grad()
  85. # Forward pass
  86. cell_decoded_output, drug_decoded_output, mlp_output = model(cell_data, drug_data)
  87. # Compute losses
  88. cell_ae_loss = autoencoder_loss_fn(cell_decoded_output, cell_data)
  89. drug_ae_loss = autoencoder_loss_fn(drug_decoded_output, drug_data)
  90. mlp_loss = mlp_loss_fn(mlp_output, target)
  91. # Total loss is the sum of autoencoder losses and MLP loss
  92. total_loss = drug_ae_loss + cell_ae_loss + mlp_loss
  93. # Backward pass and optimization
  94. total_loss.backward()
  95. mlp_optimizer.step()
  96. # Print progress
  97. if batch_idx % 200 == 0:
  98. print('Epoch [{}/{}], Total Loss: {:.4f}'.format(
  99. epoch + 1, num_epochs, total_loss.item()))
  100. # Learning rate scheduler step
  101. scheduler.step(total_loss)
  102. # Save the trained model
  103. torch.save(model.state_dict(), MODEL_FOLDER + 'DeepDRA.pth')
  104. def test(model, test_loader, reverse=False):
  105. """
  106. Tests the given model on the test dataset using evaluation metrics.
  107. Parameters:
  108. - model: The trained model to be evaluated.
  109. - test_loader: DataLoader for the test dataset.
  110. - reverse (bool): If True, reverse the predictions for evaluation.
  111. Returns:
  112. - result: The evaluation result based on the chosen metrics.
  113. """
  114. # Set model to evaluation mode
  115. model.eval()
  116. # Initialize lists to store predictions and ground truth labels
  117. all_predictions = []
  118. all_labels = []
  119. # Iterate over the test dataset
  120. for i, (test_cell_loader, test_drug_loader, labels) in enumerate(test_loader):
  121. # Forward pass through the model
  122. with torch.no_grad():
  123. decoded_cell_output, decoded_drug_output, mlp_output = model(test_cell_loader, test_drug_loader)
  124. # Apply reverse if specified
  125. predictions = 1 - mlp_output if reverse else mlp_output
  126. # Evaluate the predictions using the specified metrics
  127. result = Evaluation.evaluate(labels, predictions)
  128. return result