| # Backward pass and optimization | # Backward pass and optimization | ||||
| total_loss.backward() | total_loss.backward() | ||||
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1) | |||||
| mlp_optimizer.step() | mlp_optimizer.step() | ||||
| total_train_loss += total_loss.item() | total_train_loss += total_loss.item() | ||||
| class RawDataLoader: | class RawDataLoader: | ||||
| @staticmethod | @staticmethod | ||||
| def load_data(data_modalities, raw_file_directory, screen_file_directory, sep): | |||||
| def load_data(data_modalities, raw_file_directory, screen_file_directory, sep, drug_directory=DRUG_DATA_FOLDER): | |||||
| """ | """ | ||||
| Load raw data and screening data, perform intersection, and adjust screening data. | Load raw data and screening data, perform intersection, and adjust screening data. | ||||
| # Step 2: Load drug data files for specified data modalities | # Step 2: Load drug data files for specified data modalities | ||||
| drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities, | drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities, | ||||
| raw_file_directory=DRUG_DATA_FOLDER) | |||||
| raw_file_directory=drug_directory) | |||||
| # Step 3: Update the 'data' dictionary with drug data | # Step 3: Update the 'data' dictionary with drug data | ||||
| data.update(drug_data) | data.update(drug_data) | ||||
| df.columns = df.columns.str.replace('_cell_CN', '') | df.columns = df.columns.str.replace('_cell_CN', '') | ||||
| df.columns = df.columns.str.replace('_cell_exp', '') | df.columns = df.columns.str.replace('_cell_exp', '') | ||||
| # Note that drug_comp raw table has some NA values so we should impute it | |||||
| if any(df.isna()): | |||||
| df = pd.DataFrame(SimpleImputer(strategy='mean').fit_transform(df), | |||||
| columns=df.columns).set_index(df.index) | |||||
| if file.startswith('drug_comp'): # We need to normalize the drug_data comp dataset | |||||
| df = ((df - df.mean()) / df.std()).fillna(0) | |||||
| elif file.startswith('drug_desc'): # We need to normalize the drug_data comp dataset | |||||
| df = ((df - df.mean()) / df.std()).fillna(0) | |||||
| df = (df - df.min()) / (df.max() - df.min()) | |||||
| df = df.fillna(0) | |||||
| print("has null:") | |||||
| print(df.isnull().sum().sum()) | |||||
| if intersect: | if intersect: | ||||
| if file.startswith('cell'): | if file.startswith('cell'): | ||||
| if cell_line_names: | if cell_line_names: |
| @staticmethod | @staticmethod | ||||
| def evaluate(all_targets, mlp_output, show_plot=True): | |||||
| def evaluate(all_targets, mlp_output, show_plot=False): | |||||
| """ | """ | ||||
| Evaluate model performance based on predictions and targets. | Evaluate model performance based on predictions and targets. | ||||
| screen_file_directory=GDSC_SCREENING_DATA_FOLDER, | screen_file_directory=GDSC_SCREENING_DATA_FOLDER, | ||||
| sep="\t") | sep="\t") | ||||
| # Step 3: Load test data if applicable | # Step 3: Load test data if applicable | ||||
| if is_test: | if is_test: | ||||
| test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES, | test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES, | ||||
| torch.manual_seed(RANDOM_SEED) | torch.manual_seed(RANDOM_SEED) | ||||
| random.seed(RANDOM_SEED) | random.seed(RANDOM_SEED) | ||||
| np.random.seed(RANDOM_SEED) | np.random.seed(RANDOM_SEED) | ||||
| run(10, is_test=False) | |||||
| run(10, is_test=True) |
| nn.Linear(input_dim, 128), | nn.Linear(input_dim, 128), | ||||
| nn.ReLU(inplace=True), | nn.ReLU(inplace=True), | ||||
| nn.Linear(128, output_dim), | nn.Linear(128, output_dim), | ||||
| nn.Hardsigmoid(), | |||||
| nn.Sigmoid(), | |||||
| ) | ) | ||||
| def forward(self, x): | def forward(self, x): |
| GDSC_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'GDSC_data') | GDSC_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'GDSC_data') | ||||
| CCLE_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CCLE_data') | CCLE_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CCLE_data') | ||||
| CTRP_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_data') | CTRP_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_data') | ||||
| SIM_DATA_FOLDER = os.path.join(DATA_FOLDER, 'similarity_data') | |||||
| GDSC_SCREENING_DATA_FOLDER = os.path.join(GDSC_RAW_DATA_FOLDER, 'drug_screening_matrix_GDSC.tsv') | GDSC_SCREENING_DATA_FOLDER = os.path.join(GDSC_RAW_DATA_FOLDER, 'drug_screening_matrix_GDSC.tsv') | ||||
| CCLE_SCREENING_DATA_FOLDER = os.path.join(CCLE_RAW_DATA_FOLDER, 'drug_screening_matrix_ccle.tsv') | CCLE_SCREENING_DATA_FOLDER = os.path.join(CCLE_RAW_DATA_FOLDER, 'drug_screening_matrix_ccle.tsv') |