DeepTraCDR: Prediction Cancer Drug Response using multimodal deep learning with Transformers
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_sampler.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. import torch
  2. import numpy as np
  3. import scipy.sparse as sp
  4. from typing import Tuple, Optional
  5. from utils import to_coo_matrix, to_tensor, mask
  6. class RandomSampler:
  7. """
  8. Samples edges from an adjacency matrix to create train/test sets.
  9. Converts the training set into torch.Tensor format.
  10. """
  11. def __init__(
  12. self,
  13. adj_mat_original: np.ndarray,
  14. train_index: np.ndarray,
  15. test_index: np.ndarray,
  16. null_mask: np.ndarray
  17. ) -> None:
  18. self.adj_mat = to_coo_matrix(adj_mat_original)
  19. self.train_index = train_index
  20. self.test_index = test_index
  21. self.null_mask = null_mask
  22. # Sample positive edges
  23. self.train_pos = self._sample_edges(train_index)
  24. self.test_pos = self._sample_edges(test_index)
  25. # Sample negative edges
  26. self.train_neg, self.test_neg = self._sample_negative_edges()
  27. # Create masks
  28. self.train_mask = mask(self.train_pos, self.train_neg, dtype=int)
  29. self.test_mask = mask(self.test_pos, self.test_neg, dtype=bool)
  30. # Convert to tensors
  31. self.train_data = to_tensor(self.train_pos)
  32. self.test_data = to_tensor(self.test_pos)
  33. def _sample_edges(self, index: np.ndarray) -> sp.coo_matrix:
  34. """Samples edges from the adjacency matrix based on provided indices."""
  35. row = self.adj_mat.row[index]
  36. col = self.adj_mat.col[index]
  37. data = self.adj_mat.data[index]
  38. return sp.coo_matrix(
  39. (data, (row, col)),
  40. shape=self.adj_mat.shape
  41. )
  42. def _sample_negative_edges(self) -> Tuple[sp.coo_matrix, sp.coo_matrix]:
  43. """
  44. Samples negative edges for training and testing.
  45. Negative edges are those not present in the adjacency matrix.
  46. """
  47. pos_adj_mat = self.null_mask + self.adj_mat.toarray()
  48. neg_adj_mat = sp.coo_matrix(np.abs(pos_adj_mat - 1))
  49. all_row, all_col, all_data = neg_adj_mat.row, neg_adj_mat.col, neg_adj_mat.data
  50. indices = np.arange(all_data.shape[0])
  51. # Sample negative test edges
  52. test_n = self.test_index.shape[0]
  53. test_neg_indices = np.random.choice(indices, test_n, replace=False)
  54. test_row, test_col, test_data = (
  55. all_row[test_neg_indices],
  56. all_col[test_neg_indices],
  57. all_data[test_neg_indices]
  58. )
  59. test_neg = sp.coo_matrix(
  60. (test_data, (test_row, test_col)),
  61. shape=self.adj_mat.shape
  62. )
  63. # Sample negative train edges
  64. train_neg_indices = np.delete(indices, test_neg_indices)
  65. train_row, train_col, train_data = (
  66. all_row[train_neg_indices],
  67. all_col[train_neg_indices],
  68. all_data[train_neg_indices]
  69. )
  70. train_neg = sp.coo_matrix(
  71. (train_data, (train_row, train_col)),
  72. shape=self.adj_mat.shape
  73. )
  74. return train_neg, test_neg
  75. class NewSampler:
  76. """
  77. Samples train/test data and masks for a specific target dimension/index.
  78. """
  79. def __init__(
  80. self,
  81. original_adj_mat: np.ndarray,
  82. null_mask: np.ndarray,
  83. target_dim: Optional[int],
  84. target_index: int
  85. ) -> None:
  86. self.adj_mat = original_adj_mat
  87. self.null_mask = null_mask
  88. self.dim = target_dim
  89. self.target_index = target_index
  90. self.train_data, self.test_data = self._sample_train_test_data()
  91. self.train_mask, self.test_mask = self._sample_train_test_mask()
  92. def _sample_target_test_index(self) -> np.ndarray:
  93. """Samples indices for positive test edges based on target dimension."""
  94. if self.dim:
  95. return np.where(self.adj_mat[:, self.target_index] == 1)[0]
  96. return np.where(self.adj_mat[self.target_index, :] == 1)[0]
  97. def _sample_train_test_data(self) -> Tuple[torch.Tensor, torch.Tensor]:
  98. """Samples train and test data based on target indices."""
  99. test_data = np.zeros(self.adj_mat.shape, dtype=np.float32)
  100. test_index = self._sample_target_test_index()
  101. if self.dim:
  102. test_data[test_index, self.target_index] = 1
  103. else:
  104. test_data[self.target_index, test_index] = 1
  105. train_data = self.adj_mat - test_data
  106. return torch.from_numpy(train_data), torch.from_numpy(test_data)
  107. def _sample_train_test_mask(self) -> Tuple[torch.Tensor, torch.Tensor]:
  108. """Creates train and test masks, including negative sampling."""
  109. test_index = self._sample_target_test_index()
  110. neg_value = np.ones(self.adj_mat.shape, dtype=np.float32) - self.adj_mat - self.null_mask
  111. neg_test_mask = np.zeros(self.adj_mat.shape, dtype=np.float32)
  112. if self.dim:
  113. target_neg_index = np.where(neg_value[:, self.target_index] == 1)[0]
  114. else:
  115. target_neg_index = np.where(neg_value[self.target_index, :] == 1)[0]
  116. target_neg_test_index = (
  117. np.random.choice(target_neg_index, len(test_index), replace=False)
  118. if len(test_index) < len(target_neg_index)
  119. else target_neg_index
  120. )
  121. if self.dim:
  122. neg_test_mask[target_neg_test_index, self.target_index] = 1
  123. neg_value[:, self.target_index] = 0
  124. else:
  125. neg_test_mask[self.target_index, target_neg_test_index] = 1
  126. neg_value[self.target_index, :] = 0
  127. train_mask = (self.train_data.numpy() + neg_value).astype(bool)
  128. test_mask = (self.test_data.numpy() + neg_test_mask).astype(bool)
  129. return torch.from_numpy(train_mask), torch.from_numpy(test_mask)
  130. class SingleSampler:
  131. """
  132. Samples train/test data and masks for a specific target index.
  133. Returns results as torch.Tensor.
  134. """
  135. def __init__(
  136. self,
  137. origin_adj_mat: np.ndarray,
  138. null_mask: np.ndarray,
  139. target_index: int,
  140. train_index: np.ndarray,
  141. test_index: np.ndarray
  142. ) -> None:
  143. self.adj_mat = origin_adj_mat
  144. self.null_mask = null_mask
  145. self.target_index = target_index
  146. self.train_index = train_index
  147. self.test_index = test_index
  148. self.train_data, self.test_data = self._sample_train_test_data()
  149. self.train_mask, self.test_mask = self._sample_train_test_mask()
  150. def _sample_train_test_data(self) -> Tuple[torch.Tensor, torch.Tensor]:
  151. """Samples train and test data for the target index."""
  152. test_data = np.zeros(self.adj_mat.shape, dtype=np.float32)
  153. test_data[self.test_index, self.target_index] = 1
  154. train_data = self.adj_mat - test_data
  155. return torch.from_numpy(train_data), torch.from_numpy(test_data)
  156. def _sample_train_test_mask(self) -> Tuple[torch.Tensor, torch.Tensor]:
  157. """Creates train and test masks with negative sampling."""
  158. neg_value = np.ones(self.adj_mat.shape, dtype=np.float32) - self.adj_mat - self.null_mask
  159. neg_test_mask = np.zeros(self.adj_mat.shape, dtype=np.float32)
  160. target_neg_index = np.where(neg_value[:, self.target_index] == 1)[0]
  161. target_neg_test_index = np.random.choice(target_neg_index, len(self.test_index), replace=False)
  162. neg_test_mask[target_neg_test_index, self.target_index] = 1
  163. neg_value[target_neg_test_index, self.target_index] = 0
  164. train_mask = (self.train_data.numpy() + neg_value).astype(bool)
  165. test_mask = (self.test_data.numpy() + neg_test_mask).astype(bool)
  166. return torch.from_numpy(train_mask), torch.from_numpy(test_mask)
  167. class TargetSampler(object):
  168. """
  169. Samples train/test data and masks for multiple target indices.
  170. """
  171. def __init__(self, response_mat: np.ndarray, null_mask: np.ndarray, target_indexes: np.ndarray,
  172. pos_train_index: np.ndarray, pos_test_index: np.ndarray):
  173. self.response_mat = response_mat
  174. self.null_mask = null_mask
  175. self.target_indexes = target_indexes
  176. self.pos_train_index = pos_train_index
  177. self.pos_test_index = pos_test_index
  178. self.train_data, self.test_data = self.sample_train_test_data()
  179. self.train_mask, self.test_mask = self.sample_train_test_mask()
  180. def sample_train_test_data(self):
  181. n_target = self.target_indexes.shape[0]
  182. target_response = self.response_mat[:, self.target_indexes].reshape((-1, n_target))
  183. train_data = self.response_mat.copy()
  184. train_data[:, self.target_indexes] = 0
  185. target_pos_value = sp.coo_matrix(target_response)
  186. target_train_data = sp.coo_matrix((target_pos_value.data[self.pos_train_index],
  187. (target_pos_value.row[self.pos_train_index],
  188. target_pos_value.col[self.pos_train_index])),
  189. shape=target_response.shape).toarray()
  190. target_test_data = sp.coo_matrix((target_pos_value.data[self.pos_test_index],
  191. (target_pos_value.row[self.pos_test_index],
  192. target_pos_value.col[self.pos_test_index])),
  193. shape=target_response.shape).toarray()
  194. test_data = np.zeros(self.response_mat.shape, dtype=np.float32)
  195. for i, value in enumerate(self.target_indexes):
  196. train_data[:, value] = target_train_data[:, i]
  197. test_data[:, value] = target_test_data[:, i]
  198. train_data = torch.from_numpy(train_data)
  199. test_data = torch.from_numpy(test_data)
  200. return train_data, test_data
  201. def sample_train_test_mask(self):
  202. target_response = self.response_mat[:, self.target_indexes]
  203. target_ones = np.ones(target_response.shape, dtype=np.float32)
  204. target_neg_value = target_ones - target_response - self.null_mask[:, self.target_indexes]
  205. target_neg_value = sp.coo_matrix(target_neg_value)
  206. ids = np.arange(target_neg_value.data.shape[0])
  207. target_neg_test_index = np.random.choice(ids, self.pos_test_index.shape[0], replace=False)
  208. target_neg_test_mask = sp.coo_matrix((target_neg_value.data[target_neg_test_index],
  209. (target_neg_value.row[target_neg_test_index],
  210. target_neg_value.col[target_neg_test_index])),
  211. shape=target_response.shape).toarray()
  212. neg_test_mask = np.zeros(self.response_mat.shape, dtype=np.float32)
  213. for i, value in enumerate(self.target_indexes):
  214. neg_test_mask[:, value] = target_neg_test_mask[:, i]
  215. other_neg_value = np.ones(self.response_mat.shape,
  216. dtype=np.float32) - neg_test_mask - self.response_mat - self.null_mask
  217. test_mask = (self.test_data.numpy() + neg_test_mask).astype(bool)
  218. train_mask = (self.train_data.numpy() + other_neg_value).astype(bool)
  219. test_mask = torch.from_numpy(test_mask)
  220. train_mask = torch.from_numpy(train_mask)
  221. return train_mask, test_mask
  222. class ExterSampler:
  223. """
  224. Samples train/test data and masks based on row indices.
  225. """
  226. def __init__(
  227. self,
  228. original_adj_mat: np.ndarray,
  229. null_mask: np.ndarray,
  230. train_index: np.ndarray,
  231. test_index: np.ndarray
  232. ) -> None:
  233. self.adj_mat = original_adj_mat
  234. self.null_mask = null_mask
  235. self.train_index = train_index
  236. self.test_index = test_index
  237. self.train_data, self.test_data = self._sample_train_test_data()
  238. self.train_mask, self.test_mask = self._sample_train_test_mask()
  239. def _sample_train_test_data(self) -> Tuple[torch.Tensor, torch.Tensor]:
  240. """Samples train and test data based on row indices."""
  241. test_data = self.adj_mat.copy()
  242. test_data[self.train_index, :] = 0
  243. train_data = self.adj_mat - test_data
  244. return torch.from_numpy(train_data), torch.from_numpy(test_data)
  245. def _sample_train_test_mask(self) -> Tuple[torch.Tensor, torch.Tensor]:
  246. """Creates train and test masks with negative sampling."""
  247. neg_value = np.ones(self.adj_mat.shape, dtype=np.float32) - self.adj_mat - self.null_mask
  248. neg_train = neg_value.copy()
  249. neg_train[self.test_index, :] = 0
  250. neg_test = neg_value.copy()
  251. neg_test[self.train_index, :] = 0
  252. train_mask = (self.train_data.numpy() + neg_train).astype(bool)
  253. test_mask = (self.test_data.numpy() + neg_test).astype(bool)
  254. return torch.from_numpy(train_mask), torch.from_numpy(test_mask)
  255. class RegressionSampler(object):
  256. def __init__(self, adj_mat_original, train_index, test_index, null_mask):
  257. super(RegressionSampler, self).__init__()
  258. if isinstance(adj_mat_original, torch.Tensor):
  259. adj_mat_np = adj_mat_original.cpu().numpy()
  260. else:
  261. adj_mat_np = adj_mat_original.copy()
  262. self.full_data = torch.FloatTensor(adj_mat_np)
  263. rows, cols = adj_mat_np.shape
  264. train_mask = np.zeros((rows, cols), dtype=bool)
  265. test_mask = np.zeros((rows, cols), dtype=bool)
  266. for idx in train_index:
  267. row = idx // cols
  268. col = idx % cols
  269. if not null_mask[row, col]:
  270. train_mask[row, col] = True
  271. for idx in test_index:
  272. row = idx // cols
  273. col = idx % cols
  274. if not null_mask[row, col]:
  275. test_mask[row, col] = True
  276. self.train_mask = torch.BoolTensor(train_mask)
  277. self.test_mask = torch.BoolTensor(test_mask)
  278. self.train_data = self.full_data.clone()
  279. self.test_data = self.full_data.clone()
  280. assert not torch.any(self.train_mask & self.test_mask), "Train and test masks have overlap!"
  281. def get_train_indices(self):
  282. indices = torch.nonzero(self.train_mask)
  283. return indices
  284. def get_test_indices(self):
  285. indices = torch.nonzero(self.test_mask)
  286. return indices