1234567891011121314151617181920212223242526 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
-
-
- class MLP(nn.Module):
- def __init__(self, input_size: int, hidden_size: int):
- super(MLP, self).__init__()
- self.layers = nn.Sequential(
- nn.Linear(input_size, hidden_size),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size),
- nn.Linear(hidden_size, hidden_size // 2),
- nn.ReLU(),
- nn.BatchNorm1d(hidden_size // 2),
- nn.Linear(hidden_size // 2, 1)
- )
-
- def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
- feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
- out = self.layers(feat)
- return out
-
-
- # other PRODeepSyn models have been deleted for now
|