You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long. 10KB

  1. import pandas as pd
  2. import numpy as np
  3. from tqdm import tqdm
  4. from sklearn.impute import SimpleImputer
  5. import os
  6. from utils import DRUG_DATA_FOLDER
  7. class RawDataLoader:
  8. @staticmethod
  9. def load_data(data_modalities, raw_file_directory, screen_file_directory, sep, drug_directory=DRUG_DATA_FOLDER):
  10. """
  11. Load raw data and screening data, perform intersection, and adjust screening data.
  12. Parameters:
  13. - data_modalities (list): List of data modalities to load.
  14. - raw_file_directory (str): Directory containing raw data files.
  15. - screen_file_directory (str): Directory containing screening data files.
  16. - sep (str): Separator used in the data files.
  17. Returns:
  18. - data (dict): Dictionary containing loaded raw data.
  19. - drug_screen (pd.DataFrame): Adjusted and intersected screening data.
  20. """
  21. # Step 1: Load raw data files for specified data modalities
  22. data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities,
  23. raw_file_directory=raw_file_directory)
  24. # Step 2: Load drug data files for specified data modalities
  25. drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities,
  26. raw_file_directory=drug_directory)
  27. # Step 3: Update the 'data' dictionary with drug data
  28. data.update(drug_data)
  29. # Step 4: Load and adjust drug screening data
  30. drug_screen = RawDataLoader.load_screening_files(
  31. filename=screen_file_directory,
  32. sep=sep)
  33. drug_screen, data = RawDataLoader.adjust_screening_raw(
  34. drug_screen=drug_screen, data_dict=data)
  35. # Step 5: Return the loaded data and adjusted drug screening data
  36. return data, drug_screen
  37. @staticmethod
  38. def intersect_features(data1, data2):
  39. """
  40. Perform intersection of features between two datasets.
  41. Parameters:
  42. - data1 (pd.DataFrame): First dataset.
  43. - data2 (pd.DataFrame): Second dataset.
  44. Returns:
  45. - data1 (pd.DataFrame): First dataset with common columns.
  46. - data2 (pd.DataFrame): Second dataset with common columns.
  47. """
  48. # Step 1: Find common columns between the two datasets
  49. common_columns = list(set(data1.columns) & set(data2.columns))
  50. # Step 2: Filter data2 to include only common columns
  51. data2 = data2[common_columns]
  52. # Step 3: Filter data1 to include only common columns
  53. data1 = data1[common_columns]
  54. # Step 4: Return the datasets with intersected features
  55. return data1, data2
  56. @staticmethod
  57. def data_features_intersect(data1, data2):
  58. """
  59. Intersect features between two datasets column-wise.
  60. Parameters:
  61. - data1 (dict): Dictionary containing data modalities.
  62. - data2 (dict): Dictionary containing data modalities.
  63. Returns:
  64. - intersected_data1 (dict): Data1 with intersected features.
  65. - intersected_data2 (dict): Data2 with intersected features.
  66. """
  67. # Iterate over each data modality
  68. for i in data1:
  69. # Intersect features for each modality
  70. data1[i], data2[i] = RawDataLoader.intersect_features(data1[i], data2[i])
  71. return data1, data2
  72. @staticmethod
  73. def load_file(address, index_column=None):
  74. """
  75. Load data from a file based on its format.
  76. Parameters:
  77. - address (str): File address.
  78. - index_column (str): Name of the index column.
  79. Returns:
  80. - data (pd.DataFrame): Loaded data from the file.
  81. """
  82. data = []
  83. try:
  84. # Load data based on file format
  85. if address.endswith('.txt') or address.endswith('.tsv'):
  86. data.append(pd.read_csv(address, sep='\t', index_col=index_column), )
  87. elif address.endswith('.csv'):
  88. data.append(pd.read_csv(address))
  89. elif address.endswith('.xlsx'):
  90. data.append(pd.read_excel(address))
  91. except FileNotFoundError:
  92. print(f'File not found at address: {address}')
  93. return data[0]
  94. @staticmethod
  95. def load_raw_files(raw_file_directory, data_modalities, intersect=True):
  96. raw_dict = {}
  97. files = os.listdir(raw_file_directory)
  98. cell_line_names = None
  99. drug_names = None
  100. for file in tqdm(files, 'Reading Raw Data Files...'):
  101. if any([file.startswith(x) for x in data_modalities]):
  102. if file.endswith('_raw.gzip'):
  103. df = pd.read_parquet(os.path.join(raw_file_directory, file))
  104. elif file.endswith('_raw.tsv'):
  105. df = pd.read_csv(os.path.join(raw_file_directory, file), sep='\t', index_col=0)
  106. else:
  107. continue
  108. if df.index.is_numeric():
  109. df = df.set_index(df.columns[0])
  110. df = df.sort_index()
  111. df = df.sort_index(axis=1)
  112. df.columns = df.columns.str.replace('_cell_mut', '')
  113. df.columns = df.columns.str.replace('_cell_CN', '')
  114. df.columns = df.columns.str.replace('_cell_exp', '')
  115. df = (df - df.min()) / (df.max() - df.min())
  116. df = df.fillna(0)
  117. if intersect:
  118. if file.startswith('cell'):
  119. if cell_line_names:
  120. cell_line_names = cell_line_names.intersection(set(df.index))
  121. else:
  122. cell_line_names = set(df.index)
  123. elif file.startswith('drug'):
  124. if drug_names:
  125. drug_names = drug_names.intersection(set(df.index))
  126. else:
  127. drug_names = set(df.index)
  128. raw_dict[file[:file.find('_raw')]] = df
  129. if intersect:
  130. for key, value in raw_dict.items():
  131. if key.startswith('cell'):
  132. data = value.loc[list(cell_line_names)]
  133. raw_dict[key] = data.loc[~data.index.duplicated()]
  134. elif key.startswith('drug'):
  135. data = value.loc[list(drug_names)]
  136. raw_dict[key] = data.loc[~data.index.duplicated()]
  137. return raw_dict
  138. @staticmethod
  139. def load_screening_files(filename="AUC_matS_comb.tsv", sep=',', ):
  140. df = pd.read_csv(filename, sep=sep, index_col=0)
  141. return df
  142. @staticmethod
  143. def adjust_screening_raw(drug_screen, data_dict):
  144. raw_cell_names = []
  145. raw_drug_names = []
  146. for key, value in data_dict.items():
  147. if 'cell' in key:
  148. if len(raw_cell_names) == 0:
  149. raw_cell_names = value.index
  150. else:
  151. raw_cell_names = raw_cell_names.intersection(value.index)
  152. elif 'drug' in key:
  153. raw_drug_names = value.index
  154. screening_cell_names = drug_screen.index
  155. screening_drug_names = drug_screen.columns
  156. common_cell_names = list(set(raw_cell_names).intersection(set(screening_cell_names)))
  157. common_drug_names = list(set(raw_drug_names).intersection(set(screening_drug_names)))
  158. for key, value in data_dict.items():
  159. if 'cell' in key:
  160. data_dict[key] = value.loc[common_cell_names]
  161. else:
  162. data_dict[key] = value.loc[common_drug_names]
  163. return drug_screen.loc[common_cell_names, common_drug_names], data_dict
  164. @staticmethod
  165. def prepare_input_data(data_dict, screening):
  166. print('Preparing data...')
  167. resistance = np.argwhere((screening.to_numpy() == 1)).tolist()
  168. resistance.sort(key=lambda x: (x[1], x[0]))
  169. resistance = np.array(resistance)
  170. sensitive = np.argwhere((screening.to_numpy() == -1)).tolist()
  171. sensitive.sort(key=lambda x: (x[1], x[0]))
  172. sensitive = np.array(sensitive)
  173. print("sensitive train data len:", len(sensitive))
  174. print("resistance train data len:", len(resistance))
  175. A_train_mask = np.ones(len(resistance), dtype=bool)
  176. B_train_mask = np.ones(len(sensitive), dtype=bool)
  177. resistance = resistance[A_train_mask]
  178. sensitive = sensitive[B_train_mask]
  179. cell_data_types = list(filter(lambda x: x.startswith('cell'), data_dict.keys()))
  180. cell_data_types.sort()
  181. cell_data = pd.concat(
  182. [pd.DataFrame(data_dict[data_type].add_suffix(f'_{data_type}'), dtype=np.float32) for
  183. data_type in cell_data_types], axis=1)
  184. cell_data_sizes = [data_dict[data_type].shape[1] for data_type in cell_data_types]
  185. drug_data_types = list(filter(lambda x: x.startswith('drug'), data_dict.keys()))
  186. drug_data_types.sort()
  187. drug_data = pd.concat(
  188. [pd.DataFrame(data_dict[data_type].add_suffix(f'_{data_type}'), dtype=np.float32, )
  189. for data_type in drug_data_types], axis=1)
  190. drug_data_sizes = [data_dict[data_type].shape[1] for data_type in drug_data_types]
  191. Xp_cell = cell_data.iloc[resistance[:, 0], :]
  192. Xp_drug = drug_data.iloc[resistance[:, 1], :]
  193. Xp_cell = Xp_cell.reset_index(drop=True)
  194. Xp_drug = Xp_drug.reset_index(drop=True)
  195. Xp_cell.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in resistance]
  196. Xp_drug.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in resistance]
  197. Xn_cell = cell_data.iloc[sensitive[:, 0], :]
  198. Xn_drug = drug_data.iloc[sensitive[:, 1], :]
  199. Xn_cell = Xn_cell.reset_index(drop=True)
  200. Xn_drug = Xn_drug.reset_index(drop=True)
  201. Xn_cell.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in sensitive]
  202. Xn_drug.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in sensitive]
  203. X_cell = pd.concat([Xp_cell, Xn_cell])
  204. X_drug = pd.concat([Xp_drug, Xn_drug])
  205. Y = np.append(np.zeros(resistance.shape[0]), np.ones(sensitive.shape[0]))
  206. return X_cell, X_drug, Y, cell_data_sizes, drug_data_sizes