You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class PPI2Cell(nn.Module):
  5. def __init__(self, n_cell: int, ppi_emb: torch.Tensor, bias=True):
  6. super(PPI2Cell, self).__init__()
  7. self.n_cell = n_cell
  8. self.cell_emb = nn.Embedding(n_cell, ppi_emb.shape[1], max_norm=1.0, norm_type=2.0)
  9. if bias:
  10. self.bias = nn.Parameter(torch.randn((1, ppi_emb.shape[0])), requires_grad=True)
  11. else:
  12. self.bias = 0
  13. self.ppi_emb = ppi_emb.permute(1, 0)
  14. def forward(self, x: torch.Tensor):
  15. x = x.squeeze(dim=1)
  16. emb = self.cell_emb(x)
  17. y = emb.mm(self.ppi_emb)
  18. y += self.bias
  19. return y
  20. class PPI2CellV2(nn.Module):
  21. def __init__(self, n_cell: int, ppi_emb: torch.Tensor, hidden_dim: int, bias=True):
  22. super(PPI2CellV2, self).__init__()
  23. self.n_cell = n_cell
  24. self.projector = nn.Sequential(
  25. nn.Linear(ppi_emb.shape[1], hidden_dim, bias=bias),
  26. nn.LeakyReLU()
  27. )
  28. self.cell_emb = nn.Embedding(n_cell, hidden_dim, max_norm=1.0, norm_type=2.0)
  29. self.ppi_emb = ppi_emb
  30. def forward(self, x: torch.Tensor):
  31. x = x.squeeze(dim=1)
  32. proj = self.projector(self.ppi_emb).permute(1, 0)
  33. emb = self.cell_emb(x)
  34. y = emb.mm(proj)
  35. return y
  36. class SynEmb(nn.Module):
  37. def __init__(self, n_drug: int, drug_dim: int, n_cell: int, cell_dim: int, hidden_dim: int):
  38. super(SynEmb, self).__init__()
  39. self.drug_emb = nn.Embedding(n_drug, drug_dim, max_norm=1)
  40. self.cell_emb = nn.Embedding(n_cell, cell_dim, max_norm=1)
  41. self.network = DNN(2 * drug_dim + cell_dim, hidden_dim)
  42. def forward(self, drug1, drug2, cell):
  43. d1 = self.drug_emb(drug1).squeeze(1)
  44. d2 = self.drug_emb(drug2).squeeze(1)
  45. c = self.cell_emb(cell).squeeze(1)
  46. return self.network(d1, d2, c)
  47. class AutoEncoder(nn.Module):
  48. def __init__(self, input_size: int, latent_size: int):
  49. super(AutoEncoder, self).__init__()
  50. self.encoder = nn.Sequential(
  51. nn.Linear(input_size, input_size // 2),
  52. nn.ReLU(),
  53. nn.Linear(input_size // 2, input_size // 4),
  54. nn.ReLU(),
  55. nn.Linear(input_size // 4, latent_size)
  56. )
  57. self.decoder = nn.Sequential(
  58. nn.Linear(latent_size, input_size // 4),
  59. nn.ReLU(),
  60. nn.Linear(input_size // 4, input_size // 2),
  61. nn.ReLU(),
  62. nn.Linear(input_size // 2, input_size)
  63. )
  64. def forward(self, x: torch.Tensor):
  65. encoded = self.encoder(x)
  66. decoded = self.decoder(encoded)
  67. return encoded, decoded
  68. class GeneExpressionAE(nn.Module):
  69. def __init__(self, input_size: int, latent_size: int):
  70. super(GeneExpressionAE, self).__init__()
  71. self.encoder = nn.Sequential(
  72. nn.Linear(input_size, 2048),
  73. nn.Tanh(),
  74. nn.Linear(2048, 1024),
  75. nn.Tanh(),
  76. nn.Linear(1024, latent_size),
  77. nn.Tanh()
  78. )
  79. self.decoder = nn.Sequential(
  80. nn.Linear(latent_size, 1024),
  81. nn.Tanh(),
  82. nn.Linear(1024, 2048),
  83. nn.Tanh(),
  84. nn.Linear(2048, input_size),
  85. nn.Tanh()
  86. )
  87. def forward(self, x: torch.Tensor):
  88. encoded = self.encoder(x)
  89. decoded = self.decoder(encoded)
  90. return encoded, decoded
  91. class DrugFeatAE(nn.Module):
  92. def __init__(self, input_size: int, latent_size: int):
  93. super(DrugFeatAE, self).__init__()
  94. self.encoder = nn.Sequential(
  95. nn.Linear(input_size, 128),
  96. nn.ReLU(),
  97. nn.Linear(128, latent_size),
  98. nn.Sigmoid(),
  99. )
  100. self.decoder = nn.Sequential(
  101. nn.Linear(latent_size, 128),
  102. nn.ReLU(),
  103. nn.Linear(128, input_size),
  104. nn.Sigmoid()
  105. )
  106. def forward(self, x: torch.Tensor):
  107. encoded = self.encoder(x)
  108. decoded = self.decoder(encoded)
  109. return encoded, decoded
  110. class DSDNN(nn.Module):
  111. def __init__(self, input_size: int, hidden_size: int):
  112. super(DSDNN, self).__init__()
  113. self.network = nn.Sequential(
  114. nn.Linear(input_size, hidden_size),
  115. nn.ReLU(),
  116. nn.BatchNorm1d(hidden_size),
  117. nn.Linear(hidden_size, hidden_size // 2),
  118. nn.ReLU(),
  119. nn.BatchNorm1d(hidden_size // 2),
  120. nn.Linear(hidden_size // 2, 1)
  121. )
  122. def forward(self, feat: torch.Tensor):
  123. out = self.network(feat)
  124. return out
  125. class DNN(nn.Module):
  126. def __init__(self, input_size: int, hidden_size: int):
  127. super(DNN, self).__init__()
  128. self.network = nn.Sequential(
  129. nn.Linear(input_size, hidden_size),
  130. nn.ReLU(),
  131. nn.BatchNorm1d(hidden_size),
  132. nn.Linear(hidden_size, hidden_size // 2),
  133. nn.ReLU(),
  134. nn.BatchNorm1d(hidden_size // 2),
  135. nn.Linear(hidden_size // 2, 1)
  136. )
  137. def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
  138. feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
  139. out = self.network(feat)
  140. return out
  141. class BottleneckLayer(nn.Module):
  142. def __init__(self, in_channels: int, out_channels: int):
  143. super(BottleneckLayer, self).__init__()
  144. self.net = nn.Sequential(
  145. nn.Conv1d(in_channels, out_channels, 1),
  146. nn.BatchNorm1d(out_channels),
  147. nn.LeakyReLU()
  148. )
  149. def forward(self, x):
  150. return self.net(x)
  151. class PatchySan(nn.Module):
  152. def __init__(self, drug_size: int, cell_size: int, hidden_size: int, field_size: int):
  153. super(PatchySan, self).__init__()
  154. # self.drug_proj = nn.Linear(drug_size, hidden_size, bias=False)
  155. # self.cell_proj = nn.Linear(cell_size, hidden_size, bias=False)
  156. self.conv = nn.Sequential(
  157. BottleneckLayer(field_size, 16),
  158. BottleneckLayer(16, 32),
  159. BottleneckLayer(32, 16),
  160. BottleneckLayer(16, 1),
  161. )
  162. self.network = nn.Sequential(
  163. nn.Linear(2 * drug_size + cell_size, hidden_size),
  164. nn.LeakyReLU(),
  165. nn.BatchNorm1d(hidden_size),
  166. nn.Linear(hidden_size, hidden_size // 2),
  167. nn.LeakyReLU(),
  168. nn.BatchNorm1d(hidden_size // 2),
  169. nn.Linear(hidden_size // 2, 1)
  170. )
  171. def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
  172. cell_feat = cell_feat.permute(0, 2, 1)
  173. cell_feat = self.conv(cell_feat).squeeze(1)
  174. # drug1_feat = self.drug_proj(drug1_feat)
  175. # drug2_feat = self.drug_proj(drug2_feat)
  176. # express = self.cell_proj(cell_feat)
  177. # feat = torch.cat([drug1_feat, drug2_feat, express], 1)
  178. # drug_feat = (drug1_feat + drug2_feat) / 2
  179. feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
  180. out = self.network(feat)
  181. return out
  182. class SynSyn(nn.Module):
  183. def __init__(self, drug_size: int, cell_size: int, hidden_size: int):
  184. super(SynSyn, self).__init__()
  185. self.drug_proj = nn.Linear(drug_size, drug_size)
  186. self.cell_proj = nn.Linear(cell_size, cell_size)
  187. self.network = nn.Sequential(
  188. nn.Linear(drug_size + cell_size, hidden_size),
  189. nn.ReLU(),
  190. nn.BatchNorm1d(hidden_size),
  191. nn.Linear(hidden_size, hidden_size // 2),
  192. nn.ReLU(),
  193. nn.BatchNorm1d(hidden_size // 2),
  194. nn.Linear(hidden_size // 2, 1)
  195. )
  196. def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
  197. d1 = self.drug_proj(drug1_feat)
  198. d2 = self.drug_proj(drug2_feat)
  199. d = d1.mul(d2)
  200. c = self.cell_proj(cell_feat)
  201. feat = torch.cat([d, c], 1)
  202. out = self.network(feat)
  203. return out
  204. class PPIDNN(nn.Module):
  205. def __init__(self, drug_size: int, cell_size: int, hidden_size: int, emb_size: int):
  206. super(PPIDNN, self).__init__()
  207. self.conv = nn.Sequential(
  208. BottleneckLayer(emb_size, 64),
  209. BottleneckLayer(64, 128),
  210. BottleneckLayer(128, 64),
  211. BottleneckLayer(64, 1),
  212. )
  213. self.network = nn.Sequential(
  214. nn.Linear(2 * drug_size + cell_size, hidden_size),
  215. nn.LeakyReLU(),
  216. nn.BatchNorm1d(hidden_size),
  217. nn.Linear(hidden_size, hidden_size // 2),
  218. nn.LeakyReLU(),
  219. nn.BatchNorm1d(hidden_size // 2),
  220. nn.Linear(hidden_size // 2, 1)
  221. )
  222. def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
  223. cell_feat = cell_feat.permute(0, 2, 1)
  224. cell_feat = self.conv(cell_feat).squeeze(1)
  225. feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
  226. out = self.network(feat)
  227. return out
  228. class StackLinearDNN(nn.Module):
  229. def __init__(self, input_size: int, stack_size: int, hidden_size: int):
  230. super(StackLinearDNN, self).__init__()
  231. self.compress = nn.Parameter(torch.zeros(size=(1, stack_size)))
  232. nn.init.xavier_uniform_(self.compress.data, gain=1.414)
  233. self.network = nn.Sequential(
  234. nn.Linear(input_size, hidden_size),
  235. nn.ReLU(),
  236. nn.BatchNorm1d(hidden_size),
  237. nn.Linear(hidden_size, hidden_size // 2),
  238. nn.ReLU(),
  239. nn.BatchNorm1d(hidden_size // 2),
  240. nn.Linear(hidden_size // 2, 1)
  241. )
  242. def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
  243. cell_feat = torch.matmul(self.compress, cell_feat).squeeze(1)
  244. feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
  245. out = self.network(feat)
  246. return out
  247. class InteractionNet(nn.Module):
  248. def __init__(self, drug_size: int, cell_size: int, hidden_size: int):
  249. super(InteractionNet, self).__init__()
  250. # self.compress = nn.Parameter(torch.ones(size=(1, stack_size)))
  251. # self.drug_proj = nn.Sequential(
  252. # nn.Linear(drug_size, hidden_size),
  253. # nn.LeakyReLU(),
  254. # nn.BatchNorm1d(hidden_size)
  255. # )
  256. self.inter_net = nn.Sequential(
  257. nn.Linear(drug_size + cell_size, hidden_size),
  258. nn.LeakyReLU(),
  259. nn.BatchNorm1d(hidden_size)
  260. )
  261. self.network = nn.Sequential(
  262. nn.Linear(hidden_size, hidden_size // 2),
  263. nn.LeakyReLU(),
  264. nn.BatchNorm1d(hidden_size // 2),
  265. nn.Linear(hidden_size // 2, 1)
  266. )
  267. def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
  268. # cell_feat = torch.mat
  269. # mul(self.compress, cell_feat).squeeze(1)
  270. # d1 = self.drug_proj(drug1_feat)
  271. # d2 = self.drug_proj(drug2_feat)
  272. dc1 = torch.cat([drug1_feat, cell_feat], 1)
  273. dc2 = torch.cat([drug2_feat, cell_feat], 1)
  274. inter1 = self.inter_net(dc1)
  275. inter2 = self.inter_net(dc2)
  276. inter3 = inter1 + inter2
  277. out = self.network(inter3)
  278. return out
  279. class StackProjDNN(nn.Module):
  280. def __init__(self, drug_size: int, cell_size: int, stack_size: int, hidden_size: int):
  281. super(StackProjDNN, self).__init__()
  282. self.projectors = nn.Parameter(torch.zeros(size=(stack_size, cell_size, cell_size)))
  283. nn.init.xavier_uniform_(self.projectors.data, gain=1.414)
  284. self.network = nn.Sequential(
  285. nn.Linear(2 * drug_size + cell_size, hidden_size),
  286. nn.ReLU(),
  287. nn.BatchNorm1d(hidden_size),
  288. nn.Linear(hidden_size, hidden_size // 2),
  289. nn.ReLU(),
  290. nn.BatchNorm1d(hidden_size // 2),
  291. nn.Linear(hidden_size // 2, 1)
  292. )
  293. def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
  294. cell_feat = cell_feat.unsqueeze(-1)
  295. cell_feats = torch.matmul(self.projectors, cell_feat).squeeze(-1)
  296. cell_feat = torch.sum(cell_feats, 1)
  297. feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
  298. out = self.network(feat)
  299. return out