You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. from utils import *
  2. import os
  3. class RawDataLoader:
  4. @staticmethod
  5. def load_data(data_modalities, raw_file_directory, screen_file_directory, sep):
  6. """
  7. Load raw data and screening data, perform intersection, and adjust screening data.
  8. Parameters:
  9. - data_modalities (list): List of data modalities to load.
  10. - raw_file_directory (str): Directory containing raw data files.
  11. - screen_file_directory (str): Directory containing screening data files.
  12. - sep (str): Separator used in the data files.
  13. Returns:
  14. - data (dict): Dictionary containing loaded raw data.
  15. - drug_screen (pd.DataFrame): Adjusted and intersected screening data.
  16. """
  17. # Step 1: Load raw data files for specified data modalities
  18. data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities,
  19. raw_file_directory=raw_file_directory)
  20. # Step 2: Load drug data files for specified data modalities
  21. drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities,
  22. raw_file_directory=DRUG_DATA_FOLDER)
  23. # Step 3: Update the 'data' dictionary with drug data
  24. data.update(drug_data)
  25. # Step 4: Load and adjust drug screening data
  26. drug_screen = RawDataLoader.load_screening_files(
  27. filename=screen_file_directory,
  28. sep=sep)
  29. drug_screen, data = RawDataLoader.adjust_screening_raw(
  30. drug_screen=drug_screen, data_dict=data)
  31. # Step 5: Return the loaded data and adjusted drug screening data
  32. return data, drug_screen
  33. @staticmethod
  34. def intersect_features(data1, data2):
  35. """
  36. Perform intersection of features between two datasets.
  37. Parameters:
  38. - data1 (pd.DataFrame): First dataset.
  39. - data2 (pd.DataFrame): Second dataset.
  40. Returns:
  41. - data1 (pd.DataFrame): First dataset with common columns.
  42. - data2 (pd.DataFrame): Second dataset with common columns.
  43. """
  44. # Step 1: Find common columns between the two datasets
  45. common_columns = list(set(data1.columns) & set(data2.columns))
  46. # Step 2: Filter data2 to include only common columns
  47. data2 = data2[common_columns]
  48. # Step 3: Filter data1 to include only common columns
  49. data1 = data1[common_columns]
  50. # Step 4: Return the datasets with intersected features
  51. return data1, data2
  52. @staticmethod
  53. def data_features_intersect(data1, data2):
  54. """
  55. Intersect features between two datasets column-wise.
  56. Parameters:
  57. - data1 (dict): Dictionary containing data modalities.
  58. - data2 (dict): Dictionary containing data modalities.
  59. Returns:
  60. - intersected_data1 (dict): Data1 with intersected features.
  61. - intersected_data2 (dict): Data2 with intersected features.
  62. """
  63. # Iterate over each data modality
  64. for i in data1:
  65. # Intersect features for each modality
  66. data1[i], data2[i] = RawDataLoader.intersect_features(data1[i], data2[i])
  67. return data1, data2
  68. @staticmethod
  69. def load_file(address, index_column=None):
  70. """
  71. Load data from a file based on its format.
  72. Parameters:
  73. - address (str): File address.
  74. - index_column (str): Name of the index column.
  75. Returns:
  76. - data (pd.DataFrame): Loaded data from the file.
  77. """
  78. data = []
  79. try:
  80. # Load data based on file format
  81. if address.endswith('.txt') or address.endswith('.tsv'):
  82. data.append(pd.read_csv(address, sep='\t', index_col=index_column), )
  83. elif address.endswith('.csv'):
  84. data.append(pd.read_csv(address))
  85. elif address.endswith('.xlsx'):
  86. data.append(pd.read_excel(address))
  87. except FileNotFoundError:
  88. print(f'File not found at address: {address}')
  89. return data[0]
  90. @staticmethod
  91. def load_raw_files(raw_file_directory, data_modalities, intersect=True):
  92. raw_dict = {}
  93. files = os.listdir(raw_file_directory)
  94. cell_line_names = None
  95. drug_names = None
  96. for file in tqdm(files, 'Reading Raw Data Files...'):
  97. if any([file.startswith(x) for x in data_modalities]):
  98. if file.endswith('_raw.gzip'):
  99. df = pd.read_parquet(os.path.join(raw_file_directory, file))
  100. elif file.endswith('_raw.tsv'):
  101. df = pd.read_csv(os.path.join(raw_file_directory, file), sep='\t', index_col=0)
  102. else:
  103. continue
  104. if df.index.is_numeric():
  105. df = df.set_index(df.columns[0])
  106. df = df.sort_index()
  107. df = df.sort_index(axis=1)
  108. df.columns = df.columns.str.replace('_cell_mut', '')
  109. df.columns = df.columns.str.replace('_cell_CN', '')
  110. df.columns = df.columns.str.replace('_cell_exp', '')
  111. # Note that drug_comp raw table has some NA values so we should impute it
  112. if any(df.isna()):
  113. df = pd.DataFrame(SimpleImputer(strategy='mean').fit_transform(df),
  114. columns=df.columns).set_index(df.index)
  115. if file.startswith('drug_comp'): # We need to normalize the drug_data comp dataset
  116. df = ((df - df.mean()) / df.std()).fillna(0)
  117. elif file.startswith('drug_desc'): # We need to normalize the drug_data comp dataset
  118. df = ((df - df.mean()) / df.std()).fillna(0)
  119. if intersect:
  120. if file.startswith('cell'):
  121. if cell_line_names:
  122. cell_line_names = cell_line_names.intersection(set(df.index))
  123. else:
  124. cell_line_names = set(df.index)
  125. elif file.startswith('drug'):
  126. if drug_names:
  127. drug_names = drug_names.intersection(set(df.index))
  128. else:
  129. drug_names = set(df.index)
  130. raw_dict[file[:file.find('_raw')]] = df
  131. if intersect:
  132. for key, value in raw_dict.items():
  133. if key.startswith('cell'):
  134. data = value.loc[list(cell_line_names)]
  135. raw_dict[key] = data.loc[~data.index.duplicated()]
  136. elif key.startswith('drug'):
  137. data = value.loc[list(drug_names)]
  138. raw_dict[key] = data.loc[~data.index.duplicated()]
  139. return raw_dict
  140. @staticmethod
  141. def load_screening_files(filename="AUC_matS_comb.tsv", sep=',', ):
  142. df = pd.read_csv(filename, sep=sep, index_col=0)
  143. # df = df.drop(['Erlotinib','17-AAG','PD-0325901','PHA-665752','PHA-665752','TAE684','Sorafenib','PLX4720','selumetinib','PD-0332991','Paclitaxel','Nilotinib','Saracatinib'],axis=1)
  144. return df
  145. # return pd.read_csv(os.path.join(DATA_FOLDER, "drug_screening_matrix_GDSC.tsv"), sep='\t', index_col=0)
  146. @staticmethod
  147. def adjust_screening_raw(drug_screen, data_dict):
  148. raw_cell_names = []
  149. for key, value in data_dict.items():
  150. if 'cell' in key:
  151. if len(raw_cell_names) == 0:
  152. raw_cell_names = value.index
  153. else:
  154. raw_cell_names = raw_cell_names.intersection(value.index)
  155. elif 'drug' in key:
  156. raw_drug_names = value.index
  157. screening_cell_names = drug_screen.index
  158. screening_drug_names = drug_screen.columns
  159. common_cell_names = list(set(raw_cell_names).intersection(set(screening_cell_names)))
  160. common_drug_names = list(set(raw_drug_names).intersection(set(screening_drug_names)))
  161. for key, value in data_dict.items():
  162. if 'cell' in key:
  163. data_dict[key] = value.loc[common_cell_names]
  164. else:
  165. data_dict[key] = value.loc[common_drug_names]
  166. return drug_screen.loc[common_cell_names, common_drug_names], data_dict
  167. @staticmethod
  168. def prepare_input_data(data_dict, screening):
  169. print('Preparing data...')
  170. resistance = np.argwhere((screening.to_numpy() == 1)).tolist()
  171. resistance.sort(key=lambda x: (x[1], x[0]))
  172. resistance = np.array(resistance)
  173. sensitive = np.argwhere((screening.to_numpy() == -1)).tolist()
  174. sensitive.sort(key=lambda x: (x[1], x[0]))
  175. sensitive = np.array(sensitive)
  176. print("sensitive train data len:", len(sensitive))
  177. print("resistance train data len:", len(resistance))
  178. A_train_mask = np.ones(len(resistance), dtype=bool)
  179. B_train_mask = np.ones(len(sensitive), dtype=bool)
  180. resistance = resistance[A_train_mask]
  181. sensitive = sensitive[B_train_mask]
  182. cell_data_types = list(filter(lambda x: x.startswith('cell'), data_dict.keys()))
  183. cell_data_types.sort()
  184. cell_data = pd.concat(
  185. [pd.DataFrame(data_dict[data_type].add_suffix(f'_{data_type}'), dtype=np.float32) for
  186. data_type in cell_data_types], axis=1)
  187. cell_data_sizes = [data_dict[data_type].shape[1] for data_type in cell_data_types]
  188. drug_data_types = list(filter(lambda x: x.startswith('drug'), data_dict.keys()))
  189. drug_data_types.sort()
  190. drug_data = pd.concat(
  191. [pd.DataFrame(data_dict[data_type].add_suffix(f'_{data_type}'), dtype=np.float32, )
  192. for data_type in drug_data_types], axis=1)
  193. drug_data_sizes = [data_dict[data_type].shape[1] for data_type in drug_data_types]
  194. Xp_cell = cell_data.iloc[resistance[:, 0], :]
  195. Xp_drug = drug_data.iloc[resistance[:, 1], :]
  196. Xp_cell = Xp_cell.reset_index(drop=True)
  197. Xp_drug = Xp_drug.reset_index(drop=True)
  198. Xp_cell.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in resistance]
  199. Xp_drug.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in resistance]
  200. Xn_cell = cell_data.iloc[sensitive[:, 0], :]
  201. Xn_drug = drug_data.iloc[sensitive[:, 1], :]
  202. Xn_cell = Xn_cell.reset_index(drop=True)
  203. Xn_drug = Xn_drug.reset_index(drop=True)
  204. Xn_cell.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in sensitive]
  205. Xn_drug.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in sensitive]
  206. X_cell = pd.concat([Xp_cell, Xn_cell])
  207. X_drug = pd.concat([Xp_drug, Xn_drug])
  208. Y = np.append(np.zeros(resistance.shape[0]), np.ones(sensitive.shape[0]))
  209. return X_cell, X_drug, Y, cell_data_sizes, drug_data_sizes