@@ -41,7 +41,7 @@ def train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train, | |||
""" | |||
model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim, mlp_input_dim, mlp_output_dim) | |||
model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim) | |||
model= model.to(device) | |||
# Step 3: Convert your training data to PyTorch tensors |