Browse Source

feat: stable

main
Taha Mohammadzadeh 11 months ago
parent
commit
b2bdea1fcc
6 changed files with 14 additions and 13 deletions
  1. 2
    0
      DeepDRA.py
  2. 7
    10
      data_loader.py
  3. 1
    1
      evaluation.py
  4. 2
    1
      main.py
  5. 1
    1
      mlp.py
  6. 1
    0
      utils.py

+ 2
- 0
DeepDRA.py View File



# Backward pass and optimization # Backward pass and optimization
total_loss.backward() total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

mlp_optimizer.step() mlp_optimizer.step()
total_train_loss += total_loss.item() total_train_loss += total_loss.item()



+ 7
- 10
data_loader.py View File



class RawDataLoader: class RawDataLoader:
@staticmethod @staticmethod
def load_data(data_modalities, raw_file_directory, screen_file_directory, sep):
def load_data(data_modalities, raw_file_directory, screen_file_directory, sep, drug_directory=DRUG_DATA_FOLDER):
""" """
Load raw data and screening data, perform intersection, and adjust screening data. Load raw data and screening data, perform intersection, and adjust screening data.




# Step 2: Load drug data files for specified data modalities # Step 2: Load drug data files for specified data modalities
drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities, drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities,
raw_file_directory=DRUG_DATA_FOLDER)
raw_file_directory=drug_directory)


# Step 3: Update the 'data' dictionary with drug data # Step 3: Update the 'data' dictionary with drug data
data.update(drug_data) data.update(drug_data)
df.columns = df.columns.str.replace('_cell_CN', '') df.columns = df.columns.str.replace('_cell_CN', '')
df.columns = df.columns.str.replace('_cell_exp', '') df.columns = df.columns.str.replace('_cell_exp', '')


# Note that drug_comp raw table has some NA values so we should impute it
if any(df.isna()):
df = pd.DataFrame(SimpleImputer(strategy='mean').fit_transform(df),
columns=df.columns).set_index(df.index)
if file.startswith('drug_comp'): # We need to normalize the drug_data comp dataset
df = ((df - df.mean()) / df.std()).fillna(0)
elif file.startswith('drug_desc'): # We need to normalize the drug_data comp dataset
df = ((df - df.mean()) / df.std()).fillna(0)
df = (df - df.min()) / (df.max() - df.min())
df = df.fillna(0)

print("has null:")
print(df.isnull().sum().sum())
if intersect: if intersect:
if file.startswith('cell'): if file.startswith('cell'):
if cell_line_names: if cell_line_names:

+ 1
- 1
evaluation.py View File





@staticmethod @staticmethod
def evaluate(all_targets, mlp_output, show_plot=True):
def evaluate(all_targets, mlp_output, show_plot=False):
""" """
Evaluate model performance based on predictions and targets. Evaluate model performance based on predictions and targets.



+ 2
- 1
main.py View File

screen_file_directory=GDSC_SCREENING_DATA_FOLDER, screen_file_directory=GDSC_SCREENING_DATA_FOLDER,
sep="\t") sep="\t")



# Step 3: Load test data if applicable # Step 3: Load test data if applicable
if is_test: if is_test:
test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES, test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES,
torch.manual_seed(RANDOM_SEED) torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED) random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED) np.random.seed(RANDOM_SEED)
run(10, is_test=False)
run(10, is_test=True)

+ 1
- 1
mlp.py View File

nn.Linear(input_dim, 128), nn.Linear(input_dim, 128),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(128, output_dim), nn.Linear(128, output_dim),
nn.Hardsigmoid(),
nn.Sigmoid(),
) )


def forward(self, x): def forward(self, x):

+ 1
- 0
utils.py View File

GDSC_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'GDSC_data') GDSC_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'GDSC_data')
CCLE_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CCLE_data') CCLE_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CCLE_data')
CTRP_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_data') CTRP_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_data')
SIM_DATA_FOLDER = os.path.join(DATA_FOLDER, 'similarity_data')


GDSC_SCREENING_DATA_FOLDER = os.path.join(GDSC_RAW_DATA_FOLDER, 'drug_screening_matrix_GDSC.tsv') GDSC_SCREENING_DATA_FOLDER = os.path.join(GDSC_RAW_DATA_FOLDER, 'drug_screening_matrix_GDSC.tsv')
CCLE_SCREENING_DATA_FOLDER = os.path.join(CCLE_RAW_DATA_FOLDER, 'drug_screening_matrix_ccle.tsv') CCLE_SCREENING_DATA_FOLDER = os.path.join(CCLE_RAW_DATA_FOLDER, 'drug_screening_matrix_ccle.tsv')

Loading…
Cancel
Save