You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 10KB

11 months ago
11 months ago
11 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import pandas as pd
  2. import numpy as np
  3. from tqdm import tqdm
  4. from sklearn.impute import SimpleImputer
  5. import os
  6. from utils import DRUG_DATA_FOLDER
  7. class RawDataLoader:
  8. @staticmethod
  9. def load_data(data_modalities, raw_file_directory, screen_file_directory, sep, drug_directory=DRUG_DATA_FOLDER):
  10. """
  11. Load raw data and screening data, perform intersection, and adjust screening data.
  12. Parameters:
  13. - data_modalities (list): List of data modalities to load.
  14. - raw_file_directory (str): Directory containing raw data files.
  15. - screen_file_directory (str): Directory containing screening data files.
  16. - sep (str): Separator used in the data files.
  17. Returns:
  18. - data (dict): Dictionary containing loaded raw data.
  19. - drug_screen (pd.DataFrame): Adjusted and intersected screening data.
  20. """
  21. # Step 1: Load raw data files for specified data modalities
  22. data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities,
  23. raw_file_directory=raw_file_directory)
  24. # Step 2: Load drug data files for specified data modalities
  25. drug_data = RawDataLoader.load_raw_files(intersect=True, data_modalities=data_modalities,
  26. raw_file_directory=drug_directory)
  27. # Step 3: Update the 'data' dictionary with drug data
  28. data.update(drug_data)
  29. # Step 4: Load and adjust drug screening data
  30. drug_screen = RawDataLoader.load_screening_files(
  31. filename=screen_file_directory,
  32. sep=sep)
  33. drug_screen, data = RawDataLoader.adjust_screening_raw(
  34. drug_screen=drug_screen, data_dict=data)
  35. # Step 5: Return the loaded data and adjusted drug screening data
  36. return data, drug_screen
  37. @staticmethod
  38. def intersect_features(data1, data2):
  39. """
  40. Perform intersection of features between two datasets.
  41. Parameters:
  42. - data1 (pd.DataFrame): First dataset.
  43. - data2 (pd.DataFrame): Second dataset.
  44. Returns:
  45. - data1 (pd.DataFrame): First dataset with common columns.
  46. - data2 (pd.DataFrame): Second dataset with common columns.
  47. """
  48. # Step 1: Find common columns between the two datasets
  49. common_columns = list(set(data1.columns) & set(data2.columns))
  50. # Step 2: Filter data2 to include only common columns
  51. data2 = data2[common_columns]
  52. # Step 3: Filter data1 to include only common columns
  53. data1 = data1[common_columns]
  54. # Step 4: Return the datasets with intersected features
  55. return data1, data2
  56. @staticmethod
  57. def data_features_intersect(data1, data2):
  58. """
  59. Intersect features between two datasets column-wise.
  60. Parameters:
  61. - data1 (dict): Dictionary containing data modalities.
  62. - data2 (dict): Dictionary containing data modalities.
  63. Returns:
  64. - intersected_data1 (dict): Data1 with intersected features.
  65. - intersected_data2 (dict): Data2 with intersected features.
  66. """
  67. # Iterate over each data modality
  68. for i in data1:
  69. # Intersect features for each modality
  70. data1[i], data2[i] = RawDataLoader.intersect_features(data1[i], data2[i])
  71. return data1, data2
  72. @staticmethod
  73. def load_file(address, index_column=None):
  74. """
  75. Load data from a file based on its format.
  76. Parameters:
  77. - address (str): File address.
  78. - index_column (str): Name of the index column.
  79. Returns:
  80. - data (pd.DataFrame): Loaded data from the file.
  81. """
  82. data = []
  83. try:
  84. # Load data based on file format
  85. if address.endswith('.txt') or address.endswith('.tsv'):
  86. data.append(pd.read_csv(address, sep='\t', index_col=index_column), )
  87. elif address.endswith('.csv'):
  88. data.append(pd.read_csv(address))
  89. elif address.endswith('.xlsx'):
  90. data.append(pd.read_excel(address))
  91. except FileNotFoundError:
  92. print(f'File not found at address: {address}')
  93. return data[0]
  94. @staticmethod
  95. def load_raw_files(raw_file_directory, data_modalities, intersect=True):
  96. raw_dict = {}
  97. files = os.listdir(raw_file_directory)
  98. cell_line_names = None
  99. drug_names = None
  100. for file in tqdm(files, 'Reading Raw Data Files...'):
  101. if any([file.startswith(x) for x in data_modalities]):
  102. if file.endswith('_raw.gzip'):
  103. df = pd.read_parquet(os.path.join(raw_file_directory, file))
  104. elif file.endswith('_raw.tsv'):
  105. df = pd.read_csv(os.path.join(raw_file_directory, file), sep='\t', index_col=0)
  106. else:
  107. continue
  108. if df.index.is_numeric():
  109. df = df.set_index(df.columns[0])
  110. df = df.sort_index()
  111. df = df.sort_index(axis=1)
  112. df.columns = df.columns.str.replace('_cell_mut', '')
  113. df.columns = df.columns.str.replace('_cell_CN', '')
  114. df.columns = df.columns.str.replace('_cell_exp', '')
  115. df = (df - df.min()) / (df.max() - df.min())
  116. df = df.fillna(0)
  117. if intersect:
  118. if file.startswith('cell'):
  119. if cell_line_names:
  120. cell_line_names = cell_line_names.intersection(set(df.index))
  121. else:
  122. cell_line_names = set(df.index)
  123. elif file.startswith('drug'):
  124. if drug_names:
  125. drug_names = drug_names.intersection(set(df.index))
  126. else:
  127. drug_names = set(df.index)
  128. raw_dict[file[:file.find('_raw')]] = df
  129. if intersect:
  130. for key, value in raw_dict.items():
  131. if key.startswith('cell'):
  132. data = value.loc[list(cell_line_names)]
  133. raw_dict[key] = data.loc[~data.index.duplicated()]
  134. elif key.startswith('drug'):
  135. data = value.loc[list(drug_names)]
  136. raw_dict[key] = data.loc[~data.index.duplicated()]
  137. return raw_dict
  138. @staticmethod
  139. def load_screening_files(filename="AUC_matS_comb.tsv", sep=',', ):
  140. df = pd.read_csv(filename, sep=sep, index_col=0)
  141. return df
  142. @staticmethod
  143. def adjust_screening_raw(drug_screen, data_dict):
  144. raw_cell_names = []
  145. raw_drug_names = []
  146. for key, value in data_dict.items():
  147. if 'cell' in key:
  148. if len(raw_cell_names) == 0:
  149. raw_cell_names = value.index
  150. else:
  151. raw_cell_names = raw_cell_names.intersection(value.index)
  152. elif 'drug' in key:
  153. raw_drug_names = value.index
  154. screening_cell_names = drug_screen.index
  155. screening_drug_names = drug_screen.columns
  156. common_cell_names = list(set(raw_cell_names).intersection(set(screening_cell_names)))
  157. common_drug_names = list(set(raw_drug_names).intersection(set(screening_drug_names)))
  158. for key, value in data_dict.items():
  159. if 'cell' in key:
  160. data_dict[key] = value.loc[common_cell_names]
  161. else:
  162. data_dict[key] = value.loc[common_drug_names]
  163. return drug_screen.loc[common_cell_names, common_drug_names], data_dict
  164. @staticmethod
  165. def prepare_input_data(data_dict, screening):
  166. print('Preparing data...')
  167. resistance = np.argwhere((screening.to_numpy() == 1)).tolist()
  168. resistance.sort(key=lambda x: (x[1], x[0]))
  169. resistance = np.array(resistance)
  170. sensitive = np.argwhere((screening.to_numpy() == -1)).tolist()
  171. sensitive.sort(key=lambda x: (x[1], x[0]))
  172. sensitive = np.array(sensitive)
  173. print("sensitive train data len:", len(sensitive))
  174. print("resistance train data len:", len(resistance))
  175. A_train_mask = np.ones(len(resistance), dtype=bool)
  176. B_train_mask = np.ones(len(sensitive), dtype=bool)
  177. resistance = resistance[A_train_mask]
  178. sensitive = sensitive[B_train_mask]
  179. cell_data_types = list(filter(lambda x: x.startswith('cell'), data_dict.keys()))
  180. cell_data_types.sort()
  181. cell_data = pd.concat(
  182. [pd.DataFrame(data_dict[data_type].add_suffix(f'_{data_type}'), dtype=np.float32) for
  183. data_type in cell_data_types], axis=1)
  184. cell_data_sizes = [data_dict[data_type].shape[1] for data_type in cell_data_types]
  185. drug_data_types = list(filter(lambda x: x.startswith('drug'), data_dict.keys()))
  186. drug_data_types.sort()
  187. drug_data = pd.concat(
  188. [pd.DataFrame(data_dict[data_type].add_suffix(f'_{data_type}'), dtype=np.float32, )
  189. for data_type in drug_data_types], axis=1)
  190. drug_data_sizes = [data_dict[data_type].shape[1] for data_type in drug_data_types]
  191. Xp_cell = cell_data.iloc[resistance[:, 0], :]
  192. Xp_drug = drug_data.iloc[resistance[:, 1], :]
  193. Xp_cell = Xp_cell.reset_index(drop=True)
  194. Xp_drug = Xp_drug.reset_index(drop=True)
  195. Xp_cell.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in resistance]
  196. Xp_drug.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in resistance]
  197. Xn_cell = cell_data.iloc[sensitive[:, 0], :]
  198. Xn_drug = drug_data.iloc[sensitive[:, 1], :]
  199. Xn_cell = Xn_cell.reset_index(drop=True)
  200. Xn_drug = Xn_drug.reset_index(drop=True)
  201. Xn_cell.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in sensitive]
  202. Xn_drug.index = [f'({screening.index[x[0]]},{screening.columns[x[1]]})' for x in sensitive]
  203. X_cell = pd.concat([Xp_cell, Xn_cell])
  204. X_drug = pd.concat([Xp_drug, Xn_drug])
  205. Y = np.append(np.zeros(resistance.shape[0]), np.ones(sensitive.shape[0]))
  206. return X_cell, X_drug, Y, cell_data_sizes, drug_data_sizes