@@ -140,6 +140,8 @@ def train(model, train_loader, val_loader, num_epochs,class_weights): | |||
# Backward pass and optimization | |||
total_loss.backward() | |||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1) | |||
mlp_optimizer.step() | |||
total_train_loss += total_loss.item() | |||
@@ -9,7 +9,7 @@ from utils import DRUG_DATA_FOLDER | |||
class RawDataLoader: | |||
@staticmethod | |||
def load_data(data_modalities, raw_file_directory, screen_file_directory, sep): | |||
def load_data(data_modalities, raw_file_directory, screen_file_directory, sep, drug_directory=DRUG_DATA_FOLDER): | |||
""" | |||
Load raw data and screening data, perform intersection, and adjust screening data. | |||
@@ -29,7 +29,7 @@ class RawDataLoader: | |||
# Step 2: Load drug data files for specified data modalities | |||
drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities, | |||
raw_file_directory=DRUG_DATA_FOLDER) | |||
raw_file_directory=drug_directory) | |||
# Step 3: Update the 'data' dictionary with drug data | |||
data.update(drug_data) | |||
@@ -135,14 +135,11 @@ class RawDataLoader: | |||
df.columns = df.columns.str.replace('_cell_CN', '') | |||
df.columns = df.columns.str.replace('_cell_exp', '') | |||
# Note that drug_comp raw table has some NA values so we should impute it | |||
if any(df.isna()): | |||
df = pd.DataFrame(SimpleImputer(strategy='mean').fit_transform(df), | |||
columns=df.columns).set_index(df.index) | |||
if file.startswith('drug_comp'): # We need to normalize the drug_data comp dataset | |||
df = ((df - df.mean()) / df.std()).fillna(0) | |||
elif file.startswith('drug_desc'): # We need to normalize the drug_data comp dataset | |||
df = ((df - df.mean()) / df.std()).fillna(0) | |||
df = (df - df.min()) / (df.max() - df.min()) | |||
df = df.fillna(0) | |||
print("has null:") | |||
print(df.isnull().sum().sum()) | |||
if intersect: | |||
if file.startswith('cell'): | |||
if cell_line_names: |
@@ -53,7 +53,7 @@ class Evaluation: | |||
@staticmethod | |||
def evaluate(all_targets, mlp_output, show_plot=True): | |||
def evaluate(all_targets, mlp_output, show_plot=False): | |||
""" | |||
Evaluate model performance based on predictions and targets. | |||
@@ -165,6 +165,7 @@ def run(k, is_test=False ): | |||
screen_file_directory=GDSC_SCREENING_DATA_FOLDER, | |||
sep="\t") | |||
# Step 3: Load test data if applicable | |||
if is_test: | |||
test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES, | |||
@@ -225,4 +226,4 @@ if __name__ == '__main__': | |||
torch.manual_seed(RANDOM_SEED) | |||
random.seed(RANDOM_SEED) | |||
np.random.seed(RANDOM_SEED) | |||
run(10, is_test=False) | |||
run(10, is_test=True) |
@@ -27,7 +27,7 @@ class MLP(nn.Module): | |||
nn.Linear(input_dim, 128), | |||
nn.ReLU(inplace=True), | |||
nn.Linear(128, output_dim), | |||
nn.Hardsigmoid(), | |||
nn.Sigmoid(), | |||
) | |||
def forward(self, x): |
@@ -8,6 +8,7 @@ DRUG_DATA_FOLDER = os.path.join(DATA_FOLDER, 'drug_data') | |||
GDSC_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'GDSC_data') | |||
CCLE_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CCLE_data') | |||
CTRP_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_data') | |||
SIM_DATA_FOLDER = os.path.join(DATA_FOLDER, 'similarity_data') | |||
GDSC_SCREENING_DATA_FOLDER = os.path.join(GDSC_RAW_DATA_FOLDER, 'drug_screening_matrix_GDSC.tsv') | |||
CCLE_SCREENING_DATA_FOLDER = os.path.join(CCLE_RAW_DATA_FOLDER, 'drug_screening_matrix_ccle.tsv') |