Browse Source

feat: stable

main
Taha Mohammadzadeh 1 year ago
parent
commit
b2bdea1fcc
6 changed files with 14 additions and 13 deletions
  1. 2
    0
      DeepDRA.py
  2. 7
    10
      data_loader.py
  3. 1
    1
      evaluation.py
  4. 2
    1
      main.py
  5. 1
    1
      mlp.py
  6. 1
    0
      utils.py

+ 2
- 0
DeepDRA.py View File

@@ -140,6 +140,8 @@ def train(model, train_loader, val_loader, num_epochs,class_weights):

# Backward pass and optimization
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

mlp_optimizer.step()
total_train_loss += total_loss.item()


+ 7
- 10
data_loader.py View File

@@ -9,7 +9,7 @@ from utils import DRUG_DATA_FOLDER

class RawDataLoader:
@staticmethod
def load_data(data_modalities, raw_file_directory, screen_file_directory, sep):
def load_data(data_modalities, raw_file_directory, screen_file_directory, sep, drug_directory=DRUG_DATA_FOLDER):
"""
Load raw data and screening data, perform intersection, and adjust screening data.

@@ -29,7 +29,7 @@ class RawDataLoader:

# Step 2: Load drug data files for specified data modalities
drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities,
raw_file_directory=DRUG_DATA_FOLDER)
raw_file_directory=drug_directory)

# Step 3: Update the 'data' dictionary with drug data
data.update(drug_data)
@@ -135,14 +135,11 @@ class RawDataLoader:
df.columns = df.columns.str.replace('_cell_CN', '')
df.columns = df.columns.str.replace('_cell_exp', '')

# Note that drug_comp raw table has some NA values so we should impute it
if any(df.isna()):
df = pd.DataFrame(SimpleImputer(strategy='mean').fit_transform(df),
columns=df.columns).set_index(df.index)
if file.startswith('drug_comp'): # We need to normalize the drug_data comp dataset
df = ((df - df.mean()) / df.std()).fillna(0)
elif file.startswith('drug_desc'): # We need to normalize the drug_data comp dataset
df = ((df - df.mean()) / df.std()).fillna(0)
df = (df - df.min()) / (df.max() - df.min())
df = df.fillna(0)

print("has null:")
print(df.isnull().sum().sum())
if intersect:
if file.startswith('cell'):
if cell_line_names:

+ 1
- 1
evaluation.py View File

@@ -53,7 +53,7 @@ class Evaluation:


@staticmethod
def evaluate(all_targets, mlp_output, show_plot=True):
def evaluate(all_targets, mlp_output, show_plot=False):
"""
Evaluate model performance based on predictions and targets.


+ 2
- 1
main.py View File

@@ -165,6 +165,7 @@ def run(k, is_test=False ):
screen_file_directory=GDSC_SCREENING_DATA_FOLDER,
sep="\t")


# Step 3: Load test data if applicable
if is_test:
test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES,
@@ -225,4 +226,4 @@ if __name__ == '__main__':
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
run(10, is_test=False)
run(10, is_test=True)

+ 1
- 1
mlp.py View File

@@ -27,7 +27,7 @@ class MLP(nn.Module):
nn.Linear(input_dim, 128),
nn.ReLU(inplace=True),
nn.Linear(128, output_dim),
nn.Hardsigmoid(),
nn.Sigmoid(),
)

def forward(self, x):

+ 1
- 0
utils.py View File

@@ -8,6 +8,7 @@ DRUG_DATA_FOLDER = os.path.join(DATA_FOLDER, 'drug_data')
GDSC_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'GDSC_data')
CCLE_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CCLE_data')
CTRP_RAW_DATA_FOLDER = os.path.join(DATA_FOLDER, 'CTRP_data')
SIM_DATA_FOLDER = os.path.join(DATA_FOLDER, 'similarity_data')

GDSC_SCREENING_DATA_FOLDER = os.path.join(GDSC_RAW_DATA_FOLDER, 'drug_screening_matrix_GDSC.tsv')
CCLE_SCREENING_DATA_FOLDER = os.path.join(CCLE_RAW_DATA_FOLDER, 'drug_screening_matrix_ccle.tsv')

Loading…
Cancel
Save