123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- import torch.nn as nn
- import torch
- import torch.optim as optim
- from torch.optim import lr_scheduler
- from autoencoder import Autoencoder
- from evaluation import Evaluation
- from mlp import MLP
- from utils import MODEL_FOLDER
-
-
- class DeepDRA(nn.Module):
- """
- DeepDRA (Deep Drug Response Anticipation) is a neural network model composed of two autoencoders for cell and drug modalities
- and an MLP for integrating the encoded features and making predictions.
-
- Parameters:
- - cell_modality_sizes (list): Sizes of the cell modality features.
- - drug_modality_sizes (list): Sizes of the drug modality features.
- - cell_ae_latent_dim (int): Latent dimension for the cell autoencoder.
- - drug_ae_latent_dim (int): Latent dimension for the drug autoencoder.
- - mlp_input_dim (int): Input dimension for the MLP.
- - mlp_output_dim (int): Output dimension for the MLP.
- """
-
- def __init__(self, cell_modality_sizes, drug_modality_sizes, cell_ae_latent_dim, drug_ae_latent_dim, mlp_input_dim,
- mlp_output_dim):
- super(DeepDRA, self).__init__()
-
- # Initialize cell and drug autoencoders
- self.cell_autoencoder = Autoencoder(sum(cell_modality_sizes), cell_ae_latent_dim)
- self.drug_autoencoder = Autoencoder(sum(drug_modality_sizes), drug_ae_latent_dim)
-
- # Store modality sizes
- self.cell_modality_sizes = cell_modality_sizes
- self.drug_modality_sizes = drug_modality_sizes
-
- # Initialize MLP
- self.mlp = MLP(mlp_input_dim, mlp_output_dim)
-
- def forward(self, cell_x, drug_x):
- """
- Forward pass of the DeepDRA model.
-
- Parameters:
- - cell_x (torch.Tensor): Input tensor for cell modality.
- - drug_x (torch.Tensor): Input tensor for drug modality.
-
- Returns:
- - cell_decoded (torch.Tensor): Decoded tensor for the cell modality.
- - drug_decoded (torch.Tensor): Decoded tensor for the drug modality.
- - mlp_output (torch.Tensor): Output tensor from the MLP.
- """
- # Encode and decode cell modality
- cell_encoded = self.cell_autoencoder.encoder(cell_x)
- cell_decoded = self.cell_autoencoder.decoder(cell_encoded)
-
- # Encode and decode drug modality
- drug_encoded = self.drug_autoencoder.encoder(drug_x)
- drug_decoded = self.drug_autoencoder.decoder(drug_encoded)
-
- # Concatenate encoded cell and drug features and pass through MLP
- mlp_output = self.mlp(torch.cat((cell_encoded, drug_encoded), 1))
-
- return cell_decoded, drug_decoded, mlp_output
-
- def compute_l1_loss(self, w):
- """
- Computes L1 regularization loss.
-
- Parameters:
- - w (torch.Tensor): Input tensor.
-
- Returns:
- - loss (torch.Tensor): L1 regularization loss.
- """
- return torch.abs(w).sum()
-
- def compute_l2_loss(self, w):
- """
- Computes L2 regularization loss.
-
- Parameters:
- - w (torch.Tensor): Input tensor.
-
- Returns:
- - loss (torch.Tensor): L2 regularization loss.
- """
- return torch.square(w).sum()
-
-
- def train(model, train_loader, num_epochs):
- """
- Trains the DeepDRA (Deep Drug Response Anticipation) model.
-
- Parameters:
- - model (DeepDRA): The DeepDRA model to be trained.
- - train_loader (DataLoader): DataLoader for the training dataset.
- - num_epochs (int): Number of training epochs.
- """
- autoencoder_loss_fn = nn.MSELoss()
- mlp_loss_fn = nn.BCELoss()
-
- mlp_optimizer = optim.Adam(model.parameters(), lr=0.0005)
- scheduler = lr_scheduler.ReduceLROnPlateau(mlp_optimizer, mode='min', factor=0.8, patience=5, verbose=True)
-
- for epoch in range(num_epochs):
- for batch_idx, (cell_data, drug_data, target) in enumerate(train_loader):
- mlp_optimizer.zero_grad()
-
- # Forward pass
- cell_decoded_output, drug_decoded_output, mlp_output = model(cell_data, drug_data)
-
- # Compute losses
- cell_ae_loss = autoencoder_loss_fn(cell_decoded_output, cell_data)
- drug_ae_loss = autoencoder_loss_fn(drug_decoded_output, drug_data)
- mlp_loss = mlp_loss_fn(mlp_output, target)
-
- # Total loss is the sum of autoencoder losses and MLP loss
- total_loss = drug_ae_loss + cell_ae_loss + mlp_loss
-
- # Backward pass and optimization
- total_loss.backward()
- mlp_optimizer.step()
-
- # Print progress
- if batch_idx % 200 == 0:
- print('Epoch [{}/{}], Total Loss: {:.4f}'.format(
- epoch + 1, num_epochs, total_loss.item()))
-
- # Learning rate scheduler step
- scheduler.step(total_loss)
-
- # Save the trained model
- torch.save(model.state_dict(), MODEL_FOLDER + 'DeepDRA.pth')
-
-
- def test(model, test_loader, reverse=False):
- """
- Tests the given model on the test dataset using evaluation metrics.
-
- Parameters:
- - model: The trained model to be evaluated.
- - test_loader: DataLoader for the test dataset.
- - reverse (bool): If True, reverse the predictions for evaluation.
-
- Returns:
- - result: The evaluation result based on the chosen metrics.
- """
- # Set model to evaluation mode
- model.eval()
-
- # Initialize lists to store predictions and ground truth labels
- all_predictions = []
- all_labels = []
-
- # Iterate over the test dataset
- for i, (test_cell_loader, test_drug_loader, labels) in enumerate(test_loader):
- # Forward pass through the model
- with torch.no_grad():
- decoded_cell_output, decoded_drug_output, mlp_output = model(test_cell_loader, test_drug_loader)
-
- # Apply reverse if specified
- predictions = 1 - mlp_output if reverse else mlp_output
-
- # Evaluate the predictions using the specified metrics
- result = Evaluation.evaluate(labels, predictions)
-
- return result
|