# Backward pass and optimization | # Backward pass and optimization | ||||
total_loss.backward() | total_loss.backward() | ||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1) | |||||
mlp_optimizer.step() | mlp_optimizer.step() | ||||
total_train_loss += total_loss.item() | total_train_loss += total_loss.item() | ||||
class RawDataLoader: | class RawDataLoader: | ||||
@staticmethod | @staticmethod | ||||
def load_data(data_modalities, raw_file_directory, screen_file_directory, sep): | |||||
def load_data(data_modalities, raw_file_directory, screen_file_directory, sep, drug_directory=DRUG_DATA_FOLDER): | |||||
""" | """ | ||||
Load raw data and screening data, perform intersection, and adjust screening data. | Load raw data and screening data, perform intersection, and adjust screening data. | ||||
# Step 2: Load drug data files for specified data modalities | # Step 2: Load drug data files for specified data modalities | ||||
drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities, | drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities, | ||||
raw_file_directory=DRUG_DATA_FOLDER) | |||||
raw_file_directory=drug_directory) | |||||
# Step 3: Update the 'data' dictionary with drug data | # Step 3: Update the 'data' dictionary with drug data | ||||
data.update(drug_data) | data.update(drug_data) | ||||
df.columns = df.columns.str.replace('_cell_CN', '') | df.columns = df.columns.str.replace('_cell_CN', '') | ||||
df.columns = df.columns.str.replace('_cell_exp', '') | df.columns = df.columns.str.replace('_cell_exp', '') | ||||
# Note that drug_comp raw table has some NA values so we should impute it | |||||
if any(df.isna()): | |||||
df = pd.DataFrame(SimpleImputer(strategy='mean').fit_transform(df), | |||||
columns=df.columns).set_index(df.index) | |||||
if file.startswith('drug_comp'): # We need to normalize the drug_data comp dataset | |||||
df = ((df - df.mean()) / df.std()).fillna(0) | |||||
elif file.startswith('drug_desc'): # We need to normalize the drug_data comp dataset | |||||
df = ((df - df.mean()) / df.std()).fillna(0) | |||||
df = (df - df.min()) / (df.max() - df.min()) | |||||
df = df.fillna(0) | |||||
print("has null:") | |||||
print(df.isnull().sum().sum()) | |||||
if intersect: | if intersect: | ||||
if file.startswith('cell'): | if file.startswith('cell'): | ||||
if cell_line_names: | if cell_line_names: |
@staticmethod | @staticmethod | ||||
def evaluate(all_targets, mlp_output, show_plot=True): | |||||
def evaluate(all_targets, mlp_output, show_plot=False): | |||||
""" | """ | ||||
Evaluate model performance based on predictions and targets. | Evaluate model performance based on predictions and targets. | ||||
screen_file_directory=GDSC_SCREENING_DATA_FOLDER, | screen_file_directory=GDSC_SCREENING_DATA_FOLDER, | ||||
sep="\t") | sep="\t") | ||||
# Step 3: Load test data if applicable | # Step 3: Load test data if applicable | ||||
if is_test: | if is_test: | ||||
test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES, | test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES, | ||||
torch.manual_seed(RANDOM_SEED) | torch.manual_seed(RANDOM_SEED) | ||||
random.seed(RANDOM_SEED) | random.seed(RANDOM_SEED) | ||||
np.random.seed(RANDOM_SEED) | np.random.seed(RANDOM_SEED) | ||||
run(10, is_test=False) | |||||
run(10, is_test=True) |
nn.Linear(input_dim, 128), | nn.Linear(input_dim, 128), | ||||
nn.ReLU(inplace=True), | nn.ReLU(inplace=True), | ||||
nn.Linear(128, output_dim), | nn.Linear(128, output_dim), | ||||
nn.Hardsigmoid(), | |||||
nn.Sigmoid(), | |||||
) | ) | ||||
def forward(self, x): | def forward(self, x): |
GDSC_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'GDSC_data') | GDSC_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'GDSC_data') | ||||
CCLE_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CCLE_data') | CCLE_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CCLE_data') | ||||
CTRP_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_data') | CTRP_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_data') | ||||
SIM_DATA_FOLDER = os.path.join(DATA_FOLDER, 'similarity_data') | |||||
GDSC_SCREENING_DATA_FOLDER = os.path.join(GDSC_RAW_DATA_FOLDER, 'drug_screening_matrix_GDSC.tsv') | GDSC_SCREENING_DATA_FOLDER = os.path.join(GDSC_RAW_DATA_FOLDER, 'drug_screening_matrix_GDSC.tsv') | ||||
CCLE_SCREENING_DATA_FOLDER = os.path.join(CCLE_RAW_DATA_FOLDER, 'drug_screening_matrix_ccle.tsv') | CCLE_SCREENING_DATA_FOLDER = os.path.join(CCLE_RAW_DATA_FOLDER, 'drug_screening_matrix_ccle.tsv') |