You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DeepDRA.py 8.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import torch.nn as nn
  2. import torch
  3. import torch.optim as optim
  4. from torch.optim import lr_scheduler
  5. from autoencoder import Autoencoder
  6. from evaluation import Evaluation
  7. from mlp import MLP
  8. from utils import MODEL_FOLDER
  9. class DeepDRA(nn.Module):
  10. """
  11. DeepDRA (Deep Drug Response Anticipation) is a neural network model composed of two autoencoders for cell and drug modalities
  12. and an MLP for integrating the encoded features and making predictions.
  13. Parameters:
  14. - cell_modality_sizes (list): Sizes of the cell modality features.
  15. - drug_modality_sizes (list): Sizes of the drug modality features.
  16. - cell_ae_latent_dim (int): Latent dimension for the cell autoencoder.
  17. - drug_ae_latent_dim (int): Latent dimension for the drug autoencoder.
  18. - mlp_input_dim (int): Input dimension for the MLP.
  19. - mlp_output_dim (int): Output dimension for the MLP.
  20. """
  21. def __init__(self, cell_modality_sizes, drug_modality_sizes, cell_ae_latent_dim, drug_ae_latent_dim):
  22. super(DeepDRA, self).__init__()
  23. # Initialize cell and drug autoencoders
  24. self.cell_autoencoder = Autoencoder(sum(cell_modality_sizes), cell_ae_latent_dim)
  25. self.drug_autoencoder = Autoencoder(sum(drug_modality_sizes), drug_ae_latent_dim)
  26. # Store modality sizes
  27. self.cell_modality_sizes = cell_modality_sizes
  28. self.drug_modality_sizes = drug_modality_sizes
  29. # Initialize MLP
  30. self.mlp = MLP(cell_ae_latent_dim+drug_ae_latent_dim, 1)
  31. def forward(self, cell_x, drug_x):
  32. """
  33. Forward pass of the DeepDRA model.
  34. Parameters:
  35. - cell_x (torch.Tensor): Input tensor for cell modality.
  36. - drug_x (torch.Tensor): Input tensor for drug modality.
  37. Returns:
  38. - cell_decoded (torch.Tensor): Decoded tensor for the cell modality.
  39. - drug_decoded (torch.Tensor): Decoded tensor for the drug modality.
  40. - mlp_output (torch.Tensor): Output tensor from the MLP.
  41. """
  42. # Encode and decode cell modality
  43. cell_encoded = self.cell_autoencoder.encoder(cell_x)
  44. cell_decoded = self.cell_autoencoder.decoder(cell_encoded)
  45. # Encode and decode drug modality
  46. drug_encoded = self.drug_autoencoder.encoder(drug_x)
  47. drug_decoded = self.drug_autoencoder.decoder(drug_encoded)
  48. # Concatenate encoded cell and drug features and pass through MLP
  49. mlp_output = self.mlp(torch.cat((cell_encoded, drug_encoded), 1))
  50. return cell_decoded, drug_decoded, mlp_output
  51. def compute_l1_loss(self, w):
  52. """
  53. Computes L1 regularization loss.
  54. Parameters:
  55. - w (torch.Tensor): Input tensor.
  56. Returns:
  57. - loss (torch.Tensor): L1 regularization loss.
  58. """
  59. return torch.abs(w).sum()
  60. def compute_l2_loss(self, w):
  61. """
  62. Computes L2 regularization loss.
  63. Parameters:
  64. - w (torch.Tensor): Input tensor.
  65. Returns:
  66. - loss (torch.Tensor): L2 regularization loss.
  67. """
  68. return torch.square(w).sum()
  69. def train(model, train_loader, val_loader, num_epochs,class_weights):
  70. """
  71. Trains the DeepDRA (Deep Drug Response Anticipation) model.
  72. Parameters:
  73. - model (DeepDRA): The DeepDRA model to be trained.
  74. - train_loader (DataLoader): DataLoader for the training dataset.
  75. - num_epochs (int): Number of training epochs.
  76. """
  77. autoencoder_loss_fn = nn.MSELoss()
  78. mlp_loss_fn = nn.BCELoss()
  79. train_accuracies = []
  80. val_accuracies = []
  81. train_loss = []
  82. val_loss = []
  83. mlp_optimizer = optim.Adam(model.parameters(), lr=0.0005,)
  84. scheduler = lr_scheduler.ReduceLROnPlateau(mlp_optimizer, mode='min', factor=0.8, patience=5, verbose=True)
  85. # Define weight parameters for each loss term
  86. cell_ae_weight = 1.0
  87. drug_ae_weight = 1.0
  88. mlp_weight = 1.0
  89. for epoch in range(num_epochs):
  90. model.train()
  91. total_train_loss = 0.0
  92. train_correct = 0
  93. train_total_samples = 0
  94. for batch_idx, (cell_data, drug_data, target) in enumerate(train_loader):
  95. mlp_optimizer.zero_grad()
  96. # Forward pass
  97. cell_decoded_output, drug_decoded_output, mlp_output = model(cell_data, drug_data)
  98. # Compute class weights for the current batch
  99. # batch_class_weights = class_weights[target.long()]
  100. # mlp_loss_fn = nn.BCEWithLogitsLoss(weight=batch_class_weights)
  101. # Compute losses
  102. cell_ae_loss = cell_ae_weight * autoencoder_loss_fn(cell_decoded_output, cell_data)
  103. drug_ae_loss = drug_ae_weight * autoencoder_loss_fn(drug_decoded_output, drug_data)
  104. mlp_loss = mlp_weight * mlp_loss_fn(mlp_output, target)
  105. # Total loss is the sum of autoencoder losses and MLP loss
  106. total_loss = drug_ae_loss + cell_ae_loss + mlp_loss
  107. # Backward pass and optimization
  108. total_loss.backward()
  109. torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
  110. mlp_optimizer.step()
  111. total_train_loss += total_loss.item()
  112. # Calculate accuracy
  113. train_predictions = torch.round(mlp_output)
  114. train_correct += (train_predictions == target).sum().item()
  115. train_total_samples += target.size(0)
  116. avg_train_loss = total_train_loss / len(train_loader)
  117. train_loss.append(avg_train_loss)
  118. # Validation
  119. model.eval()
  120. total_val_loss = 0.0
  121. correct = 0
  122. total_samples = 0
  123. with torch.no_grad():
  124. for val_batch_idx, (cell_data_val, drug_data_val, val_target) in enumerate(val_loader):
  125. cell_decoded_output_val, drug_decoded_output_val, mlp_output_val = model(cell_data_val, drug_data_val)
  126. # batch_class_weights = class_weights[val_target.long()]
  127. # mlp_loss_fn = nn.BCEWithLogitsLoss(weight=batch_class_weights)
  128. # Compute losses
  129. cell_ae_loss_val = cell_ae_weight * autoencoder_loss_fn(cell_decoded_output_val, cell_data_val)
  130. drug_ae_loss_val = drug_ae_weight * autoencoder_loss_fn(drug_decoded_output_val, drug_data_val)
  131. mlp_loss_val = mlp_weight * mlp_loss_fn(mlp_output_val, val_target)
  132. # Total loss is the sum of autoencoder losses and MLP loss
  133. total_val_loss = (drug_ae_loss_val + cell_ae_loss_val + mlp_loss_val).item()
  134. # Calculate accuracy
  135. val_predictions = torch.round(mlp_output_val)
  136. correct += (val_predictions == val_target).sum().item()
  137. total_samples += val_target.size(0)
  138. avg_val_loss = total_val_loss / len(val_loader)
  139. val_loss.append(avg_val_loss)
  140. train_accuracy = train_correct / train_total_samples
  141. train_accuracies.append(train_accuracy)
  142. val_accuracy = correct / total_samples
  143. val_accuracies.append(val_accuracy)
  144. print(
  145. 'Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}, Train Accuracy: {:.4f}, Val Accuracy: {:.4f}'.format(
  146. epoch + 1, num_epochs, avg_train_loss, avg_val_loss, train_accuracy,
  147. val_accuracy))
  148. # Learning rate scheduler step
  149. scheduler.step(total_train_loss)
  150. # Save the trained model
  151. torch.save(model.state_dict(), MODEL_FOLDER + 'DeepDRA.pth')
  152. def test(model, test_loader, reverse=False):
  153. """
  154. Tests the given model on the test dataset using evaluation metrics.
  155. Parameters:
  156. - model: The trained model to be evaluated.
  157. - test_loader: DataLoader for the test dataset.
  158. - reverse (bool): If True, reverse the predictions for evaluation.
  159. Returns:
  160. - result: The evaluation result based on the chosen metrics.
  161. """
  162. # Set model to evaluation mode
  163. model.eval()
  164. # Initialize lists to store predictions and ground truth labels
  165. all_predictions = []
  166. all_labels = []
  167. # Iterate over the test dataset
  168. for i, (test_cell_loader, test_drug_loader, labels) in enumerate(test_loader):
  169. # Forward pass through the model
  170. with torch.no_grad():
  171. decoded_cell_output, decoded_drug_output, mlp_output = model(test_cell_loader, test_drug_loader)
  172. # Apply reverse if specified
  173. predictions = 1 - mlp_output if reverse else mlp_output
  174. # Evaluate the predictions using the specified metrics
  175. result = Evaluation.evaluate(labels, predictions)
  176. return result