123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
-
- class PPI2Cell(nn.Module):
-
- def __init__(self, n_cell: int, ppi_emb: torch.Tensor, bias=True):
- super(PPI2Cell, self).__init__()
- self.n_cell = n_cell
- self.cell_emb = nn.Embedding(n_cell, ppi_emb.shape[1], max_norm=1.0, norm_type=2.0)
- if bias:
- self.bias = nn.Parameter(torch.randn((1, ppi_emb.shape[0])), requires_grad=True)
- else:
- self.bias = 0
- self.ppi_emb = ppi_emb.permute(1, 0)
-
- def forward(self, x: torch.Tensor):
- x = x.squeeze(dim=1)
- emb = self.cell_emb(x)
- y = emb.mm(self.ppi_emb)
- y += self.bias
- return y
-
-
- class PPI2CellV2(nn.Module):
-
- def __init__(self, n_cell: int, ppi_emb: torch.Tensor, hidden_dim: int, bias=True):
- super(PPI2CellV2, self).__init__()
- self.n_cell = n_cell
- self.projector = nn.Sequential(
- nn.Linear(ppi_emb.shape[1], hidden_dim, bias=bias),
- nn.LeakyReLU()
- )
- self.cell_emb = nn.Embedding(n_cell, hidden_dim, max_norm=1.0, norm_type=2.0)
- self.ppi_emb = ppi_emb
-
- def forward(self, x: torch.Tensor):
- x = x.squeeze(dim=1)
- proj = self.projector(self.ppi_emb).permute(1, 0)
- emb = self.cell_emb(x)
- y = emb.mm(proj)
- return y
-
-
- class SynEmb(nn.Module):
-
- def __init__(self, n_drug: int, drug_dim: int, n_cell: int, cell_dim: int, hidden_dim: int):
- super(SynEmb, self).__init__()
- self.drug_emb = nn.Embedding(n_drug, drug_dim, max_norm=1)
- self.cell_emb = nn.Embedding(n_cell, cell_dim, max_norm=1)
- self.network = DNN(2 * drug_dim + cell_dim, hidden_dim)
-
- def forward(self, drug1, drug2, cell):
- d1 = self.drug_emb(drug1).squeeze(1)
- d2 = self.drug_emb(drug2).squeeze(1)
- c = self.cell_emb(cell).squeeze(1)
- return self.network(d1, d2, c)
-
-
- class AutoEncoder(nn.Module):
-
- def __init__(self, input_size: int, latent_size: int):
- super(AutoEncoder, self).__init__()
-
- self.encoder = nn.Sequential(
- nn.Linear(input_size, input_size // 2),
- nn.ReLU(),
- nn.Linear(input_size // 2, input_size // 4),
- nn.ReLU(),
- nn.Linear(input_size // 4, latent_size)
- )
-
- self.decoder = nn.Sequential(
- nn.Linear(latent_size, input_size // 4),
- nn.ReLU(),
- nn.Linear(input_size // 4, input_size // 2),
- nn.ReLU(),
- nn.Linear(input_size // 2, input_size)
- )
-
- def forward(self, x: torch.Tensor):
- encoded = self.encoder(x)
- decoded = self.decoder(encoded)
- return encoded, decoded
-
-
- class GeneExpressionAE(nn.Module):
- def __init__(self, input_size: int, latent_size: int):
- super(GeneExpressionAE, self).__init__()
-
- self.encoder = nn.Sequential(
- nn.Linear(input_size, 2048),
- nn.Tanh(),
- nn.Linear(2048, 1024),
- nn.Tanh(),
- nn.Linear(1024, latent_size),
- nn.Tanh()
- )
- self.decoder = nn.Sequential(
- nn.Linear(latent_size, 1024),
- nn.Tanh(),
- nn.Linear(1024, 2048),
- nn.Tanh(),
- nn.Linear(2048, input_size),
- nn.Tanh()
- )
-
- def forward(self, x: torch.Tensor):
- encoded = self.encoder(x)
- decoded = self.decoder(encoded)
- return encoded, decoded
-
-
- class DrugFeatAE(nn.Module):
- def __init__(self, input_size: int, latent_size: int):
- super(DrugFeatAE, self).__init__()
-
- self.encoder = nn.Sequential(
- nn.Linear(input_size, 128),
- nn.ReLU(),
- nn.Linear(128, latent_size),
- nn.Sigmoid(),
- )
- self.decoder = nn.Sequential(
- nn.Linear(latent_size, 128),
- nn.ReLU(),
- nn.Linear(128, input_size),
- nn.Sigmoid()
- )
-
- def forward(self, x: torch.Tensor):
- encoded = self.encoder(x)
- decoded = self.decoder(encoded)
- return encoded, decoded
-
-
- class DSDNN(nn.Module):
- def __init__(self, input_size: int, hidden_size: int):
- super(DSDNN, self).__init__()
-
- self.network = nn.Sequential(
- nn.Linear(input_size, hidden_size),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size),
- nn.Linear(hidden_size, hidden_size // 2),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- def forward(self, feat: torch.Tensor):
- out = self.network(feat)
- return out
-
-
- class DNN(nn.Module):
- def __init__(self, input_size: int, hidden_size: int):
- super(DNN, self).__init__()
- self.network = nn.Sequential(
- nn.Linear(input_size, hidden_size),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size),
- nn.Linear(hidden_size, hidden_size // 2),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
- feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
- out = self.network(feat)
- return out
-
-
- class BottleneckLayer(nn.Module):
-
- def __init__(self, in_channels: int, out_channels: int):
- super(BottleneckLayer, self).__init__()
- self.net = nn.Sequential(
- nn.Conv1d(in_channels, out_channels, 1),
- nn.BatchNorm1d(out_channels),
- nn.LeakyReLU()
- )
-
- def forward(self, x):
- return self.net(x)
-
-
- class PatchySan(nn.Module):
-
- def __init__(self, drug_size: int, cell_size: int, hidden_size: int, field_size: int):
- super(PatchySan, self).__init__()
- # self.drug_proj = nn.Linear(drug_size, hidden_size, bias=False)
- # self.cell_proj = nn.Linear(cell_size, hidden_size, bias=False)
- self.conv = nn.Sequential(
- BottleneckLayer(field_size, 16),
- BottleneckLayer(16, 32),
- BottleneckLayer(32, 16),
- BottleneckLayer(16, 1),
- )
- self.network = nn.Sequential(
- nn.Linear(2 * drug_size + cell_size, hidden_size),
- nn.LeakyReLU(),
- nn.BatchNorm1d(hidden_size),
- nn.Linear(hidden_size, hidden_size // 2),
- nn.LeakyReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
- cell_feat = cell_feat.permute(0, 2, 1)
- cell_feat = self.conv(cell_feat).squeeze(1)
- # drug1_feat = self.drug_proj(drug1_feat)
- # drug2_feat = self.drug_proj(drug2_feat)
- # express = self.cell_proj(cell_feat)
- # feat = torch.cat([drug1_feat, drug2_feat, express], 1)
- # drug_feat = (drug1_feat + drug2_feat) / 2
- feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
- out = self.network(feat)
- return out
-
-
- class SynSyn(nn.Module):
-
- def __init__(self, drug_size: int, cell_size: int, hidden_size: int):
- super(SynSyn, self).__init__()
-
- self.drug_proj = nn.Linear(drug_size, drug_size)
- self.cell_proj = nn.Linear(cell_size, cell_size)
- self.network = nn.Sequential(
- nn.Linear(drug_size + cell_size, hidden_size),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size),
- nn.Linear(hidden_size, hidden_size // 2),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
- d1 = self.drug_proj(drug1_feat)
- d2 = self.drug_proj(drug2_feat)
- d = d1.mul(d2)
- c = self.cell_proj(cell_feat)
- feat = torch.cat([d, c], 1)
- out = self.network(feat)
- return out
-
-
- class PPIDNN(nn.Module):
-
- def __init__(self, drug_size: int, cell_size: int, hidden_size: int, emb_size: int):
- super(PPIDNN, self).__init__()
- self.conv = nn.Sequential(
- BottleneckLayer(emb_size, 64),
- BottleneckLayer(64, 128),
- BottleneckLayer(128, 64),
- BottleneckLayer(64, 1),
- )
- self.network = nn.Sequential(
- nn.Linear(2 * drug_size + cell_size, hidden_size),
- nn.LeakyReLU(),
- nn.BatchNorm1d(hidden_size),
- nn.Linear(hidden_size, hidden_size // 2),
- nn.LeakyReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
- cell_feat = cell_feat.permute(0, 2, 1)
- cell_feat = self.conv(cell_feat).squeeze(1)
- feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
- out = self.network(feat)
- return out
-
-
- class StackLinearDNN(nn.Module):
-
- def __init__(self, input_size: int, stack_size: int, hidden_size: int):
- super(StackLinearDNN, self).__init__()
-
- self.compress = nn.Parameter(torch.zeros(size=(1, stack_size)))
- nn.init.xavier_uniform_(self.compress.data, gain=1.414)
-
- self.network = nn.Sequential(
- nn.Linear(input_size, hidden_size),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size),
- nn.Linear(hidden_size, hidden_size // 2),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
- cell_feat = torch.matmul(self.compress, cell_feat).squeeze(1)
- feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
- out = self.network(feat)
- return out
-
-
- class InteractionNet(nn.Module):
-
- def __init__(self, drug_size: int, cell_size: int, hidden_size: int):
- super(InteractionNet, self).__init__()
-
- # self.compress = nn.Parameter(torch.ones(size=(1, stack_size)))
- # self.drug_proj = nn.Sequential(
- # nn.Linear(drug_size, hidden_size),
- # nn.LeakyReLU(),
- # nn.BatchNorm1d(hidden_size)
- # )
- self.inter_net = nn.Sequential(
- nn.Linear(drug_size + cell_size, hidden_size),
- nn.LeakyReLU(),
- nn.BatchNorm1d(hidden_size)
- )
-
- self.network = nn.Sequential(
- nn.Linear(hidden_size, hidden_size // 2),
- nn.LeakyReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
- # cell_feat = torch.mat
- # mul(self.compress, cell_feat).squeeze(1)
- # d1 = self.drug_proj(drug1_feat)
- # d2 = self.drug_proj(drug2_feat)
- dc1 = torch.cat([drug1_feat, cell_feat], 1)
- dc2 = torch.cat([drug2_feat, cell_feat], 1)
- inter1 = self.inter_net(dc1)
- inter2 = self.inter_net(dc2)
- inter3 = inter1 + inter2
- out = self.network(inter3)
- return out
-
-
- class StackProjDNN(nn.Module):
-
- def __init__(self, drug_size: int, cell_size: int, stack_size: int, hidden_size: int):
- super(StackProjDNN, self).__init__()
-
- self.projectors = nn.Parameter(torch.zeros(size=(stack_size, cell_size, cell_size)))
- nn.init.xavier_uniform_(self.projectors.data, gain=1.414)
-
- self.network = nn.Sequential(
- nn.Linear(2 * drug_size + cell_size, hidden_size),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size),
- nn.Linear(hidden_size, hidden_size // 2),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
- cell_feat = cell_feat.unsqueeze(-1)
- cell_feats = torch.matmul(self.projectors, cell_feat).squeeze(-1)
- cell_feat = torch.sum(cell_feats, 1)
- feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
- out = self.network(feat)
- return out
|