You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import pandas as pd
  2. import numpy as np
  3. from tqdm import tqdm
  4. from sklearn.impute import SimpleImputer
  5. import os
  6. from utils import DRUG_DATA_FOLDER
  7. class RawDataLoader:
  8. @staticmethod
  9. def load_data(data_modalities, raw_file_directory, screen_file_directory, sep, drug_directory=DRUG_DATA_FOLDER):
  10. """
  11. Load raw data and screening data, perform intersection, and adjust screening data.
  12. Parameters:
  13. - data_modalities (list): List of data modalities to load.
  14. - raw_file_directory (str): Directory containing raw data files.
  15. - screen_file_directory (str): Directory containing screening data files.
  16. - sep (str): Separator used in the data files.
  17. Returns:
  18. - data (dict): Dictionary containing loaded raw data.
  19. - drug_screen (pd.DataFrame): Adjusted and intersected screening data.
  20. """
  21. # Step 1: Load raw data files for specified data modalities
  22. data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities,
  23. raw_file_directory=raw_file_directory)
  24. # Step 2: Load drug data files for specified data modalities
  25. drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities,
  26. raw_file_directory=drug_directory)
  27. # Step 3: Update the 'data' dictionary with drug data
  28. data.update(drug_data)
  29. # Step 4: Load and adjust drug screening data
  30. drug_screen = RawDataLoader.load_screening_files(
  31. filename=screen_file_directory,
  32. sep=sep)
  33. drug_screen, data = RawDataLoader.adjust_screening_raw(
  34. drug_screen=drug_screen, data_dict=data)
  35. # Step 5: Return the loaded data and adjusted drug screening data
  36. return data, drug_screen
  37. @staticmethod
  38. def intersect_features(data1, data2):
  39. """
  40. Perform intersection of features between two datasets.
  41. Parameters:
  42. - data1 (pd.DataFrame): First dataset.
  43. - data2 (pd.DataFrame): Second dataset.
  44. Returns:
  45. - data1 (pd.DataFrame): First dataset with common columns.
  46. - data2 (pd.DataFrame): Second dataset with common columns.
  47. """
  48. # Step 1: Find common columns between the two datasets
  49. common_columns = list(set(data1.columns) & set(data2.columns))
  50. # Step 2: Filter data2 to include only common columns
  51. data2 = data2[common_columns]
  52. # Step 3: Filter data1 to include only common columns
  53. data1 = data1[common_columns]
  54. # Step 4: Return the datasets with intersected features
  55. return data1, data2
  56. @staticmethod
  57. def data_features_intersect(data1, data2):
  58. """
  59. Intersect features between two datasets column-wise.
  60. Parameters:
  61. - data1 (dict): Dictionary containing data modalities.
  62. - data2 (dict): Dictionary containing data modalities.
  63. Returns:
  64. - intersected_data1 (dict): Data1 with intersected features.
  65. - intersected_data2 (dict): Data2 with intersected features.
  66. """
  67. # Iterate over each data modality
  68. for i in data1:
  69. # Intersect features for each modality
  70. data1[i], data2[i] = RawDataLoader.intersect_features(data1[i], data2[i])
  71. return data1, data2
  72. @staticmethod
  73. def load_file(address, index_column=None):
  74. """
  75. Load data from a file based on its format.
  76. Parameters:
  77. - address (str): File address.
  78. - index_column (str): Name of the index column.
  79. Returns:
  80. - data (pd.DataFrame): Loaded data from the file.
  81. """
  82. data = []
  83. try:
  84. # Load data based on file format
  85. if address.endswith('.txt') or address.endswith('.tsv'):
  86. data.append(pd.read_csv(address, sep='\t', index_col=index_column), )
  87. elif address.endswith('.csv'):
  88. data.append(pd.read_csv(address))
  89. elif address.endswith('.xlsx'):
  90. data.append(pd.read_excel(address))
  91. except FileNotFoundError:
  92. print(f'File not found at address: {address}')
  93. return data[0]
  94. @staticmethod
  95. def load_raw_files(raw_file_directory, data_modalities, intersect=True):
  96. raw_dict = {}
  97. files = os.listdir(raw_file_directory)
  98. cell_line_names = None
  99. drug_names = None
  100. for file in tqdm(files, 'Reading Raw Data Files...'):
  101. if any([file.startswith(x) for x in data_modalities]):
  102. if file.endswith('_raw.gzip'):
  103. df = pd.read_parquet(os.path.join(raw_file_directory, file))
  104. elif file.endswith('_raw.tsv'):
  105. df = pd.read_csv(os.path.join(raw_file_directory, file), sep='\t', index_col=0)
  106. else:
  107. continue
  108. if df.index.is_numeric():
  109. df = df.set_index(df.columns[0])
  110. df = df.sort_index()
  111. df = df.sort_index(axis=1)
  112. df.columns = df.columns.str.replace('_cell_mut', '')
  113. df.columns = df.columns.str.replace('_cell_CN', '')
  114. df.columns = df.columns.str.replace('_cell_exp', '')
  115. df = (df - df.min()) / (df.max() - df.min())
  116. df = df.fillna(0)
  117. print("has null:")
  118. print(df.isnull().sum().sum())
  119. if intersect:
  120. if file.startswith('cell'):
  121. if cell_line_names:
  122. cell_line_names = cell_line_names.intersection(set(df.index))
  123. else:
  124. cell_line_names = set(df.index)
  125. elif file.startswith('drug'):
  126. if drug_names:
  127. drug_names = drug_names.intersection(set(df.index))
  128. else:
  129. drug_names = set(df.index)
  130. raw_dict[file[:file.find('_raw')]] = df
  131. if intersect:
  132. for key, value in raw_dict.items():
  133. if key.startswith('cell'):
  134. data = value.loc[list(cell_line_names)]
  135. raw_dict[key] = data.loc[~data.index.duplicated()]
  136. elif key.startswith('drug'):
  137. data = value.loc[list(drug_names)]
  138. raw_dict[key] = data.loc[~data.index.duplicated()]
  139. return raw_dict
  140. @staticmethod
  141. def load_screening_files(filename="AUC_matS_comb.tsv", sep=',', ):
  142. df = pd.read_csv(filename, sep=sep, index_col=0)
  143. return df
  144. @staticmethod
  145. def adjust_screening_raw(drug_screen, data_dict):
  146. raw_cell_names = []
  147. raw_drug_names = []
  148. for key, value in data_dict.items():
  149. if 'cell' in key:
  150. if len(raw_cell_names) == 0:
  151. raw_cell_names = value.index
  152. else:
  153. raw_cell_names = raw_cell_names.intersection(value.index)
  154. elif 'drug' in key:
  155. raw_drug_names = value.index
  156. screening_cell_names = drug_screen.index
  157. screening_drug_names = drug_screen.columns
  158. common_cell_names = list(set(raw_cell_names).intersection(set(screening_cell_names)))
  159. common_drug_names = list(set(raw_drug_names).intersection(set(screening_drug_names)))
  160. for key, value in data_dict.items():
  161. if 'cell' in key:
  162. data_dict[key] = value.loc[common_cell_names]
  163. else:
  164. data_dict[key] = value.loc[common_drug_names]
  165. return drug_screen.loc[common_cell_names, common_drug_names], data_dict
  166. @staticmethod
  167. def prepare_input_data(data_dict, screening):
  168. print('Preparing data...')
  169. resistance = np.argwhere((screening.to_numpy() == 1)).tolist()
  170. resistance.sort(key=lambda x: (x[1], x[0]))
  171. resistance = np.array(resistance)
  172. sensitive = np.argwhere((screening.to_numpy() == -1)).tolist()
  173. sensitive.sort(key=lambda x: (x[1], x[0]))
  174. sensitive = np.array(sensitive)
  175. print("sensitive train data len:", len(sensitive))
  176. print("resistance train data len:", len(resistance))
  177. A_train_mask = np.ones(len(resistance), dtype=bool)
  178. B_train_mask = np.ones(len(sensitive), dtype=bool)
  179. resistance = resistance[A_train_mask]
  180. sensitive = sensitive[B_train_mask]
  181. cell_data_types = list(filter(lambda x: x.startswith('cell'), data_dict.keys()))
  182. cell_data_types.sort()
  183. cell_data = pd.concat(
  184. [pd.DataFrame(data_dict[data_type].add_suffix(f'_{data_type}'), dtype=np.float32) for
  185. data_type in cell_data_types], axis=1)
  186. cell_data_sizes = [data_dict[data_type].shape[1] for data_type in cell_data_types]
  187. drug_data_types = list(filter(lambda x: x.startswith('drug'), data_dict.keys()))
  188. drug_data_types.sort()
  189. drug_data = pd.concat(
  190. [pd.DataFrame(data_dict[data_type].add_suffix(f'_{data_type}'), dtype=np.float32, )
  191. for data_type in drug_data_types], axis=1)
  192. drug_data_sizes = [data_dict[data_type].shape[1] for data_type in drug_data_types]
  193. Xp_cell = cell_data.iloc[resistance[:, 0], :]
  194. Xp_drug = drug_data.iloc[resistance[:, 1], :]
  195. Xp_cell = Xp_cell.reset_index(drop=True)
  196. Xp_drug = Xp_drug.reset_index(drop=True)
  197. Xp_cell.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in resistance]
  198. Xp_drug.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in resistance]
  199. Xn_cell = cell_data.iloc[sensitive[:, 0], :]
  200. Xn_drug = drug_data.iloc[sensitive[:, 1], :]
  201. Xn_cell = Xn_cell.reset_index(drop=True)
  202. Xn_drug = Xn_drug.reset_index(drop=True)
  203. Xn_cell.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in sensitive]
  204. Xn_drug.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in sensitive]
  205. X_cell = pd.concat([Xp_cell, Xn_cell])
  206. X_drug = pd.concat([Xp_drug, Xn_drug])
  207. Y = np.append(np.zeros(resistance.shape[0]), np.ones(sensitive.shape[0]))
  208. return X_cell, X_drug, Y, cell_data_sizes, drug_data_sizes