Browse Source

feat: add device

main
taha 11 months ago
parent
commit
5896d7461a
1 changed files with 10 additions and 6 deletions
  1. 10
    6
      main.py

+ 10
- 6
main.py View File

import pandas as pd import pandas as pd




def train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train, y_test, cell_sizes, drug_sizes):
def train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train, y_test, cell_sizes, drug_sizes,device):
""" """


Train and evaluate the DeepDRA model. Train and evaluate the DeepDRA model.
mlp_output_dim = 1 mlp_output_dim = 1
num_epochs = 25 num_epochs = 25
model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim, mlp_input_dim, mlp_output_dim) model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim, mlp_input_dim, mlp_output_dim)
model.to(device)
# Step 3: Convert your training data to PyTorch tensors # Step 3: Convert your training data to PyTorch tensors
x_cell_train_tensor = torch.Tensor(x_cell_train.values) x_cell_train_tensor = torch.Tensor(x_cell_train.values)
x_drug_train_tensor = torch.Tensor(x_drug_train.values) x_drug_train_tensor = torch.Tensor(x_drug_train.values)
y_train_tensor = torch.Tensor(y_train) y_train_tensor = torch.Tensor(y_train)
y_train_tensor = y_train_tensor.unsqueeze(1) y_train_tensor = y_train_tensor.unsqueeze(1)


x_cell_train_tensor.to(device)
x_drug_train_tensor.to(device)
y_train_tensor.to(device)
# Compute class weights # Compute class weights
classes = [0, 1] # Assuming binary classification classes = [0, 1] # Assuming binary classification
class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=classes, y=y_train), class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=classes, y=y_train),
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)



# Step 6: Train the model # Step 6: Train the model
train(model, train_loader, val_loader, num_epochs,class_weights) train(model, train_loader, val_loader, num_epochs,class_weights)


Returns: Returns:
- history (dict): Dictionary containing evaluation metrics for each run. - history (dict): Dictionary containing evaluation metrics for each run.
""" """

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())
torch.zeros(1).cuda()
# Step 1: Initialize a dictionary to store evaluation metrics # Step 1: Initialize a dictionary to store evaluation metrics
history = {'AUC': [], 'AUPRC': [], "Accuracy": [], "Precision": [], "Recall": [], "F1 score": []} history = {'AUC': [], 'AUPRC': [], "Accuracy": [], "Precision": [], "Recall": [], "F1 score": []}




# Step 7: Train and evaluate the DeepDRA model on test data # Step 7: Train and evaluate the DeepDRA model on test data
results = train_DeepDRA(X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test, cell_sizes, results = train_DeepDRA(X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test, cell_sizes,
drug_sizes)
drug_sizes, device)
else: else:
# Step 8: Split the data into training and validation sets # Step 8: Split the data into training and validation sets
X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test = train_test_split(X_cell_train, X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test = train_test_split(X_cell_train,
shuffle=True) shuffle=True)
# Step 9: Train and evaluate the DeepDRA model on the split data # Step 9: Train and evaluate the DeepDRA model on the split data
results = train_DeepDRA(X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test, cell_sizes, results = train_DeepDRA(X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test, cell_sizes,
drug_sizes)
drug_sizes, device)


# Step 10: Add results to the history dictionary # Step 10: Add results to the history dictionary
Evaluation.add_results(history, results) Evaluation.add_results(history, results)

Loading…
Cancel
Save