You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 10.0KB

11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. from imblearn.under_sampling import RandomUnderSampler
  2. from sklearn.model_selection import train_test_split
  3. from sklearn.utils.class_weight import compute_class_weight
  4. from torch.utils.data import TensorDataset, DataLoader, SubsetRandomSampler
  5. from sklearn.model_selection import KFold
  6. from DeepDRA import DeepDRA, train, test
  7. from data_loader import RawDataLoader
  8. from evaluation import Evaluation
  9. from utils import *
  10. import random
  11. import torch
  12. import numpy as np
  13. import pandas as pd
  14. # Step 1: Define the batch size for training
  15. batch_size = 64
  16. # Step 2: Instantiate the combined model
  17. ae_latent_dim = 50
  18. num_epochs = 25
  19. def train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train, y_test, cell_sizes, drug_sizes,device):
  20. """
  21. Train and evaluate the DeepDRA model.
  22. Parameters:
  23. - X_cell_train (pd.DataFrame): Training data for the cell modality.
  24. - X_cell_test (pd.DataFrame): Test data for the cell modality.
  25. - X_drug_train (pd.DataFrame): Training data for the drug modality.
  26. - X_drug_test (pd.DataFrame): Test data for the drug modality.
  27. - y_train (pd.Series): Training labels.
  28. - y_test (pd.Series): Test labels.
  29. - cell_sizes (list): Sizes of the cell modality features.
  30. - drug_sizes (list): Sizes of the drug modality features.
  31. Returns:
  32. - result: Evaluation result on the test set.
  33. """
  34. model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim)
  35. model.to(device)
  36. # Step 3: Convert your training data to PyTorch tensors
  37. x_cell_train_tensor = torch.Tensor(x_cell_train.values)
  38. x_drug_train_tensor = torch.Tensor(x_drug_train.values)
  39. x_cell_train_tensor = torch.nn.functional.normalize(x_cell_train_tensor, dim=0)
  40. x_drug_train_tensor = torch.nn.functional.normalize(x_drug_train_tensor, dim=0)
  41. y_train_tensor = torch.Tensor(y_train)
  42. y_train_tensor = y_train_tensor.unsqueeze(1)
  43. x_cell_train_tensor.to(device)
  44. x_drug_train_tensor.to(device)
  45. y_train_tensor.to(device)
  46. # Compute class weights
  47. classes = [0, 1] # Assuming binary classification
  48. class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=classes, y=y_train),
  49. dtype=torch.float32)
  50. x_cell_train_tensor, x_cell_val_tensor, x_drug_train_tensor, x_drug_val_tensor, y_train_tensor, y_val_tensor = train_test_split(
  51. x_cell_train_tensor, x_drug_train_tensor, y_train_tensor, test_size=0.1,
  52. random_state=RANDOM_SEED,
  53. shuffle=True)
  54. # Step 4: Create a TensorDataset with the input features and target labels
  55. train_dataset = TensorDataset(x_cell_train_tensor, x_drug_train_tensor, y_train_tensor)
  56. val_dataset = TensorDataset(x_cell_val_tensor, x_drug_val_tensor, y_val_tensor)
  57. # Step 5: Create the train_loader
  58. train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  59. val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
  60. # Step 6: Train the model
  61. train(model, train_loader, val_loader, num_epochs,class_weights)
  62. # Step 7: Save the trained model
  63. torch.save(model, 'DeepDRA.pth')
  64. # Step 8: Load the saved model
  65. model = torch.load('DeepDRA.pth')
  66. # Step 9: Convert your test data to PyTorch tensors
  67. x_cell_test_tensor = torch.Tensor(x_cell_test.values)
  68. x_drug_test_tensor = torch.Tensor(x_drug_test.values)
  69. y_test_tensor = torch.Tensor(y_test)
  70. # normalize data
  71. x_cell_test_tensor = torch.nn.functional.normalize(x_cell_test_tensor, dim=0)
  72. x_drug_test_tensor = torch.nn.functional.normalize(x_drug_test_tensor, dim=0)
  73. # Step 10: Create a TensorDataset with the input features and target labels for testing
  74. test_dataset = TensorDataset(x_cell_test_tensor, x_drug_test_tensor, y_test_tensor)
  75. test_loader = DataLoader(test_dataset, batch_size=len(x_cell_test))
  76. # Step 11: Test the model
  77. return test(model, test_loader)
  78. def cv_train(x_cell_train, x_drug_train, y_train, cell_sizes,
  79. drug_sizes, device, k=5, ):
  80. splits = KFold(n_splits=k, shuffle=True, random_state=RANDOM_SEED)
  81. history = {'AUC': [], 'AUPRC': [], "Accuracy": [], "Precision": [], "Recall": [], "F1 score": []}
  82. for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(x_cell_train)))):
  83. print('Fold {}'.format(fold + 1))
  84. train_sampler = SubsetRandomSampler(train_idx)
  85. test_sampler = SubsetRandomSampler(val_idx)
  86. model = DeepDRA(cell_sizes, drug_sizes, ae_latent_dim, ae_latent_dim)
  87. # Convert your training data to PyTorch tensors
  88. x_cell_train_tensor = torch.Tensor(x_cell_train.values)
  89. x_drug_train_tensor = torch.Tensor(x_drug_train.values)
  90. y_train_tensor = torch.Tensor(y_train)
  91. y_train_tensor = y_train_tensor.unsqueeze(1)
  92. # Compute class weights
  93. classes = [0, 1] # Assuming binary classification
  94. class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=classes, y=y_train),
  95. dtype=torch.float32)
  96. # Create a TensorDataset with the input features and target labels
  97. train_dataset = TensorDataset(x_cell_train_tensor, x_drug_train_tensor, y_train_tensor)
  98. # Create the train_loader
  99. train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
  100. # Train the model
  101. train(model, train_loader,train_loader, num_epochs, class_weights)
  102. # Create a TensorDataset with the input features and target labels
  103. test_loader = DataLoader(train_dataset, batch_size=len(x_cell_train), sampler=test_sampler)
  104. # Test the model
  105. results = test(model, test_loader)
  106. # Step 10: Add results to the history dictionary
  107. Evaluation.add_results(history, results)
  108. return Evaluation.show_final_results(history)
  109. def run(k, is_test=False ):
  110. """
  111. Run the training and evaluation process k times.
  112. Parameters:
  113. - k (int): Number of times to run the process.
  114. - is_test (bool): If True, run on test data; otherwise, perform train-validation split.
  115. Returns:
  116. - history (dict): Dictionary containing evaluation metrics for each run.
  117. """
  118. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  119. # Step 1: Initialize a dictionary to store evaluation metrics
  120. history = {'AUC': [], 'AUPRC': [], "Accuracy": [], "Precision": [], "Recall": [], "F1 score": []}
  121. # Step 2: Load training data
  122. train_data, train_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES,
  123. raw_file_directory=RAW_BOTH_DATA_FOLDER,
  124. screen_file_directory=BOTH_SCREENING_DATA_FOLDER,
  125. sep="\t")
  126. # Step 3: Load test data if applicable
  127. if is_test:
  128. test_data, test_drug_screen = RawDataLoader.load_data(data_modalities=DATA_MODALITIES,
  129. raw_file_directory=CCLE_RAW_DATA_FOLDER,
  130. screen_file_directory=CCLE_SCREENING_DATA_FOLDER,
  131. sep="\t")
  132. train_data, test_data = RawDataLoader.data_features_intersect(train_data, test_data)
  133. # Step 4: Prepare input data for training
  134. x_cell_train, x_drug_train, y_train, cell_sizes, drug_sizes = RawDataLoader.prepare_input_data(train_data,
  135. train_drug_screen)
  136. if is_test:
  137. x_cell_test, x_drug_test, y_test, cell_sizes, drug_sizes = RawDataLoader.prepare_input_data(test_data,
  138. test_drug_screen)
  139. rus = RandomUnderSampler(sampling_strategy="majority", random_state=RANDOM_SEED)
  140. dataset = pd.concat([x_cell_train, x_drug_train], axis=1)
  141. dataset.index = x_cell_train.index
  142. dataset, y_train = rus.fit_resample(dataset, y_train)
  143. x_cell_train = dataset.iloc[:, :sum(cell_sizes)]
  144. x_drug_train = dataset.iloc[:, sum(cell_sizes):]
  145. # Step 5: Loop over k runs
  146. for i in range(k):
  147. print('Run {}'.format(i))
  148. # Step 6: If is_test is True, perform random under-sampling on the training data
  149. if is_test:
  150. # Step 7: Train and evaluate the DeepDRA model on test data
  151. results = train_DeepDRA(x_cell_train, x_cell_test, x_drug_train, x_drug_test, y_train, y_test, cell_sizes,
  152. drug_sizes, device)
  153. else:
  154. # # Step 8: Split the data into training and validation sets
  155. # X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test = train_test_split(X_cell_train,
  156. # X_drug_train, y_train,
  157. # test_size=0.2,
  158. # random_state=RANDOM_SEED,
  159. # shuffle=True)
  160. # # Step 9: Train and evaluate the DeepDRA model on the split data
  161. # results = train_DeepDRA(X_cell_train, X_cell_test, X_drug_train, X_drug_test, y_train, y_test, cell_sizes,
  162. # drug_sizes, device)
  163. results = cv_train(x_cell_train, x_drug_train, y_train, cell_sizes, drug_sizes, device, k=5)
  164. # Step 10: Add results to the history dictionary
  165. Evaluation.add_results(history, results)
  166. # Step 11: Display final results
  167. Evaluation.show_final_results(history)
  168. return history
  169. if __name__ == '__main__':
  170. torch.manual_seed(RANDOM_SEED)
  171. random.seed(RANDOM_SEED)
  172. np.random.seed(RANDOM_SEED)
  173. run(10, is_test=True)