You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 10KB

11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from imblearn.under_sampling import RandomUnderSampler
  2. from sklearn.model_selection import train_test_split
  3. from sklearn.utils.class_weight import compute_class_weight
  4. from torch.utils.data import TensorDataset, DataLoader, SubsetRandomSampler
  5. from sklearn.model_selection import KFold
  6. from DeepDRA import DeepDRA, train, test
  7. from data_loader import RawDataLoader
  8. from evaluation import Evaluation
  9. from utils import *
  10. import random
  11. import torch
  12. import numpy as np
  13. import pandas as pd
  14. # Step 1: Define the batch size for training
  15. batch_size = 64
  16. # Step 2: Instantiate the combined model
  17. ae_latent_dim = 50
  18. mlp_input_dim = 2 * ae_latent_dim
  19. mlp_output_dim = 1
  20. num_epochs = 25
  21. def train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train, y_test, cell_sizes, drug_sizes,device):
  22. """
  23. Train and evaluate the DeepDRA model.
  24. Parameters:
  25. - X_cell_train (pd.DataFrame): Training data for the cell modality.
  26. - X_cell_test (pd.DataFrame): Test data for the cell modality.
  27. - X_drug_train (pd.DataFrame): Training data for the drug modality.
  28. - X_drug_test (pd.DataFrame): Test data for the drug modality.
  29. - y_train (pd.Series): Training labels.
  30. - y_test (pd.Series): Test labels.
  31. - cell_sizes (list): Sizes of the cell modality features.
  32. - drug_sizes (list): Sizes of the drug modality features.
  33. Returns:
  34. - result: Evaluation result on the test set.
  35. """
  36. model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim, mlp_input_dim, mlp_output_dim)
  37. model.to(device)
  38. # Step 3: Convert your training data to PyTorch tensors
  39. x_cell_train_tensor = torch.Tensor(x_cell_train.values)
  40. x_drug_train_tensor = torch.Tensor(x_drug_train.values)
  41. x_cell_train_tensor = torch.nn.functional.normalize(x_cell_train_tensor, dim=0)
  42. x_drug_train_tensor = torch.nn.functional.normalize(x_drug_train_tensor, dim=0)
  43. y_train_tensor = torch.Tensor(y_train)
  44. y_train_tensor = y_train_tensor.unsqueeze(1)
  45. x_cell_train_tensor.to(device)
  46. x_drug_train_tensor.to(device)
  47. y_train_tensor.to(device)
  48. # Compute class weights
  49. classes = [0, 1] # Assuming binary classification
  50. class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=classes, y=y_train),
  51. dtype=torch.float32)
  52. x_cell_train_tensor, x_cell_val_tensor, x_drug_train_tensor, x_drug_val_tensor, y_train_tensor, y_val_tensor = train_test_split(
  53. x_cell_train_tensor, x_drug_train_tensor, y_train_tensor, test_size=0.1,
  54. random_state=RANDOM_SEED,
  55. shuffle=True)
  56. # Step 4: Create a TensorDataset with the input features and target labels
  57. train_dataset = TensorDataset(x_cell_train_tensor, x_drug_train_tensor, y_train_tensor)
  58. val_dataset = TensorDataset(x_cell_val_tensor, x_drug_val_tensor, y_val_tensor)
  59. # Step 5: Create the train_loader
  60. train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  61. val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
  62. # Step 6: Train the model
  63. train(model, train_loader, val_loader, num_epochs,class_weights)
  64. # Step 7: Save the trained model
  65. torch.save(model, 'DeepDRA.pth')
  66. # Step 8: Load the saved model
  67. model = torch.load('DeepDRA.pth')
  68. # Step 9: Convert your test data to PyTorch tensors
  69. x_cell_test_tensor = torch.Tensor(x_cell_test.values)
  70. x_drug_test_tensor = torch.Tensor(x_drug_test.values)
  71. y_test_tensor = torch.Tensor(y_test)
  72. # normalize data
  73. x_cell_test_tensor = torch.nn.functional.normalize(x_cell_test_tensor, dim=0)
  74. x_drug_test_tensor = torch.nn.functional.normalize(x_drug_test_tensor, dim=0)
  75. # Step 10: Create a TensorDataset with the input features and target labels for testing
  76. test_dataset = TensorDataset(x_cell_test_tensor, x_drug_test_tensor, y_test_tensor)
  77. test_loader = DataLoader(test_dataset, batch_size=len(x_cell_test))
  78. # Step 11: Test the model
  79. return test(model, test_loader)
  80. def cv_train(x_cell_train, x_drug_train, y_train, cell_sizes,
  81. drug_sizes, device, k=5, ):
  82. splits = KFold(n_splits=k, shuffle=True, random_state=RANDOM_SEED)
  83. history = {'AUC': [], 'AUPRC': [], "Accuracy": [], "Precision": [], "Recall": [], "F1 score": []}
  84. for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(x_cell_train)))):
  85. print('Fold {}'.format(fold + 1))
  86. train_sampler = SubsetRandomSampler(train_idx)
  87. test_sampler = SubsetRandomSampler(val_idx)
  88. model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim, mlp_input_dim, mlp_output_dim)
  89. # Convert your training data to PyTorch tensors
  90. x_cell_train_tensor = torch.Tensor(x_cell_train.values)
  91. x_drug_train_tensor = torch.Tensor(x_drug_train.values)
  92. y_train_tensor = torch.Tensor(y_train)
  93. y_train_tensor = y_train_tensor.unsqueeze(1)
  94. # Compute class weights
  95. classes = [0, 1] # Assuming binary classification
  96. class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=classes, y=y_train),
  97. dtype=torch.float32)
  98. # Create a TensorDataset with the input features and target labels
  99. train_dataset = TensorDataset(x_cell_train_tensor, x_drug_train_tensor, y_train_tensor)
  100. # Create the train_loader
  101. train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
  102. # Train the model
  103. train(model, train_loader,train_loader, num_epochs, class_weights)
  104. # Create a TensorDataset with the input features and target labels
  105. test_loader = DataLoader(train_dataset, batch_size=len(x_cell_train), sampler=test_sampler)
  106. # Test the model
  107. results = test(model, test_loader)
  108. # Step 10: Add results to the history dictionary
  109. Evaluation.add_results(history, results)
  110. return Evaluation.show_final_results(history)
  111. def run(k, is_test=False ):
  112. """
  113. Run the training and evaluation process k times.
  114. Parameters:
  115. - k (int): Number of times to run the process.
  116. - is_test (bool): If True, run on test data; otherwise, perform train-validation split.
  117. Returns:
  118. - history (dict): Dictionary containing evaluation metrics for each run.
  119. """
  120. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  121. # Step 1: Initialize a dictionary to store evaluation metrics
  122. history = {'AUC': [], 'AUPRC': [], "Accuracy": [], "Precision": [], "Recall": [], "F1 score": []}
  123. # Step 2: Load training data
  124. train_data, train_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES,
  125. raw_file_directory=RAW_BOTH_DATA_FOLDER,
  126. screen_file_directory=BOTH_SCREENING_DATA_FOLDER,
  127. sep="\t")
  128. # Step 3: Load test data if applicable
  129. if is_test:
  130. test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES,
  131. raw_file_directory=CCLE_RAW_DATA_FOLDER,
  132. screen_file_directory=CCLE_SCREENING_DATA_FOLDER,
  133. sep="\t")
  134. train_data, test_data = RawDataLoader.data_features_intersect(train_data, test_data)
  135. # Step 4: Prepare input data for training
  136. x_cell_train, x_drug_train, y_train, cell_sizes, drug_sizes = RawDataLoader.prepare_input_data(train_data,
  137. train_drug_screen)
  138. if is_test:
  139. x_cell_test, x_drug_test, y_test, cell_sizes, drug_sizes = RawDataLoader.prepare_input_data(test_data,
  140. test_drug_screen)
  141. rus = RandomUnderSampler(sampling_strategy="majority", random_state=RANDOM_SEED)
  142. dataset = pd.concat([x_cell_train, x_drug_train], axis=1)
  143. dataset.index = x_cell_train.index
  144. dataset, y_train = rus.fit_resample(dataset, y_train)
  145. x_cell_train = dataset.iloc[:, :sum(cell_sizes)]
  146. x_drug_train = dataset.iloc[:, sum(cell_sizes):]
  147. # Step 5: Loop over k runs
  148. for i in range(k):
  149. print('Run {}'.format(i))
  150. # Step 6: If is_test is True, perform random under-sampling on the training data
  151. if is_test:
  152. # Step 7: Train and evaluate the DeepDRA model on test data
  153. results = train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train, y_test, cell_sizes,
  154. drug_sizes, device)
  155. else:
  156. # # Step 8: Split the data into training and validation sets
  157. # X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test = train_test_split(X_cell_train,
  158. # X_drug_train, y_train,
  159. # test_size=0.2,
  160. # random_state=RANDOM_SEED,
  161. # shuffle=True)
  162. # # Step 9: Train and evaluate the DeepDRA model on the split data
  163. # results = train_DeepDRA(X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test, cell_sizes,
  164. # drug_sizes, device)
  165. results = cv_train(x_cell_train, x_drug_train, y_train, cell_sizes, drug_sizes, device, k=5)
  166. # Step 10: Add results to the history dictionary
  167. Evaluation.add_results(history, results)
  168. # Step 11: Display final results
  169. Evaluation.show_final_results(history)
  170. return history
  171. if __name__ == '__main__':
  172. torch.manual_seed(RANDOM_SEED)
  173. random.seed(RANDOM_SEED)
  174. np.random.seed(RANDOM_SEED)
  175. run(10, is_test=False)