|
|
@@ -3,289 +3,11 @@ import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
class PPI2Cell(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, n_cell: int, ppi_emb: torch.Tensor, bias=True): |
|
|
|
super(PPI2Cell, self).__init__() |
|
|
|
self.n_cell = n_cell |
|
|
|
self.cell_emb = nn.Embedding(n_cell, ppi_emb.shape[1], max_norm=1.0, norm_type=2.0) |
|
|
|
if bias: |
|
|
|
self.bias = nn.Parameter(torch.randn((1, ppi_emb.shape[0])), requires_grad=True) |
|
|
|
else: |
|
|
|
self.bias = 0 |
|
|
|
self.ppi_emb = ppi_emb.permute(1, 0) |
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
x = x.squeeze(dim=1) |
|
|
|
emb = self.cell_emb(x) |
|
|
|
y = emb.mm(self.ppi_emb) |
|
|
|
y += self.bias |
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
class PPI2CellV2(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, n_cell: int, ppi_emb: torch.Tensor, hidden_dim: int, bias=True): |
|
|
|
super(PPI2CellV2, self).__init__() |
|
|
|
self.n_cell = n_cell |
|
|
|
self.projector = nn.Sequential( |
|
|
|
nn.Linear(ppi_emb.shape[1], hidden_dim, bias=bias), |
|
|
|
nn.LeakyReLU() |
|
|
|
) |
|
|
|
self.cell_emb = nn.Embedding(n_cell, hidden_dim, max_norm=1.0, norm_type=2.0) |
|
|
|
self.ppi_emb = ppi_emb |
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
x = x.squeeze(dim=1) |
|
|
|
proj = self.projector(self.ppi_emb).permute(1, 0) |
|
|
|
emb = self.cell_emb(x) |
|
|
|
y = emb.mm(proj) |
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
class SynEmb(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, n_drug: int, drug_dim: int, n_cell: int, cell_dim: int, hidden_dim: int): |
|
|
|
super(SynEmb, self).__init__() |
|
|
|
self.drug_emb = nn.Embedding(n_drug, drug_dim, max_norm=1) |
|
|
|
self.cell_emb = nn.Embedding(n_cell, cell_dim, max_norm=1) |
|
|
|
self.network = DNN(2 * drug_dim + cell_dim, hidden_dim) |
|
|
|
|
|
|
|
def forward(self, drug1, drug2, cell): |
|
|
|
d1 = self.drug_emb(drug1).squeeze(1) |
|
|
|
d2 = self.drug_emb(drug2).squeeze(1) |
|
|
|
c = self.cell_emb(cell).squeeze(1) |
|
|
|
return self.network(d1, d2, c) |
|
|
|
|
|
|
|
|
|
|
|
class AutoEncoder(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, input_size: int, latent_size: int): |
|
|
|
super(AutoEncoder, self).__init__() |
|
|
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
|
nn.Linear(input_size, input_size // 2), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(input_size // 2, input_size // 4), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(input_size // 4, latent_size) |
|
|
|
) |
|
|
|
|
|
|
|
self.decoder = nn.Sequential( |
|
|
|
nn.Linear(latent_size, input_size // 4), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(input_size // 4, input_size // 2), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(input_size // 2, input_size) |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
encoded = self.encoder(x) |
|
|
|
decoded = self.decoder(encoded) |
|
|
|
return encoded, decoded |
|
|
|
|
|
|
|
|
|
|
|
class GeneExpressionAE(nn.Module): |
|
|
|
def __init__(self, input_size: int, latent_size: int): |
|
|
|
super(GeneExpressionAE, self).__init__() |
|
|
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
|
nn.Linear(input_size, 2048), |
|
|
|
nn.Tanh(), |
|
|
|
nn.Linear(2048, 1024), |
|
|
|
nn.Tanh(), |
|
|
|
nn.Linear(1024, latent_size), |
|
|
|
nn.Tanh() |
|
|
|
) |
|
|
|
self.decoder = nn.Sequential( |
|
|
|
nn.Linear(latent_size, 1024), |
|
|
|
nn.Tanh(), |
|
|
|
nn.Linear(1024, 2048), |
|
|
|
nn.Tanh(), |
|
|
|
nn.Linear(2048, input_size), |
|
|
|
nn.Tanh() |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
encoded = self.encoder(x) |
|
|
|
decoded = self.decoder(encoded) |
|
|
|
return encoded, decoded |
|
|
|
|
|
|
|
|
|
|
|
class DrugFeatAE(nn.Module): |
|
|
|
def __init__(self, input_size: int, latent_size: int): |
|
|
|
super(DrugFeatAE, self).__init__() |
|
|
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
|
nn.Linear(input_size, 128), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(128, latent_size), |
|
|
|
nn.Sigmoid(), |
|
|
|
) |
|
|
|
self.decoder = nn.Sequential( |
|
|
|
nn.Linear(latent_size, 128), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(128, input_size), |
|
|
|
nn.Sigmoid() |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
encoded = self.encoder(x) |
|
|
|
decoded = self.decoder(encoded) |
|
|
|
return encoded, decoded |
|
|
|
|
|
|
|
|
|
|
|
class DSDNN(nn.Module): |
|
|
|
def __init__(self, input_size: int, hidden_size: int): |
|
|
|
super(DSDNN, self).__init__() |
|
|
|
|
|
|
|
self.network = nn.Sequential( |
|
|
|
nn.Linear(input_size, hidden_size), |
|
|
|
nn.ReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size), |
|
|
|
nn.Linear(hidden_size, hidden_size // 2), |
|
|
|
nn.ReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size // 2), |
|
|
|
nn.Linear(hidden_size // 2, 1) |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, feat: torch.Tensor): |
|
|
|
out = self.network(feat) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class DNN(nn.Module): |
|
|
|
class MLP(nn.Module): |
|
|
|
def __init__(self, input_size: int, hidden_size: int): |
|
|
|
super(DNN, self).__init__() |
|
|
|
self.network = nn.Sequential( |
|
|
|
nn.Linear(input_size, hidden_size), |
|
|
|
nn.ReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size), |
|
|
|
nn.Linear(hidden_size, hidden_size // 2), |
|
|
|
nn.ReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size // 2), |
|
|
|
nn.Linear(hidden_size // 2, 1) |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): |
|
|
|
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) |
|
|
|
out = self.network(feat) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class BottleneckLayer(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, in_channels: int, out_channels: int): |
|
|
|
super(BottleneckLayer, self).__init__() |
|
|
|
self.net = nn.Sequential( |
|
|
|
nn.Conv1d(in_channels, out_channels, 1), |
|
|
|
nn.BatchNorm1d(out_channels), |
|
|
|
nn.LeakyReLU() |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
return self.net(x) |
|
|
|
|
|
|
|
|
|
|
|
class PatchySan(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, drug_size: int, cell_size: int, hidden_size: int, field_size: int): |
|
|
|
super(PatchySan, self).__init__() |
|
|
|
# self.drug_proj = nn.Linear(drug_size, hidden_size, bias=False) |
|
|
|
# self.cell_proj = nn.Linear(cell_size, hidden_size, bias=False) |
|
|
|
self.conv = nn.Sequential( |
|
|
|
BottleneckLayer(field_size, 16), |
|
|
|
BottleneckLayer(16, 32), |
|
|
|
BottleneckLayer(32, 16), |
|
|
|
BottleneckLayer(16, 1), |
|
|
|
) |
|
|
|
self.network = nn.Sequential( |
|
|
|
nn.Linear(2 * drug_size + cell_size, hidden_size), |
|
|
|
nn.LeakyReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size), |
|
|
|
nn.Linear(hidden_size, hidden_size // 2), |
|
|
|
nn.LeakyReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size // 2), |
|
|
|
nn.Linear(hidden_size // 2, 1) |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): |
|
|
|
cell_feat = cell_feat.permute(0, 2, 1) |
|
|
|
cell_feat = self.conv(cell_feat).squeeze(1) |
|
|
|
# drug1_feat = self.drug_proj(drug1_feat) |
|
|
|
# drug2_feat = self.drug_proj(drug2_feat) |
|
|
|
# express = self.cell_proj(cell_feat) |
|
|
|
# feat = torch.cat([drug1_feat, drug2_feat, express], 1) |
|
|
|
# drug_feat = (drug1_feat + drug2_feat) / 2 |
|
|
|
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) |
|
|
|
out = self.network(feat) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class SynSyn(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, drug_size: int, cell_size: int, hidden_size: int): |
|
|
|
super(SynSyn, self).__init__() |
|
|
|
|
|
|
|
self.drug_proj = nn.Linear(drug_size, drug_size) |
|
|
|
self.cell_proj = nn.Linear(cell_size, cell_size) |
|
|
|
self.network = nn.Sequential( |
|
|
|
nn.Linear(drug_size + cell_size, hidden_size), |
|
|
|
nn.ReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size), |
|
|
|
nn.Linear(hidden_size, hidden_size // 2), |
|
|
|
nn.ReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size // 2), |
|
|
|
nn.Linear(hidden_size // 2, 1) |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): |
|
|
|
d1 = self.drug_proj(drug1_feat) |
|
|
|
d2 = self.drug_proj(drug2_feat) |
|
|
|
d = d1.mul(d2) |
|
|
|
c = self.cell_proj(cell_feat) |
|
|
|
feat = torch.cat([d, c], 1) |
|
|
|
out = self.network(feat) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class PPIDNN(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, drug_size: int, cell_size: int, hidden_size: int, emb_size: int): |
|
|
|
super(PPIDNN, self).__init__() |
|
|
|
self.conv = nn.Sequential( |
|
|
|
BottleneckLayer(emb_size, 64), |
|
|
|
BottleneckLayer(64, 128), |
|
|
|
BottleneckLayer(128, 64), |
|
|
|
BottleneckLayer(64, 1), |
|
|
|
) |
|
|
|
self.network = nn.Sequential( |
|
|
|
nn.Linear(2 * drug_size + cell_size, hidden_size), |
|
|
|
nn.LeakyReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size), |
|
|
|
nn.Linear(hidden_size, hidden_size // 2), |
|
|
|
nn.LeakyReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size // 2), |
|
|
|
nn.Linear(hidden_size // 2, 1) |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): |
|
|
|
cell_feat = cell_feat.permute(0, 2, 1) |
|
|
|
cell_feat = self.conv(cell_feat).squeeze(1) |
|
|
|
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) |
|
|
|
out = self.network(feat) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class StackLinearDNN(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, input_size: int, stack_size: int, hidden_size: int): |
|
|
|
super(StackLinearDNN, self).__init__() |
|
|
|
|
|
|
|
self.compress = nn.Parameter(torch.zeros(size=(1, stack_size))) |
|
|
|
nn.init.xavier_uniform_(self.compress.data, gain=1.414) |
|
|
|
|
|
|
|
self.network = nn.Sequential( |
|
|
|
super(MLP, self).__init__() |
|
|
|
self.layers = nn.Sequential( |
|
|
|
nn.Linear(input_size, hidden_size), |
|
|
|
nn.ReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size), |
|
|
@@ -294,74 +16,11 @@ class StackLinearDNN(nn.Module): |
|
|
|
nn.BatchNorm1d(hidden_size // 2), |
|
|
|
nn.Linear(hidden_size // 2, 1) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): |
|
|
|
cell_feat = torch.matmul(self.compress, cell_feat).squeeze(1) |
|
|
|
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) |
|
|
|
out = self.network(feat) |
|
|
|
out = self.layers(feat) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class InteractionNet(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, drug_size: int, cell_size: int, hidden_size: int): |
|
|
|
super(InteractionNet, self).__init__() |
|
|
|
|
|
|
|
# self.compress = nn.Parameter(torch.ones(size=(1, stack_size))) |
|
|
|
# self.drug_proj = nn.Sequential( |
|
|
|
# nn.Linear(drug_size, hidden_size), |
|
|
|
# nn.LeakyReLU(), |
|
|
|
# nn.BatchNorm1d(hidden_size) |
|
|
|
# ) |
|
|
|
self.inter_net = nn.Sequential( |
|
|
|
nn.Linear(drug_size + cell_size, hidden_size), |
|
|
|
nn.LeakyReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size) |
|
|
|
) |
|
|
|
|
|
|
|
self.network = nn.Sequential( |
|
|
|
nn.Linear(hidden_size, hidden_size // 2), |
|
|
|
nn.LeakyReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size // 2), |
|
|
|
nn.Linear(hidden_size // 2, 1) |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): |
|
|
|
# cell_feat = torch.mat |
|
|
|
# mul(self.compress, cell_feat).squeeze(1) |
|
|
|
# d1 = self.drug_proj(drug1_feat) |
|
|
|
# d2 = self.drug_proj(drug2_feat) |
|
|
|
dc1 = torch.cat([drug1_feat, cell_feat], 1) |
|
|
|
dc2 = torch.cat([drug2_feat, cell_feat], 1) |
|
|
|
inter1 = self.inter_net(dc1) |
|
|
|
inter2 = self.inter_net(dc2) |
|
|
|
inter3 = inter1 + inter2 |
|
|
|
out = self.network(inter3) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class StackProjDNN(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, drug_size: int, cell_size: int, stack_size: int, hidden_size: int): |
|
|
|
super(StackProjDNN, self).__init__() |
|
|
|
|
|
|
|
self.projectors = nn.Parameter(torch.zeros(size=(stack_size, cell_size, cell_size))) |
|
|
|
nn.init.xavier_uniform_(self.projectors.data, gain=1.414) |
|
|
|
|
|
|
|
self.network = nn.Sequential( |
|
|
|
nn.Linear(2 * drug_size + cell_size, hidden_size), |
|
|
|
nn.ReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size), |
|
|
|
nn.Linear(hidden_size, hidden_size // 2), |
|
|
|
nn.ReLU(), |
|
|
|
nn.BatchNorm1d(hidden_size // 2), |
|
|
|
nn.Linear(hidden_size // 2, 1) |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): |
|
|
|
cell_feat = cell_feat.unsqueeze(-1) |
|
|
|
cell_feats = torch.matmul(self.projectors, cell_feat).squeeze(-1) |
|
|
|
cell_feat = torch.sum(cell_feats, 1) |
|
|
|
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) |
|
|
|
out = self.network(feat) |
|
|
|
return out |
|
|
|
# other PRODeepSyn models have been deleted for now |