import torch.nn as nn | |||||
import torch | |||||
import torch.optim as optim | |||||
from torch.optim import lr_scheduler | |||||
from autoencoder import Autoencoder | |||||
from evaluation import Evaluation | |||||
from mlp import MLP | |||||
from utils import * | |||||
class DeepDRA(nn.Module): | |||||
""" | |||||
DeepDRA (Deep Drug Response Anticipation) is a neural network model composed of two autoencoders for cell and drug modalities | |||||
and an MLP for integrating the encoded features and making predictions. | |||||
Parameters: | |||||
- cell_modality_sizes (list): Sizes of the cell modality features. | |||||
- drug_modality_sizes (list): Sizes of the drug modality features. | |||||
- cell_ae_latent_dim (int): Latent dimension for the cell autoencoder. | |||||
- drug_ae_latent_dim (int): Latent dimension for the drug autoencoder. | |||||
- mlp_input_dim (int): Input dimension for the MLP. | |||||
- mlp_output_dim (int): Output dimension for the MLP. | |||||
""" | |||||
def __init__(self, cell_modality_sizes, drug_modality_sizes, cell_ae_latent_dim, drug_ae_latent_dim, mlp_input_dim, | |||||
mlp_output_dim): | |||||
super(DeepDRA, self).__init__() | |||||
# Initialize cell and drug autoencoders | |||||
self.cell_autoencoder = Autoencoder(sum(cell_modality_sizes), cell_ae_latent_dim) | |||||
self.drug_autoencoder = Autoencoder(sum(drug_modality_sizes), drug_ae_latent_dim) | |||||
# Store modality sizes | |||||
self.cell_modality_sizes = cell_modality_sizes | |||||
self.drug_modality_sizes = drug_modality_sizes | |||||
# Initialize MLP | |||||
self.mlp = MLP(mlp_input_dim, mlp_output_dim) | |||||
def forward(self, cell_x, drug_x): | |||||
""" | |||||
Forward pass of the DeepDRA model. | |||||
Parameters: | |||||
- cell_x (torch.Tensor): Input tensor for cell modality. | |||||
- drug_x (torch.Tensor): Input tensor for drug modality. | |||||
Returns: | |||||
- cell_decoded (torch.Tensor): Decoded tensor for the cell modality. | |||||
- drug_decoded (torch.Tensor): Decoded tensor for the drug modality. | |||||
- mlp_output (torch.Tensor): Output tensor from the MLP. | |||||
""" | |||||
# Encode and decode cell modality | |||||
cell_encoded = self.cell_autoencoder.encoder(cell_x) | |||||
cell_decoded = self.cell_autoencoder.decoder(cell_encoded) | |||||
# Encode and decode drug modality | |||||
drug_encoded = self.drug_autoencoder.encoder(drug_x) | |||||
drug_decoded = self.drug_autoencoder.decoder(drug_encoded) | |||||
# Concatenate encoded cell and drug features and pass through MLP | |||||
mlp_output = self.mlp(torch.cat((cell_encoded, drug_encoded), 1)) | |||||
return cell_decoded, drug_decoded, mlp_output | |||||
def compute_l1_loss(self, w): | |||||
""" | |||||
Computes L1 regularization loss. | |||||
Parameters: | |||||
- w (torch.Tensor): Input tensor. | |||||
Returns: | |||||
- loss (torch.Tensor): L1 regularization loss. | |||||
""" | |||||
return torch.abs(w).sum() | |||||
def compute_l2_loss(self, w): | |||||
""" | |||||
Computes L2 regularization loss. | |||||
Parameters: | |||||
- w (torch.Tensor): Input tensor. | |||||
Returns: | |||||
- loss (torch.Tensor): L2 regularization loss. | |||||
""" | |||||
return torch.square(w).sum() | |||||
def train(model, train_loader, num_epochs): | |||||
""" | |||||
Trains the DeepDRA (Deep Drug Response Anticipation) model. | |||||
Parameters: | |||||
- model (DeepDRA): The DeepDRA model to be trained. | |||||
- train_loader (DataLoader): DataLoader for the training dataset. | |||||
- num_epochs (int): Number of training epochs. | |||||
""" | |||||
autoencoder_loss_fn = nn.MSELoss() | |||||
mlp_loss_fn = nn.BCELoss() | |||||
mlp_optimizer = optim.Adam(model.parameters(), lr=0.0005) | |||||
scheduler = lr_scheduler.ReduceLROnPlateau(mlp_optimizer, mode='min', factor=0.8, patience=5, verbose=True) | |||||
for epoch in range(num_epochs): | |||||
for batch_idx, (cell_data, drug_data, target) in enumerate(train_loader): | |||||
mlp_optimizer.zero_grad() | |||||
# Forward pass | |||||
cell_decoded_output, drug_decoded_output, mlp_output = model(cell_data, drug_data) | |||||
# Compute losses | |||||
cell_ae_loss = autoencoder_loss_fn(cell_decoded_output, cell_data) | |||||
drug_ae_loss = autoencoder_loss_fn(drug_decoded_output, drug_data) | |||||
mlp_loss = mlp_loss_fn(mlp_output, target) | |||||
# Total loss is the sum of autoencoder losses and MLP loss | |||||
total_loss = drug_ae_loss + cell_ae_loss + mlp_loss | |||||
# Backward pass and optimization | |||||
total_loss.backward() | |||||
mlp_optimizer.step() | |||||
# Print progress | |||||
if batch_idx % 200 == 0: | |||||
print('Epoch [{}/{}], Total Loss: {:.4f}'.format( | |||||
epoch + 1, num_epochs, total_loss.item())) | |||||
# Learning rate scheduler step | |||||
scheduler.step(total_loss) | |||||
# Save the trained model | |||||
torch.save(model.state_dict(), MODEL_FOLDER + 'DeepDRA.pth') | |||||
def test(model, test_loader, reverse=False): | |||||
""" | |||||
Tests the given model on the test dataset using evaluation metrics. | |||||
Parameters: | |||||
- model: The trained model to be evaluated. | |||||
- test_loader: DataLoader for the test dataset. | |||||
- reverse (bool): If True, reverse the predictions for evaluation. | |||||
Returns: | |||||
- result: The evaluation result based on the chosen metrics. | |||||
""" | |||||
# Set model to evaluation mode | |||||
model.eval() | |||||
# Initialize lists to store predictions and ground truth labels | |||||
all_predictions = [] | |||||
all_labels = [] | |||||
# Iterate over the test dataset | |||||
for i, (test_cell_loader, test_drug_loader, labels) in enumerate(test_loader): | |||||
# Forward pass through the model | |||||
with torch.no_grad(): | |||||
decoded_cell_output, decoded_drug_output, mlp_output = model(test_cell_loader, test_drug_loader) | |||||
# Apply reverse if specified | |||||
predictions = 1 - mlp_output if reverse else mlp_output | |||||
# # Store predictions and ground truth labels | |||||
# all_predictions.extend(predictions.cpu().numpy()) | |||||
# all_labels.extend(labels.cpu().numpy()) | |||||
# Evaluate the predictions using the specified metrics | |||||
result = Evaluation.evaluate(labels, predictions) | |||||
return result | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.optim as optim | |||||
class Autoencoder(nn.Module): | |||||
""" | |||||
Autoencoder neural network model for feature learning. | |||||
Parameters: | |||||
- input_dim (int): Dimensionality of the input features. | |||||
- latent_dim (int): Dimensionality of the latent space. | |||||
""" | |||||
def __init__(self, input_dim, latent_dim): | |||||
super(Autoencoder, self).__init__() | |||||
# Encoder architecture | |||||
self.encoder = nn.Sequential( | |||||
nn.Linear(input_dim, 256), | |||||
nn.ReLU(inplace=True), | |||||
nn.Linear(256, latent_dim), | |||||
nn.ReLU(inplace=True), | |||||
) | |||||
# Decoder architecture | |||||
self.decoder = nn.Sequential( | |||||
nn.Linear(latent_dim, 256), | |||||
nn.ReLU(inplace=True), | |||||
nn.Linear(256, input_dim), | |||||
) | |||||
def forward(self, x): | |||||
""" | |||||
Forward pass of the autoencoder. | |||||
Parameters: | |||||
- x (torch.Tensor): Input tensor. | |||||
Returns: | |||||
- decoded (torch.Tensor): Decoded output tensor. | |||||
""" | |||||
encoded = self.encoder(x) | |||||
decoded = self.decoder(encoded) | |||||
return decoded | |||||
def trainAutoencoder(model, train_loader, val_loader, num_epochs, name): | |||||
""" | |||||
Train the autoencoder model. | |||||
Parameters: | |||||
- model (Autoencoder): The autoencoder model to be trained. | |||||
- train_loader (DataLoader): DataLoader for the training dataset. | |||||
- val_loader (DataLoader): DataLoader for the validation dataset. | |||||
- num_epochs (int): Number of training epochs. | |||||
- name (str): Name to save the trained model. | |||||
Returns: | |||||
- None | |||||
""" | |||||
loss_fn = nn.MSELoss() | |||||
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-8) | |||||
train_loss = [] | |||||
val_final_loss = [] | |||||
for epoch in range(num_epochs): | |||||
# Training | |||||
model.train() | |||||
total_train_loss = 0.0 | |||||
for batch_idx, data in enumerate(train_loader): | |||||
data = data[0] | |||||
output = model(data) | |||||
loss = loss_fn(output, data) | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
total_train_loss += loss | |||||
avg_train_loss = total_train_loss | |||||
train_loss.append(avg_train_loss) | |||||
# Validation | |||||
model.eval() | |||||
total_val_loss = 0.0 | |||||
with torch.no_grad(): | |||||
for val_batch_idx, (val_data) in enumerate(val_loader): | |||||
val_data = val_data[0] | |||||
val_output = model(val_data) | |||||
val_loss = loss_fn(val_output, val_data) | |||||
total_val_loss += val_loss | |||||
avg_val_loss = total_val_loss | |||||
val_final_loss.append(avg_val_loss) | |||||
print('Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}'.format( | |||||
epoch + 1, num_epochs, avg_train_loss, avg_val_loss)) | |||||
before_lr = optimizer.param_groups[0]["lr"] | |||||
after_lr = optimizer.param_groups[0]["lr"] | |||||
if before_lr != after_lr: | |||||
print("Epoch %d: Adam lr %.8f -> %.8f" % (epoch, before_lr, after_lr)) | |||||
# Save the trained model | |||||
torch.save(model.state_dict(), "autoencoder" + name + '.pth') | |||||
print(model.encoder[0].weight.detach().numpy()) |
from utils import * | |||||
import os | |||||
class RawDataLoader: | |||||
@staticmethod | |||||
def load_data(data_modalities, raw_file_directory, screen_file_directory, sep): | |||||
""" | |||||
Load raw data and screening data, perform intersection, and adjust screening data. | |||||
Parameters: | |||||
- data_modalities (list): List of data modalities to load. | |||||
- raw_file_directory (str): Directory containing raw data files. | |||||
- screen_file_directory (str): Directory containing screening data files. | |||||
- sep (str): Separator used in the data files. | |||||
Returns: | |||||
- data (dict): Dictionary containing loaded raw data. | |||||
- drug_screen (pd.DataFrame): Adjusted and intersected screening data. | |||||
""" | |||||
# Step 1: Load raw data files for specified data modalities | |||||
data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities, | |||||
raw_file_directory=raw_file_directory) | |||||
# Step 2: Load drug data files for specified data modalities | |||||
drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities, | |||||
raw_file_directory=DRUG_DATA_FOLDER) | |||||
# Step 3: Update the 'data' dictionary with drug data | |||||
data.update(drug_data) | |||||
# Step 4: Load and adjust drug screening data | |||||
drug_screen = RawDataLoader.load_screening_files( | |||||
filename=screen_file_directory, | |||||
sep=sep) | |||||
drug_screen, data = RawDataLoader.adjust_screening_raw( | |||||
drug_screen=drug_screen, data_dict=data) | |||||
# Step 5: Return the loaded data and adjusted drug screening data | |||||
return data, drug_screen | |||||
@staticmethod | |||||
def intersect_features(data1, data2): | |||||
""" | |||||
Perform intersection of features between two datasets. | |||||
Parameters: | |||||
- data1 (pd.DataFrame): First dataset. | |||||
- data2 (pd.DataFrame): Second dataset. | |||||
Returns: | |||||
- data1 (pd.DataFrame): First dataset with common columns. | |||||
- data2 (pd.DataFrame): Second dataset with common columns. | |||||
""" | |||||
# Step 1: Find common columns between the two datasets | |||||
common_columns = list(set(data1.columns) & set(data2.columns)) | |||||
# Step 2: Filter data2 to include only common columns | |||||
data2 = data2[common_columns] | |||||
# Step 3: Filter data1 to include only common columns | |||||
data1 = data1[common_columns] | |||||
# Step 4: Return the datasets with intersected features | |||||
return data1, data2 | |||||
@staticmethod | |||||
def data_features_intersect(data1, data2): | |||||
""" | |||||
Intersect features between two datasets column-wise. | |||||
Parameters: | |||||
- data1 (dict): Dictionary containing data modalities. | |||||
- data2 (dict): Dictionary containing data modalities. | |||||
Returns: | |||||
- intersected_data1 (dict): Data1 with intersected features. | |||||
- intersected_data2 (dict): Data2 with intersected features. | |||||
""" | |||||
# Iterate over each data modality | |||||
for i in data1: | |||||
# Intersect features for each modality | |||||
data1[i], data2[i] = RawDataLoader.intersect_features(data1[i], data2[i]) | |||||
return data1, data2 | |||||
@staticmethod | |||||
def load_file(address, index_column=None): | |||||
""" | |||||
Load data from a file based on its format. | |||||
Parameters: | |||||
- address (str): File address. | |||||
- index_column (str): Name of the index column. | |||||
Returns: | |||||
- data (pd.DataFrame): Loaded data from the file. | |||||
""" | |||||
data = [] | |||||
try: | |||||
# Load data based on file format | |||||
if address.endswith('.txt') or address.endswith('.tsv'): | |||||
data.append(pd.read_csv(address, sep='\t', index_col=index_column), ) | |||||
elif address.endswith('.csv'): | |||||
data.append(pd.read_csv(address)) | |||||
elif address.endswith('.xlsx'): | |||||
data.append(pd.read_excel(address)) | |||||
except FileNotFoundError: | |||||
print(f'File not found at address: {address}') | |||||
return data[0] | |||||
@staticmethod | |||||
def load_raw_files(raw_file_directory, data_modalities, intersect=True): | |||||
raw_dict = {} | |||||
files = os.listdir(raw_file_directory) | |||||
cell_line_names = None | |||||
drug_names = None | |||||
for file in tqdm(files, 'Reading Raw Data Files...'): | |||||
if any([file.startswith(x) for x in data_modalities]): | |||||
if file.endswith('_raw.gzip'): | |||||
df = pd.read_parquet(os.path.join(raw_file_directory, file)) | |||||
elif file.endswith('_raw.tsv'): | |||||
df = pd.read_csv(os.path.join(raw_file_directory, file), sep='\t', index_col=0) | |||||
else: | |||||
continue | |||||
if df.index.is_numeric(): | |||||
df = df.set_index(df.columns[0]) | |||||
df = df.sort_index() | |||||
df = df.sort_index(axis=1) | |||||
df.columns = df.columns.str.replace('_cell_mut', '') | |||||
df.columns = df.columns.str.replace('_cell_CN', '') | |||||
df.columns = df.columns.str.replace('_cell_exp', '') | |||||
# Note that drug_comp raw table has some NA values so we should impute it | |||||
if any(df.isna()): | |||||
df = pd.DataFrame(SimpleImputer(strategy='mean').fit_transform(df), | |||||
columns=df.columns).set_index(df.index) | |||||
if file.startswith('drug_comp'): # We need to normalize the drug_data comp dataset | |||||
df = ((df - df.mean()) / df.std()).fillna(0) | |||||
elif file.startswith('drug_desc'): # We need to normalize the drug_data comp dataset | |||||
df = ((df - df.mean()) / df.std()).fillna(0) | |||||
if intersect: | |||||
if file.startswith('cell'): | |||||
if cell_line_names: | |||||
cell_line_names = cell_line_names.intersection(set(df.index)) | |||||
else: | |||||
cell_line_names = set(df.index) | |||||
elif file.startswith('drug'): | |||||
if drug_names: | |||||
drug_names = drug_names.intersection(set(df.index)) | |||||
else: | |||||
drug_names = set(df.index) | |||||
raw_dict[file[:file.find('_raw')]] = df | |||||
if intersect: | |||||
for key, value in raw_dict.items(): | |||||
if key.startswith('cell'): | |||||
data = value.loc[list(cell_line_names)] | |||||
raw_dict[key] = data.loc[~data.index.duplicated()] | |||||
elif key.startswith('drug'): | |||||
data = value.loc[list(drug_names)] | |||||
raw_dict[key] = data.loc[~data.index.duplicated()] | |||||
return raw_dict | |||||
@staticmethod | |||||
def load_screening_files(filename="AUC_matS_comb.tsv", sep=',', ): | |||||
df = pd.read_csv(filename, sep=sep, index_col=0) | |||||
# df = df.drop(['Erlotinib','17-AAG','PD-0325901','PHA-665752','PHA-665752','TAE684','Sorafenib','PLX4720','selumetinib','PD-0332991','Paclitaxel','Nilotinib','Saracatinib'],axis=1) | |||||
return df | |||||
# return pd.read_csv(os.path.join(DATA_FOLDER, "drug_screening_matrix_GDSC.tsv"), sep='\t', index_col=0) | |||||
@staticmethod | |||||
def adjust_screening_raw(drug_screen, data_dict): | |||||
raw_cell_names = [] | |||||
for key, value in data_dict.items(): | |||||
if 'cell' in key: | |||||
if len(raw_cell_names) == 0: | |||||
raw_cell_names = value.index | |||||
else: | |||||
raw_cell_names = raw_cell_names.intersection(value.index) | |||||
elif 'drug' in key: | |||||
raw_drug_names = value.index | |||||
screening_cell_names = drug_screen.index | |||||
screening_drug_names = drug_screen.columns | |||||
common_cell_names = list(set(raw_cell_names).intersection(set(screening_cell_names))) | |||||
common_drug_names = list(set(raw_drug_names).intersection(set(screening_drug_names))) | |||||
for key, value in data_dict.items(): | |||||
if 'cell' in key: | |||||
data_dict[key] = value.loc[common_cell_names] | |||||
else: | |||||
data_dict[key] = value.loc[common_drug_names] | |||||
return drug_screen.loc[common_cell_names, common_drug_names], data_dict | |||||
@staticmethod | |||||
def prepare_input_data(data_dict, screening): | |||||
print('Preparing data...') | |||||
resistance = np.argwhere((screening.to_numpy() == 1)).tolist() | |||||
resistance.sort(key=lambda x: (x[1], x[0])) | |||||
resistance = np.array(resistance) | |||||
sensitive = np.argwhere((screening.to_numpy() == -1)).tolist() | |||||
sensitive.sort(key=lambda x: (x[1], x[0])) | |||||
sensitive = np.array(sensitive) | |||||
print("sensitive train data len:", len(sensitive)) | |||||
print("resistance train data len:", len(resistance)) | |||||
A_train_mask = np.ones(len(resistance), dtype=bool) | |||||
B_train_mask = np.ones(len(sensitive), dtype=bool) | |||||
resistance = resistance[A_train_mask] | |||||
sensitive = sensitive[B_train_mask] | |||||
cell_data_types = list(filter(lambda x: x.startswith('cell'), data_dict.keys())) | |||||
cell_data_types.sort() | |||||
cell_data = pd.concat( | |||||
[pd.DataFrame(data_dict[data_type].add_suffix(f'_{data_type}'), dtype=np.float32) for | |||||
data_type in cell_data_types], axis=1) | |||||
cell_data_sizes = [data_dict[data_type].shape[1] for data_type in cell_data_types] | |||||
drug_data_types = list(filter(lambda x: x.startswith('drug'), data_dict.keys())) | |||||
drug_data_types.sort() | |||||
drug_data = pd.concat( | |||||
[pd.DataFrame(data_dict[data_type].add_suffix(f'_{data_type}'), dtype=np.float32, ) | |||||
for data_type in drug_data_types], axis=1) | |||||
drug_data_sizes = [data_dict[data_type].shape[1] for data_type in drug_data_types] | |||||
Xp_cell = cell_data.iloc[resistance[:, 0], :] | |||||
Xp_drug = drug_data.iloc[resistance[:, 1], :] | |||||
Xp_cell = Xp_cell.reset_index(drop=True) | |||||
Xp_drug = Xp_drug.reset_index(drop=True) | |||||
Xp_cell.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in resistance] | |||||
Xp_drug.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in resistance] | |||||
Xn_cell = cell_data.iloc[sensitive[:, 0], :] | |||||
Xn_drug = drug_data.iloc[sensitive[:, 1], :] | |||||
Xn_cell = Xn_cell.reset_index(drop=True) | |||||
Xn_drug = Xn_drug.reset_index(drop=True) | |||||
Xn_cell.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in sensitive] | |||||
Xn_drug.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in sensitive] | |||||
X_cell = pd.concat([Xp_cell, Xn_cell]) | |||||
X_drug = pd.concat([Xp_drug, Xn_drug]) | |||||
Y = np.append(np.zeros(resistance.shape[0]), np.ones(sensitive.shape[0])) | |||||
return X_cell, X_drug, Y, cell_data_sizes, drug_data_sizes |
from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix, f1_score, precision_score, \ | |||||
recall_score, accuracy_score | |||||
from utils import * | |||||
class Evaluation: | |||||
@staticmethod | |||||
def plot_train_val_accuracy(train_accuracies, val_accuracies, num_epochs): | |||||
plt.xlabel('epoch') | |||||
plt.ylabel('accuracy') | |||||
plt.title('h') | |||||
plt.plot(range(1, num_epochs + 1), train_accuracies) | |||||
plt.plot(range(1, num_epochs + 1), val_accuracies) | |||||
plt.show() | |||||
@staticmethod | |||||
def plot_train_val_loss(train_loss, val_loss, num_epochs): | |||||
plt.xlabel('epoch') | |||||
plt.ylabel('loss') | |||||
plt.title('h') | |||||
plt.plot(range(1, num_epochs + 1), train_loss) | |||||
plt.plot(range(1, num_epochs + 1), val_loss) | |||||
plt.show() | |||||
@staticmethod | |||||
def evaluate(all_targets, mlp_output, show_plot=True): | |||||
predicted_labels = np.where(mlp_output > 0.5, 1, 0) | |||||
# Collect predictions and targets for later evaluation | |||||
predicted_labels = predicted_labels.reshape(-1) | |||||
# Convert predictions and targets to numpy arrays | |||||
all_predictions = predicted_labels | |||||
# Calculate and print AUC | |||||
fpr, tpr, thresholds = metrics.roc_curve(all_targets, mlp_output) | |||||
auc = np.round(metrics.auc(fpr, tpr), 2) | |||||
# Calculate and print AUPRC | |||||
print(all_targets) | |||||
precision, recall, thresholds = metrics.precision_recall_curve(all_targets, mlp_output) | |||||
auprc = np.round(metrics.auc(recall, precision), 2) | |||||
# auprc = average_precision_score(all_targets, mlp_output) | |||||
print('Accuracy: {:.2f}'.format(np.round(accuracy_score(all_targets, all_predictions), 2))) | |||||
print('AUC: {:.2f}'.format(auc)) | |||||
print('AUPRC: {:.2f}'.format(auprc)) | |||||
# Calculate and print confusion matrix | |||||
cm = confusion_matrix(all_targets, all_predictions) | |||||
accuracy = cm.trace() / np.sum(cm) | |||||
precision = cm[0, 0] / (cm[0, 0] + cm[0, 1]) | |||||
recall = cm[0, 0] / (cm[0, 0] + cm[1, 0]) | |||||
f1_score = 2 * precision * recall / (precision + recall) | |||||
print('Confusion matrix:\n', cm, sep='') | |||||
print(f'Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1 score: {f1_score:.3f}') | |||||
if show_plot: | |||||
plt.xlabel('False Positive Rate') | |||||
plt.ylabel('True Positive Rate') | |||||
plt.title(f'ROC Curve: AUC={auc}') | |||||
plt.plot(fpr, tpr) | |||||
plt.show() | |||||
# print(f'AUC: {auc}') | |||||
plt.xlabel('Recall') | |||||
plt.ylabel('Precision') | |||||
plt.title(f'PR Curve: AUPRC={auprc}') | |||||
plt.plot(recall, precision) | |||||
plt.show() | |||||
prediction_targets = pd.DataFrame({}, columns=['Prediction', 'Target']) | |||||
res = pd.concat( | |||||
[pd.DataFrame(mlp_output.numpy(), ), pd.DataFrame(all_targets.numpy())], axis=1, | |||||
ignore_index=True) | |||||
res.columns = prediction_targets.columns | |||||
prediction_targets = pd.concat([prediction_targets, res]) | |||||
class_one = prediction_targets.loc[prediction_targets['Target'] == 0, 'Prediction'].astype( | |||||
np.float32).tolist() | |||||
class_minus_one = prediction_targets.loc[prediction_targets['Target'] == 1, 'Prediction'].astype( | |||||
np.float32).tolist() | |||||
fig, ax = plt.subplots() | |||||
ax.set_ylabel("DeepDRA score") | |||||
xticklabels = ['Responder', 'Non Responder'] | |||||
ax.set_xticks([1, 2]) | |||||
ax.set_xticklabels(xticklabels) | |||||
data_to_plot = [class_minus_one, class_one] | |||||
plt.ylim(0, 1) | |||||
p_value = np.format_float_scientific(ttest_ind(class_one, class_minus_one)[1]) | |||||
cancer = 'all' | |||||
plt.title( | |||||
f'Responder/Non responder scores for {cancer} cancer with \np-value ~= {p_value[0]}e{p_value[-3:]} ') | |||||
bp = ax.violinplot(data_to_plot, showextrema=True, showmeans=True, showmedians=True) | |||||
bp['cmeans'].set_color('r') | |||||
bp['cmedians'].set_color('g') | |||||
plt.show() | |||||
return {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1 score': f1_score, 'AUC': auc, | |||||
'AUPRC': auprc} | |||||
@staticmethod | |||||
def add_results(result_list, current_result): | |||||
result_list['AUC'].append(current_result['AUC']) | |||||
result_list['AUPRC'].append(current_result['AUPRC']) | |||||
result_list['Accuracy'].append(current_result['Accuracy']) | |||||
result_list['Precision'].append(current_result['Precision']) | |||||
result_list['Recall'].append(current_result['Recall']) | |||||
result_list['F1 score'].append(current_result['F1 score']) | |||||
return result_list | |||||
@staticmethod | |||||
def show_final_results(result_list): | |||||
print("Final Results:") | |||||
for i in range(len(result_list["AUC"])): | |||||
accuracy = result_list['Accuracy'][i] | |||||
precision = result_list['Precision'][i] | |||||
recall = result_list['Recall'][i] | |||||
f1_score = result_list['F1 score'][i] | |||||
auc = result_list['AUC'][i] | |||||
auprc = result_list['AUPRC'][i] | |||||
print(f'Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1 score: {f1_score:.3f}, AUC: {auc:.3f}, ,AUPRC: {auprc:.3f}') | |||||
avg_auc = np.mean(result_list['AUC']) | |||||
avg_auprc = np.mean(result_list['AUPRC']) | |||||
std_auprc = np.std(result_list['AUPRC']) | |||||
print(" Average AUC: {:.3f} \t Average AUPRC: {:.3f} \t Std AUPRC: {:.3f}".format(avg_auc, avg_auprc, std_auprc)) |
from imblearn.under_sampling import RandomUnderSampler | |||||
from sklearn.model_selection import train_test_split | |||||
from torch.utils.data import TensorDataset, DataLoader, SubsetRandomSampler | |||||
from DeepDRA import DeepDRA, train, test | |||||
from data_loader import RawDataLoader | |||||
from evaluation import Evaluation | |||||
from utils import * | |||||
from mlp import MLP | |||||
import random | |||||
import torch | |||||
import numpy as np | |||||
def train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train, y_test, cell_sizes, drug_sizes): | |||||
""" | |||||
Train and evaluate the DeepDRA model. | |||||
Parameters: | |||||
- X_cell_train (pd.DataFrame): Training data for the cell modality. | |||||
- X_cell_test (pd.DataFrame): Test data for the cell modality. | |||||
- X_drug_train (pd.DataFrame): Training data for the drug modality. | |||||
- X_drug_test (pd.DataFrame): Test data for the drug modality. | |||||
- y_train (pd.Series): Training labels. | |||||
- y_test (pd.Series): Test labels. | |||||
- cell_sizes (list): Sizes of the cell modality features. | |||||
- drug_sizes (list): Sizes of the drug modality features. | |||||
Returns: | |||||
- result: Evaluation result on the test set. | |||||
""" | |||||
# Step 1: Define the batch size for training | |||||
batch_size = 64 | |||||
# Step 2: Instantiate the combined model | |||||
ae_latent_dim = 50 | |||||
mlp_input_dim = 2 * ae_latent_dim | |||||
mlp_output_dim = 1 | |||||
num_epochs = 20 | |||||
model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim, mlp_input_dim, mlp_output_dim) | |||||
# Step 3: Convert your training data to PyTorch tensors | |||||
x_cell_train_tensor = torch.Tensor(x_cell_train.values) | |||||
x_drug_train_tensor = torch.Tensor(x_drug_train.values) | |||||
x_cell_train_tensor = torch.nn.functional.normalize(x_cell_train_tensor, dim=0) | |||||
x_drug_train_tensor = torch.nn.functional.normalize(x_drug_train_tensor, dim=0) | |||||
y_train_tensor = torch.Tensor(y_train) | |||||
y_train_tensor = y_train_tensor.unsqueeze(1) | |||||
# Step 4: Create a TensorDataset with the input features and target labels | |||||
train_dataset = TensorDataset(x_cell_train_tensor, x_drug_train_tensor, y_train_tensor) | |||||
# Step 5: Create the train_loader | |||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |||||
# Step 6: Train the model | |||||
train(model, train_loader, num_epochs=num_epochs) | |||||
# Step 7: Save the trained model | |||||
torch.save(model, 'DeepDRA.pth') | |||||
# Step 8: Load the saved model | |||||
model = torch.load('DeepDRA.pth') | |||||
# Step 9: Convert your test data to PyTorch tensors | |||||
x_cell_test_tensor = torch.Tensor(x_cell_test.values) | |||||
x_drug_test_tensor = torch.Tensor(x_drug_test.values) | |||||
y_test_tensor = torch.Tensor(y_test) | |||||
# normalize data | |||||
x_cell_test_tensor = torch.nn.functional.normalize(x_cell_test_tensor, dim=0) | |||||
x_drug_test_tensor = torch.nn.functional.normalize(x_drug_test_tensor, dim=0) | |||||
# Step 10: Create a TensorDataset with the input features and target labels for testing | |||||
test_dataset = TensorDataset(x_cell_test_tensor, x_drug_test_tensor, y_test_tensor) | |||||
test_loader = DataLoader(test_dataset, batch_size=len(x_cell_test)) | |||||
# Step 11: Test the model | |||||
return test(model, test_loader) | |||||
def run(k, is_test=False): | |||||
""" | |||||
Run the training and evaluation process k times. | |||||
Parameters: | |||||
- k (int): Number of times to run the process. | |||||
- is_test (bool): If True, run on test data; otherwise, perform train-validation split. | |||||
Returns: | |||||
- history (dict): Dictionary containing evaluation metrics for each run. | |||||
""" | |||||
# Step 1: Initialize a dictionary to store evaluation metrics | |||||
history = {'AUC': [], 'AUPRC': [], "Accuracy": [], "Precision": [], "Recall": [], "F1 score": []} | |||||
# Step 2: Load training data | |||||
train_data, train_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES, | |||||
raw_file_directory=GDSC_RAW_DATA_FOLDER, | |||||
screen_file_directory=GDSC_SCREENING_DATA_FOLDER, | |||||
sep="\t") | |||||
# Step 3: Load test data if applicable | |||||
if is_test: | |||||
test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES, | |||||
raw_file_directory=CCLE_RAW_DATA_FOLDER, | |||||
screen_file_directory=CCLE_SCREENING_DATA_FOLDER, | |||||
sep="\t") | |||||
train_data, test_data = RawDataLoader.data_features_intersect(train_data, test_data) | |||||
X_cell_test, X_drug_test, y_test, cell_sizes, drug_sizes = RawDataLoader.prepare_input_data(test_data, | |||||
test_drug_screen) | |||||
# Step 4: Prepare input data for training | |||||
X_cell_train, X_drug_train, y_train, cell_sizes, drug_sizes = RawDataLoader.prepare_input_data(train_data, | |||||
train_drug_screen) | |||||
# Step 5: Loop over k runs | |||||
for i in range(k): | |||||
print('Run {}'.format(i)) | |||||
# Step 6: If is_test is True, perform random under-sampling on the training data | |||||
if is_test: | |||||
rus = RandomUnderSampler(sampling_strategy="majority", random_state=RANDOM_SEED) | |||||
dataset = pd.concat([X_cell_train, X_drug_train], axis=1) | |||||
dataset.index = X_cell_train.index | |||||
dataset, y_train = rus.fit_resample(dataset, y_train) | |||||
X_cell_train = dataset.iloc[:, :sum(cell_sizes)] | |||||
X_drug_train = dataset.iloc[:, sum(cell_sizes):] | |||||
# Step 7: Train and evaluate the DeepDRA model on test data | |||||
results = train_DeepDRA(X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test, cell_sizes, | |||||
drug_sizes) | |||||
else: | |||||
# Step 8: Split the data into training and validation sets | |||||
X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test = train_test_split(X_cell_train, | |||||
X_drug_train, y_train, | |||||
test_size=0.2, | |||||
random_state=44, | |||||
shuffle=True) | |||||
# Step 9: Train and evaluate the DeepDRA model on the split data | |||||
results = train_DeepDRA(X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test, cell_sizes, | |||||
drug_sizes) | |||||
# Step 10: Add results to the history dictionary | |||||
Evaluation.add_results(history, results) | |||||
# Step 11: Display final results | |||||
Evaluation.show_final_results(history) | |||||
return history | |||||
if __name__ == '__main__': | |||||
torch.manual_seed(RANDOM_SEED) | |||||
random.seed(RANDOM_SEED) | |||||
np.random.seed(RANDOM_SEED) | |||||
run(10, is_test=True) |
from torch import nn | |||||
from torch import optim, no_grad | |||||
import torch | |||||
from evaluation import Evaluation | |||||
from utils import data_modalities_abbreviation | |||||
class EarlyStopper: | |||||
def __init__(self, patience=1, min_delta=0): | |||||
self.patience = patience | |||||
self.min_delta = min_delta | |||||
self.counter = 0 | |||||
self.min_validation_loss = float('inf') | |||||
def early_stop(self, validation_loss): | |||||
if validation_loss < self.min_validation_loss: | |||||
self.min_validation_loss = validation_loss | |||||
self.counter = 0 | |||||
elif validation_loss > (self.min_validation_loss + self.min_delta): | |||||
self.counter += 1 | |||||
if self.counter >= self.patience: | |||||
return True | |||||
return False | |||||
class MLP(nn.Module): | |||||
def __init__(self, input_dim, output_dim): | |||||
super(MLP, self).__init__() | |||||
self.mlp = nn.Sequential( | |||||
nn.Linear(input_dim, 128), | |||||
nn.ReLU(inplace=True), | |||||
nn.Linear(128, output_dim), | |||||
nn.Hardsigmoid(), | |||||
) | |||||
def forward(self, x): | |||||
return self.mlp(x) | |||||
def train_mlp(model, train_loader, val_loader, num_epochs): | |||||
mlp_loss_fn = nn.BCELoss() | |||||
mlp_optimizer = optim.Adadelta(model.parameters(), lr=0.01,) | |||||
# scheduler = lr_scheduler.ReduceLROnPlateau(mlp_optimizer, mode='min', factor=0.8, patience=5, verbose=True) | |||||
train_accuracies = [] | |||||
val_accuracies = [] | |||||
train_loss = [] | |||||
val_loss = [] | |||||
early_stopper = EarlyStopper(patience=3, min_delta=0.05) | |||||
for epoch in range(num_epochs): | |||||
# Training | |||||
model.trainCombinedModel() | |||||
total_train_loss = 0.0 | |||||
train_correct = 0 | |||||
train_total_samples = 0 | |||||
for batch_idx, (data, target) in enumerate(train_loader): | |||||
mlp_optimizer.zero_grad() | |||||
mlp_output = model(data) | |||||
mlp_loss = mlp_loss_fn(mlp_output, target) | |||||
mlp_loss.backward() | |||||
mlp_optimizer.step() | |||||
total_train_loss += mlp_loss.item() | |||||
# Calculate accuracy | |||||
train_predictions = torch.round(mlp_output) | |||||
train_correct += (train_predictions == target).sum().item() | |||||
train_total_samples += target.size(0) | |||||
# if batch_idx % 200 == 0: | |||||
# after_lr = mlp_optimizer.param_groups[0]["lr"] | |||||
# print('Epoch [{}/{}], Batch [{}/{}], Total Loss: {:.4f}, Learning Rate: {:.8f},'.format( | |||||
# epoch + 1, num_epochs, batch_idx + 1, len(train_loader), mlp_loss.item(), after_lr)) | |||||
avg_train_loss = total_train_loss / len(train_loader) | |||||
train_loss.append(avg_train_loss) | |||||
# Validation | |||||
model.eval() | |||||
total_val_loss = 0.0 | |||||
correct = 0 | |||||
total_samples = 0 | |||||
with torch.no_grad(): | |||||
for val_batch_idx, (data, val_target) in enumerate(val_loader): | |||||
val_mlp_output = model(data) | |||||
val_mlp_loss = mlp_loss_fn(val_mlp_output, val_target) | |||||
total_val_loss += val_mlp_loss.item() | |||||
# Calculate accuracy | |||||
val_predictions = torch.round(val_mlp_output) | |||||
correct += (val_predictions == val_target).sum().item() | |||||
total_samples += val_target.size(0) | |||||
avg_val_loss = total_val_loss / len(val_loader) | |||||
val_loss.append(avg_val_loss) | |||||
train_accuracy = train_correct / train_total_samples | |||||
train_accuracies.append(train_accuracy) | |||||
val_accuracy = correct / total_samples | |||||
val_accuracies.append(val_accuracy) | |||||
print( | |||||
'Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}, Train Accuracy: {:.4f}, Val Accuracy: {:.4f}'.format( | |||||
epoch + 1, num_epochs, avg_train_loss, avg_val_loss, train_accuracy, | |||||
val_accuracy)) | |||||
# if early_stopper.early_stop(avg_val_loss): | |||||
# break | |||||
before_lr = mlp_optimizer.param_groups[0]["lr"] | |||||
# scheduler.step(avg_val_loss) | |||||
after_lr = mlp_optimizer.param_groups[0]["lr"] | |||||
if before_lr != after_lr: | |||||
print("Epoch %d: Adam lr %.8f -> %.8f" % (epoch, before_lr, after_lr)) | |||||
Evaluation.plot_train_val_accuracy(train_accuracies, val_accuracies, epoch+1) | |||||
Evaluation.plot_train_val_loss(train_loss, val_loss, epoch+1) | |||||
def test_mlp(model, test_loader): | |||||
for i, (data, labels) in enumerate(test_loader): | |||||
model.eval() | |||||
with torch.no_grad(): | |||||
mlp_output = model(data) | |||||
return Evaluation.evaluate(labels, mlp_output) | |||||
import os | |||||
import pandas as pd | |||||
import numpy as np | |||||
from tqdm import tqdm | |||||
import sklearn as sk | |||||
from matplotlib import pyplot as plt | |||||
from scipy.spatial.distance import pdist, squareform | |||||
import h2o | |||||
from h2o.estimators import H2ODeepLearningEstimator | |||||
from sklearn.impute import SimpleImputer | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.optim as optim | |||||
from torch.utils.data import DataLoader, TensorDataset | |||||
import pickle | |||||
from sklearn import metrics | |||||
from copy import deepcopy | |||||
import pyreadr | |||||
import requests | |||||
from time import time | |||||
from math import ceil | |||||
from statsmodels.stats.weightstats import ttest_ind | |||||
import torch.optim.lr_scheduler as lr_scheduler | |||||
from sklearn.model_selection import KFold | |||||
DATA_FOLDER = 'data' | |||||
RES_DATA_FOLDER = os.path.join(DATA_FOLDER, 'res') | |||||
TEST_DATA_FOLDER = os.path.join(DATA_FOLDER, 'final_test_data') | |||||
TEST_TCGA_DATA_FOLDER = os.path.join(DATA_FOLDER, 'TCGA_test_data') | |||||
SIM_DATA_FOLDER = os.path.join(DATA_FOLDER, 'similarity_data') | |||||
RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'raw_data') | |||||
RAW_BOTH_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_GDSC_Data') | |||||
DRUG_DATA_FOLDER = os.path.join(DATA_FOLDER, 'drug_data') | |||||
NEW_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'new_raw_data') | |||||
GDSC_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'GDSC_data') | |||||
CCLE_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CCLE_raw') | |||||
CTRP_FOLDER = os.path.join(DATA_FOLDER, 'CTRP') | |||||
GDSC_FOLDER = os.path.join(DATA_FOLDER, 'GDSC') | |||||
CCLE_FOLDER = os.path.join(DATA_FOLDER, 'CCLE') | |||||
MODEL_FOLDER = os.path.join(DATA_FOLDER, 'model') | |||||
CTRP_EXPERIMENT_FILE = os.path.join(CTRP_FOLDER, 'v20.meta.per_experiment.txt') | |||||
CTRP_COMPOUND_FILE = os.path.join(CTRP_FOLDER, 'v20.meta.per_compound.txt') | |||||
CTRP_CELLLINE_FILE = os.path.join(CTRP_FOLDER, 'v20.meta.per_cell_line.txt') | |||||
CTRP_AUC_FILE = os.path.join(CTRP_FOLDER, 'v20.data.curves_post_qc.txt') | |||||
GDSC_AUC_FILE = os.path.join(GDSC_FOLDER, 'GDSC2_fitted_dose_response.csv') | |||||
GDSC_cnv_data_FILE = os.path.join(GDSC_FOLDER, 'cnv_abs_copy_number_picnic_20191101.csv') | |||||
GDSC_methy_data_FILE = os.path.join(GDSC_FOLDER, 'F2_METH_CELL_DATA.txt') | |||||
GDSC_methy_sampleIds_FILE = os.path.join(GDSC_FOLDER, 'methSampleId_2_cosmicIds.xlsx') | |||||
GDSC_exp_data_FILE = os.path.join(GDSC_FOLDER, 'Cell_line_RMA_proc_basalExp.txt') | |||||
GDSC_exp_sampleIds_FILE = os.path.join(GDSC_FOLDER, 'E-MTAB-3610.sdrf.txt') | |||||
GDSC_mut_data_FILE = os.path.join(GDSC_FOLDER, 'mutations_all_20230202.csv') | |||||
GDSC_SCREENING_DATA_FOLDER = os.path.join(GDSC_RAW_DATA_FOLDER, 'drug_screening_matrix_GDSC.tsv') | |||||
CCLE_SCREENING_DATA_FOLDER = os.path.join(CCLE_RAW_DATA_FOLDER, 'drug_screening_matrix_ccle.tsv') | |||||
BOTH_SCREENING_DATA_FOLDER = os.path.join(RAW_BOTH_DATA_FOLDER, 'drug_screening_matrix_gdsc_ctrp.tsv') | |||||
CCLE_mut_data_FILE = os.path.join(CCLE_FOLDER, 'CCLE_mutations.csv') | |||||
TABLE_RESULTS_FILE = os.path.join(DATA_FOLDER, 'drug_screening_table.tsv') | |||||
MATRIX_RESULTS_FILE = os.path.join(DATA_FOLDER, 'drug_screening_matrix.tsv') | |||||
MODEL_FILE = os.path.join(MODEL_FOLDER, 'trained_model_V1_EMDP.sav') | |||||
TEST_FILE = os.path.join(TEST_DATA_FOLDER, 'test.gzip') | |||||
RESULT_FILE = os.path.join(RES_DATA_FOLDER, 'result.tsv') | |||||
TCGA_DATA_FOLDER = os.path.join(DATA_FOLDER, 'TCGA_test_data') | |||||
TCGA_SCREENING_DATA = os.path.join(TCGA_DATA_FOLDER, 'TCGA_screening_matrix.tsv') | |||||
BUILD_SIM_MATRICES = True # Make this variable True to build similarity matrices from raw data | |||||
SIM_KERNEL = {'cell_CN': ('euclidean', 0.001), 'cell_exp': ('euclidean', 0.01), 'cell_methy': ('euclidean', 0.1), | |||||
'cell_mut': ('jaccard', 1), 'drug_DT': ('jaccard', 1), 'drug_comp': ('euclidean', 0.001), | |||||
'drug_desc': ('euclidean', 0.001), 'drug_finger': ('euclidean', 0.001)} | |||||
SAVE_MODEL = False # Change it to True to save the trained model | |||||
VARIATIONAL_AUTOENCODERS = False | |||||
# DATA_MODALITIES=['cell_CN','cell_exp','cell_methy','cell_mut','drug_comp','drug_DT'] # Change this list to only consider specific data modalities | |||||
DATA_MODALITIES = ['cell_mut', 'drug_desc', 'drug_finger'] | |||||
RANDOM_SEED = 42 # Must be used wherever can be used | |||||
def data_modalities_abbreviation(): | |||||
abb = [] | |||||
if 'cell_CN' in DATA_MODALITIES: | |||||
abb.append('C') | |||||
if 'cell_exp' in DATA_MODALITIES: | |||||
abb.append('E') | |||||
if 'cell_mut' in DATA_MODALITIES: | |||||
abb.append('M') | |||||
if 'cell_methy' in DATA_MODALITIES: | |||||
abb.append('T') | |||||
if 'drug_DT' in DATA_MODALITIES: | |||||
abb.append('D') | |||||
if 'drug_comp' in DATA_MODALITIES: | |||||
abb.append('P') | |||||
return ''.join(abb) | |||||
""" TRAIN_INTEGRATION_METHOD used for each cell's and drug_data's data definitions: | |||||
SIMILARITY: A kernel based integration method in which based on the similarity of each cell's data with the training cell's | |||||
data the input features for the multi layer perceptron (MLP) is constructed. The similarity function used could be different for | |||||
each data modality (euclidean, jaccard,l1_norm, or ...) | |||||
AUTO_ENCODER_V1: In this version of integrating multi-omics, for each data modality an autoencoder is trained to reduce the | |||||
dimension of the features and finally a concatenation of each autoencoder's latent space builds up the input layer of the MLP. | |||||
AUTO_ENCODER_V2: In this version of integrating multi-omics data, we train a big autoencoder which reduces the dimension of | |||||
all the different data modalities features at the same time to a smaller feature space. This version of integrating could | |||||
take a lot of memory and time to integrate the data and might be computationally expensive. | |||||
AUTO_ENCODER_V3: IN this version of integrating multi-omics data, we train an autoencoder for all the modalities kinda same as | |||||
the autoencoder version 2 but with this difference that the encoder and decoder layers are separate from each other and | |||||
just the latent layer is shared among different data modalities. | |||||
""" |