Browse Source

feat: add main files

main
Taha Mohammadzadeh 5 months ago
parent
commit
d5ca9f0e77
5 changed files with 122 additions and 26 deletions
  1. 73
    10
      DeepDRA.py
  2. 7
    0
      README.md
  3. 9
    2
      evaluation.py
  4. 29
    13
      main.py
  5. 4
    1
      utils.py

+ 73
- 10
DeepDRA.py View File

@@ -88,7 +88,7 @@ class DeepDRA(nn.Module):
return torch.square(w).sum()


def train(model, train_loader, num_epochs):
def train(model, train_loader, val_loader, num_epochs,class_weights):
"""
Trains the DeepDRA (Deep Drug Response Anticipation) model.

@@ -97,23 +97,44 @@ def train(model, train_loader, num_epochs):
- train_loader (DataLoader): DataLoader for the training dataset.
- num_epochs (int): Number of training epochs.
"""


autoencoder_loss_fn = nn.MSELoss()
mlp_loss_fn = nn.BCELoss()

mlp_optimizer = optim.Adam(model.parameters(), lr=0.0005)
train_accuracies = []
val_accuracies = []

train_loss = []
val_loss = []

mlp_optimizer = optim.Adam(model.parameters(), lr=0.0005,)
scheduler = lr_scheduler.ReduceLROnPlateau(mlp_optimizer, mode='min', factor=0.8, patience=5, verbose=True)

# Define weight parameters for each loss term
cell_ae_weight = 1.0
drug_ae_weight = 1.0
mlp_weight = 1.0

for epoch in range(num_epochs):
model.train()
total_train_loss = 0.0
train_correct = 0
train_total_samples = 0
for batch_idx, (cell_data, drug_data, target) in enumerate(train_loader):
mlp_optimizer.zero_grad()

# Forward pass
cell_decoded_output, drug_decoded_output, mlp_output = model(cell_data, drug_data)

# Compute class weights for the current batch
# batch_class_weights = class_weights[target.long()]
# mlp_loss_fn = nn.BCEWithLogitsLoss(weight=batch_class_weights)

# Compute losses
cell_ae_loss = autoencoder_loss_fn(cell_decoded_output, cell_data)
drug_ae_loss = autoencoder_loss_fn(drug_decoded_output, drug_data)
mlp_loss = mlp_loss_fn(mlp_output, target)
cell_ae_loss = cell_ae_weight * autoencoder_loss_fn(cell_decoded_output, cell_data)
drug_ae_loss = drug_ae_weight * autoencoder_loss_fn(drug_decoded_output, drug_data)
mlp_loss = mlp_weight * mlp_loss_fn(mlp_output, target)

# Total loss is the sum of autoencoder losses and MLP loss
total_loss = drug_ae_loss + cell_ae_loss + mlp_loss
@@ -121,14 +142,56 @@ def train(model, train_loader, num_epochs):
# Backward pass and optimization
total_loss.backward()
mlp_optimizer.step()
total_train_loss += total_loss.item()

# Calculate accuracy
train_predictions = torch.round(mlp_output)
train_correct += (train_predictions == target).sum().item()
train_total_samples += target.size(0)

avg_train_loss = total_train_loss / len(train_loader)
train_loss.append(avg_train_loss)

# Validation
model.eval()
total_val_loss = 0.0
correct = 0
total_samples = 0
with torch.no_grad():
for val_batch_idx, (cell_data_val, drug_data_val, val_target) in enumerate(val_loader):
cell_decoded_output_val, drug_decoded_output_val, mlp_output_val = model(cell_data_val, drug_data_val)
# batch_class_weights = class_weights[val_target.long()]
# mlp_loss_fn = nn.BCEWithLogitsLoss(weight=batch_class_weights)

# Compute losses
cell_ae_loss_val = cell_ae_weight * autoencoder_loss_fn(cell_decoded_output_val, cell_data_val)
drug_ae_loss_val = drug_ae_weight * autoencoder_loss_fn(drug_decoded_output_val, drug_data_val)
mlp_loss_val = mlp_weight * mlp_loss_fn(mlp_output_val, val_target)

# Total loss is the sum of autoencoder losses and MLP loss
total_val_loss = drug_ae_loss_val + cell_ae_loss_val + mlp_loss_val

# Calculate accuracy
val_predictions = torch.round(mlp_output_val)
correct += (val_predictions == val_target).sum().item()
total_samples += val_target.size(0)

avg_val_loss = total_val_loss / len(val_loader)
val_loss.append(avg_val_loss)


train_accuracy = train_correct / train_total_samples
train_accuracies.append(train_accuracy)
val_accuracy = correct / total_samples
val_accuracies.append(val_accuracy)

# Print progress
if batch_idx % 200 == 0:
print('Epoch [{}/{}], Total Loss: {:.4f}'.format(
epoch + 1, num_epochs, total_loss.item()))
print(
'Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}, Train Accuracy: {:.4f}, Val Accuracy: {:.4f}'.format(
epoch + 1, num_epochs, avg_train_loss, avg_val_loss, train_accuracy,
val_accuracy))

# Learning rate scheduler step
scheduler.step(total_loss)
scheduler.step(total_train_loss)

# Save the trained model
torch.save(model.state_dict(), MODEL_FOLDER + 'DeepDRA.pth')

+ 7
- 0
README.md View File

@@ -1 +1,8 @@
# DeepDRA
Data

Download data from this link: https://drive.google.com/drive/folders/1-PgwD7KN9ZxCYBhyGAs3ihlbKK7s9jiO?usp=sharing

Run

You can run the main code with different data sets

+ 9
- 2
evaluation.py View File

@@ -72,11 +72,11 @@ class Evaluation:

# Step 2: Calculate and print AUC
fpr, tpr, thresholds = metrics.roc_curve(all_targets, mlp_output)
auc = np.round(metrics.auc(fpr, tpr), 2)
auc = np.round(metrics.auc(fpr, tpr), 3)

# Step 3: Calculate and print AUPRC
precision, recall, thresholds = metrics.precision_recall_curve(all_targets, mlp_output)
auprc = np.round(metrics.auc(recall, precision), 2)
auprc = np.round(metrics.auc(recall, precision), 3)

# Step 4: Print accuracy, AUC, AUPRC, and confusion matrix
accuracy = accuracy_score(all_targets, all_predictions)
@@ -162,4 +162,11 @@ class Evaluation:
avg_auc = np.mean(result_list['AUC'])
avg_auprc = np.mean(result_list['AUPRC'])
std_auprc = np.std(result_list['AUPRC'])
avg_accuracy = np.mean(result_list['Accuracy'])
avg_precision = np.mean(result_list['Precision'])
avg_recal = np.mean(result_list['Recall'])
avg_f1score = np.mean(result_list['F1 score'])
print(
f'AVG: Accuracy: {avg_accuracy:.3f}, Precision: {avg_precision:.3f}, Recall: {avg_recal:.3f}, F1 score: {avg_f1score:.3f}, AUC: {avg_auc:.3f}, ,AUPRC: {avg_auprc:.3f}')

print(" Average AUC: {:.3f} \t Average AUPRC: {:.3f} \t Std AUPRC: {:.3f}".format(avg_auc, avg_auprc, std_auprc))

+ 29
- 13
main.py View File

@@ -1,5 +1,6 @@
from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

from torch.utils.data import TensorDataset, DataLoader, SubsetRandomSampler

@@ -39,7 +40,7 @@ def train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train,
ae_latent_dim = 50
mlp_input_dim = 2 * ae_latent_dim
mlp_output_dim = 1
num_epochs = 20
num_epochs = 25
model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim, mlp_input_dim, mlp_output_dim)

# Step 3: Convert your training data to PyTorch tensors
@@ -50,20 +51,34 @@ def train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train,
y_train_tensor = torch.Tensor(y_train)
y_train_tensor = y_train_tensor.unsqueeze(1)

# Compute class weights
classes = [0, 1] # Assuming binary classification
class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=classes, y=y_train),
dtype=torch.float32)


x_cell_train_tensor, x_cell_val_tensor, x_drug_train_tensor, x_drug_val_tensor, y_train_tensor, y_val_tensor = train_test_split(
x_cell_train_tensor, x_drug_train_tensor, y_train_tensor, test_size=0.1,
random_state=RANDOM_SEED,
shuffle=True)

# Step 4: Create a TensorDataset with the input features and target labels
train_dataset = TensorDataset(x_cell_train_tensor, x_drug_train_tensor, y_train_tensor)
val_dataset = TensorDataset(x_cell_val_tensor, x_drug_val_tensor, y_val_tensor)

# Step 5: Create the train_loader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)


# Step 6: Train the model
train(model, train_loader, num_epochs=num_epochs)
train(model, train_loader, val_loader, num_epochs,class_weights)

# Step 7: Save the trained model
torch.save(model, MODEL_FOLDER + 'DeepDRA.pth')
torch.save(model, 'DeepDRA.pth')

# Step 8: Load the saved model
model = torch.load( MODEL_FOLDER + 'DeepDRA.pth')
model = torch.load('DeepDRA.pth')

# Step 9: Convert your test data to PyTorch tensors
x_cell_test_tensor = torch.Tensor(x_cell_test.values)
@@ -99,8 +114,8 @@ def run(k, is_test=False):

# Step 2: Load training data
train_data, train_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES,
raw_file_directory=GDSC_RAW_DATA_FOLDER,
screen_file_directory=GDSC_SCREENING_DATA_FOLDER,
raw_file_directory=RAW_BOTH_DATA_FOLDER,
screen_file_directory=BOTH_SCREENING_DATA_FOLDER,
sep="\t")

# Step 3: Load test data if applicable
@@ -117,18 +132,19 @@ def run(k, is_test=False):
X_cell_train, X_drug_train, y_train, cell_sizes, drug_sizes = RawDataLoader.prepare_input_data(train_data,
train_drug_screen)

rus = RandomUnderSampler(sampling_strategy="majority", random_state=RANDOM_SEED)
dataset = pd.concat([X_cell_train, X_drug_train], axis=1)
dataset.index = X_cell_train.index
dataset, y_train = rus.fit_resample(dataset, y_train)
X_cell_train = dataset.iloc[:, :sum(cell_sizes)]
X_drug_train = dataset.iloc[:, sum(cell_sizes):]

# Step 5: Loop over k runs
for i in range(k):
print('Run {}'.format(i))

# Step 6: If is_test is True, perform random under-sampling on the training data
if is_test:
rus = RandomUnderSampler(sampling_strategy="majority", random_state=RANDOM_SEED)
dataset = pd.concat([X_cell_train, X_drug_train], axis=1)
dataset.index = X_cell_train.index
dataset, y_train = rus.fit_resample(dataset, y_train)
X_cell_train = dataset.iloc[:, :sum(cell_sizes)]
X_drug_train = dataset.iloc[:, sum(cell_sizes):]

# Step 7: Train and evaluate the DeepDRA model on test data
results = train_DeepDRA(X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test, cell_sizes,
@@ -138,7 +154,7 @@ def run(k, is_test=False):
X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test = train_test_split(X_cell_train,
X_drug_train, y_train,
test_size=0.2,
random_state=44,
random_state=RANDOM_SEED,
shuffle=True)
# Step 9: Train and evaluate the DeepDRA model on the split data
results = train_DeepDRA(X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test, cell_sizes,

+ 4
- 1
utils.py View File

@@ -7,8 +7,11 @@ RAW_BOTH_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_GDSC_data')
DRUG_DATA_FOLDER = os.path.join(DATA_FOLDER, 'drug_data')
GDSC_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'GDSC_data')
CCLE_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CCLE_data')
CTRP_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_data')

GDSC_SCREENING_DATA_FOLDER = os.path.join(GDSC_RAW_DATA_FOLDER, 'drug_screening_matrix_GDSC.tsv')
CCLE_SCREENING_DATA_FOLDER = os.path.join(CCLE_RAW_DATA_FOLDER, 'drug_screening_matrix_ccle.tsv')
CTRP_SCREENING_DATA_FOLDER = os.path.join(CTRP_RAW_DATA_FOLDER, 'drug_screening_matrix_ctrp.tsv')
BOTH_SCREENING_DATA_FOLDER = os.path.join(RAW_BOTH_DATA_FOLDER, 'drug_screening_matrix_gdsc_ctrp.tsv')

CTRP_FOLDER = os.path.join(DATA_FOLDER, 'CTRP')
@@ -27,7 +30,7 @@ SIM_KERNEL = {'cell_CN': ('euclidean', 0.001), 'cell_exp': ('euclidean', 0.01),
SAVE_MODEL = False # Change it to True to save the trained model
VARIATIONAL_AUTOENCODERS = False
# DATA_MODALITIES=['cell_CN','cell_exp','cell_methy','cell_mut','drug_comp','drug_DT'] # Change this list to only consider specific data modalities
DATA_MODALITIES = ['cell_mut', 'drug_desc', 'drug_finger']
DATA_MODALITIES = ['cell_CN','cell_exp','cell_mut', 'drug_desc']
RANDOM_SEED = 42 # Must be used wherever can be used



Loading…
Cancel
Save