Browse Source

feat: stable

main
Taha Mohammadzadeh 11 months ago
parent
commit
439b0c1aa7
3 changed files with 8 additions and 11 deletions
  1. 3
    4
      DeepDRA.py
  2. 4
    6
      main.py
  3. 1
    1
      utils.py

+ 3
- 4
DeepDRA.py View File

- mlp_output_dim (int): Output dimension for the MLP. - mlp_output_dim (int): Output dimension for the MLP.
""" """


def __init__(self, cell_modality_sizes, drug_modality_sizes, cell_ae_latent_dim, drug_ae_latent_dim, mlp_input_dim,
mlp_output_dim):
def __init__(self, cell_modality_sizes, drug_modality_sizes, cell_ae_latent_dim, drug_ae_latent_dim):
super(DeepDRA, self).__init__() super(DeepDRA, self).__init__()


# Initialize cell and drug autoencoders # Initialize cell and drug autoencoders
self.drug_modality_sizes = drug_modality_sizes self.drug_modality_sizes = drug_modality_sizes


# Initialize MLP # Initialize MLP
self.mlp = MLP(mlp_input_dim, mlp_output_dim)
self.mlp = MLP(cell_ae_latent_dim+drug_ae_latent_dim, 1)


def forward(self, cell_x, drug_x): def forward(self, cell_x, drug_x):
""" """
return torch.square(w).sum() return torch.square(w).sum()




def train(model, train_loader, val_loader, num_epochs, class_weights):
def train(model, train_loader, val_loader, num_epochs,class_weights):
""" """
Trains the DeepDRA (Deep Drug Response Anticipation) model. Trains the DeepDRA (Deep Drug Response Anticipation) model.



+ 4
- 6
main.py View File



# Step 2: Instantiate the combined model # Step 2: Instantiate the combined model
ae_latent_dim = 50 ae_latent_dim = 50
mlp_input_dim = 2 * ae_latent_dim
mlp_output_dim = 1
num_epochs = 25 num_epochs = 25


def train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train, y_test, cell_sizes, drug_sizes,device): def train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train, y_test, cell_sizes, drug_sizes,device):
""" """




model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim, mlp_input_dim, mlp_output_dim)
model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim)
model.to(device) model.to(device)
# Step 3: Convert your training data to PyTorch tensors # Step 3: Convert your training data to PyTorch tensors
x_cell_train_tensor = torch.Tensor(x_cell_train.values) x_cell_train_tensor = torch.Tensor(x_cell_train.values)


train_sampler = SubsetRandomSampler(train_idx) train_sampler = SubsetRandomSampler(train_idx)
test_sampler = SubsetRandomSampler(val_idx) test_sampler = SubsetRandomSampler(val_idx)
model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim, mlp_input_dim, mlp_output_dim)
model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim)
# Convert your training data to PyTorch tensors # Convert your training data to PyTorch tensors
x_cell_train_tensor = torch.Tensor(x_cell_train.values) x_cell_train_tensor = torch.Tensor(x_cell_train.values)
x_drug_train_tensor = torch.Tensor(x_drug_train.values) x_drug_train_tensor = torch.Tensor(x_drug_train.values)


# Step 2: Load training data # Step 2: Load training data
train_data, train_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES, train_data, train_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES,
raw_file_directory=RAW_BOTH_DATA_FOLDER,
screen_file_directory=BOTH_SCREENING_DATA_FOLDER,
raw_file_directory=GDSC_RAW_DATA_FOLDER,
screen_file_directory=GDSC_SCREENING_DATA_FOLDER,
sep="\t") sep="\t")


# Step 3: Load test data if applicable # Step 3: Load test data if applicable

+ 1
- 1
utils.py View File

SAVE_MODEL = False # Change it to True to save the trained model SAVE_MODEL = False # Change it to True to save the trained model
VARIATIONAL_AUTOENCODERS = False VARIATIONAL_AUTOENCODERS = False
# DATA_MODALITIES=['cell_CN','cell_exp','cell_methy','cell_mut','drug_comp','drug_DT'] # Change this list to only consider specific data modalities # DATA_MODALITIES=['cell_CN','cell_exp','cell_methy','cell_mut','drug_comp','drug_DT'] # Change this list to only consider specific data modalities
DATA_MODALITIES = ['cell_CN','cell_exp','cell_mut', 'drug_desc']
DATA_MODALITIES = ['cell_exp', 'drug_desc']
RANDOM_SEED = 42 # Must be used wherever can be used RANDOM_SEED = 42 # Must be used wherever can be used





Loading…
Cancel
Save