drug/data/DDI/SNAP Stanford/ChCh-Miner_durgbank-chem-chem.tsv | drug/data/DDI/SNAP Stanford/ChCh-Miner_durgbank-chem-chem.tsv | ||||
cell/data/DTI/SNAP Stanford/ChG-Miner_miner-chem-gene.tsv | cell/data/DTI/SNAP Stanford/ChG-Miner_miner-chem-gene.tsv | ||||
cell/data/DTI/SNAP Stanford/ChG-Miner_miner-chem-gene.tsv.gz | cell/data/DTI/SNAP Stanford/ChG-Miner_miner-chem-gene.tsv.gz | ||||
drug/data/Smiles/drugbank_all_structure_links.csv.zip | |||||
drug/data/Smiles/drugbank_all_structure_links.csv.zip |
ERLOTINIB DB00530 | ERLOTINIB DB00530 | ||||
ETOPOSIDE DB00773 | ETOPOSIDE DB00773 | ||||
GELDANAMYCIN DB02424 | GELDANAMYCIN DB02424 | ||||
GEMCITABINE DB00441 | |||||
L778123 15 ?????????????????????? | |||||
GEMCITABINE DB00441 | |||||
L778123 DB07227 | |||||
LAPATINIB DB01259 | LAPATINIB DB01259 | ||||
METFORMIN DB00331 | METFORMIN DB00331 | ||||
METHOTREXATE DB00563 | METHOTREXATE DB00563 | ||||
MRK-003 DB17015 | MRK-003 DB17015 | ||||
OXALIPLATIN DB00526 | OXALIPLATIN DB00526 | ||||
PACLITAXEL DB01229 | PACLITAXEL DB01229 | ||||
PD325901 29 ????????????????????? | |||||
PD325901 DB07101 | |||||
SN-38 DB05482 | SN-38 DB05482 | ||||
SORAFENIB DB00398 | SORAFENIB DB00398 | ||||
SUNITINIB DB01268 | SUNITINIB DB01268 |
import os | |||||
PROJ_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__))) | |||||
SUB_PROJ_DIR = os.path.join(PROJ_DIR, 'predictor') | |||||
DATA_DIR = os.path.join(SUB_PROJ_DIR, 'data') | |||||
DRUG_DATA_DIR = os.path.join(PROJ_DIR, 'drug', 'data') | |||||
CELL_DATA_DIR = os.path.join(PROJ_DIR, 'cell', 'data') | |||||
OUTPUT_DIR = 'output' | |||||
if not os.path.exists(OUTPUT_DIR): | |||||
os.makedirs(OUTPUT_DIR) | |||||
SYNERGY_FILE = os.path.join(DATA_DIR, 'synergy.tsv') | |||||
DRUG_FEAT_FILE = os.path.join(DRUG_DATA_DIR, 'drug_feat.npy') | |||||
DRUG2ID_FILE = os.path.join(DRUG_DATA_DIR, 'drug2id.tsv') | |||||
CELL_FEAT_FILE = os.path.join(CELL_DATA_DIR, 'cell_feat.npy') | |||||
CELL2ID_FILE = os.path.join(CELL_DATA_DIR, 'cell2id.tsv') | |||||
import argparse | |||||
import os | |||||
import logging | |||||
import time | |||||
import pickle | |||||
import torch | |||||
import torch.nn as nn | |||||
from datetime import datetime | |||||
time_str = str(datetime.now().strftime('%y%m%d%H%M')) | |||||
from model.datasets import FastSynergyDataset, FastTensorDataLoader | |||||
from model.models import DNN | |||||
from model.utils import save_args, save_best_model, find_best_model, arg_min, random_split_indices, calc_stat, conf_inv | |||||
from const import SYNERGY_FILE, DRUG2ID_FILE, DRUG_FEAT_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR | |||||
def eval_model(model, optimizer, loss_func, train_data, test_data, | |||||
batch_size, n_epoch, patience, gpu_id, mdl_dir): | |||||
tr_indices, es_indices = random_split_indices(len(train_data), test_rate=0.1) | |||||
train_loader = FastTensorDataLoader(*train_data.tensor_samples(tr_indices), batch_size=batch_size, shuffle=True) | |||||
valid_loader = FastTensorDataLoader(*train_data.tensor_samples(es_indices), batch_size=len(es_indices) // 4) | |||||
test_loader = FastTensorDataLoader(*test_data.tensor_samples(), batch_size=len(test_data) // 4) | |||||
train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id, | |||||
sl=True, mdl_dir=mdl_dir) | |||||
test_loss = eval_epoch(model, test_loader, loss_func, gpu_id) | |||||
test_loss /= len(test_data) | |||||
return test_loss | |||||
def step_batch(model, batch, loss_func, gpu_id=None, train=True): | |||||
drug1_feats, drug2_feats, cell_feats, y_true = batch | |||||
if gpu_id is not None: | |||||
drug1_feats, drug2_feats, cell_feats, y_true = drug1_feats.cuda(gpu_id), drug2_feats.cuda(gpu_id), \ | |||||
cell_feats.cuda(gpu_id), y_true.cuda(gpu_id) | |||||
if train: | |||||
y_pred = model(drug1_feats, drug2_feats, cell_feats) | |||||
else: | |||||
yp1 = model(drug1_feats, drug2_feats, cell_feats) | |||||
yp2 = model(drug2_feats, drug1_feats, cell_feats) | |||||
y_pred = (yp1 + yp2) / 2 | |||||
loss = loss_func(y_pred, y_true) | |||||
return loss | |||||
def train_epoch(model, loader, loss_func, optimizer, gpu_id=None): | |||||
model.train() | |||||
epoch_loss = 0 | |||||
for _, batch in enumerate(loader): | |||||
optimizer.zero_grad() | |||||
loss = step_batch(model, batch, loss_func, gpu_id) | |||||
loss.backward() | |||||
optimizer.step() | |||||
epoch_loss += loss.item() | |||||
return epoch_loss | |||||
def eval_epoch(model, loader, loss_func, gpu_id=None): | |||||
model.eval() | |||||
with torch.no_grad(): | |||||
epoch_loss = 0 | |||||
for batch in loader: | |||||
loss = step_batch(model, batch, loss_func, gpu_id, train=False) | |||||
epoch_loss += loss.item() | |||||
return epoch_loss | |||||
def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id, | |||||
sl=False, mdl_dir=None): | |||||
min_loss = float('inf') | |||||
angry = 0 | |||||
for epoch in range(1, n_epoch + 1): | |||||
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, gpu_id) | |||||
trn_loss /= train_loader.dataset_len | |||||
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id) | |||||
val_loss /= valid_loader.dataset_len | |||||
if val_loss < min_loss: | |||||
angry = 0 | |||||
min_loss = val_loss | |||||
if sl: | |||||
# TODO: save best model in case of need | |||||
save_best_model(model.state_dict(), mdl_dir, epoch, keep=1) | |||||
pass | |||||
else: | |||||
angry += 1 | |||||
if angry >= patience: | |||||
break | |||||
if sl: | |||||
model.load_state_dict(torch.load(find_best_model(mdl_dir))) | |||||
return min_loss | |||||
def create_model(data, hidden_size, gpu_id=None): | |||||
# TODO: use our own MLP model | |||||
model = DNN(data.cell_feat_len() + 2 * data.drug_feat_len(), hidden_size) | |||||
if gpu_id is not None: | |||||
model = model.cuda(gpu_id) | |||||
return model | |||||
def cv(args, out_dir): | |||||
save_args(args, os.path.join(out_dir, 'args.json')) | |||||
test_loss_file = os.path.join(out_dir, 'test_loss.pkl') | |||||
if torch.cuda.is_available() and (args.gpu is not None): | |||||
gpu_id = args.gpu | |||||
else: | |||||
gpu_id = None | |||||
n_folds = 5 | |||||
n_delimiter = 60 | |||||
loss_func = nn.MSELoss(reduction='sum') | |||||
test_losses = [] | |||||
for test_fold in range(n_folds): | |||||
outer_trn_folds = [x for x in range(n_folds) if x != test_fold] | |||||
logging.info("Outer: train folds {}, test folds {}".format(outer_trn_folds, test_fold)) | |||||
logging.info("-" * n_delimiter) | |||||
param = [] | |||||
losses = [] | |||||
for hs in args.hidden: | |||||
for lr in args.lr: | |||||
param.append((hs, lr)) | |||||
logging.info("Hidden size: {} | Learning rate: {}".format(hs, lr)) | |||||
ret_vals = [] | |||||
for valid_fold in outer_trn_folds: | |||||
inner_trn_folds = [x for x in outer_trn_folds if x != valid_fold] | |||||
valid_folds = [valid_fold] | |||||
train_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, DRUG_FEAT_FILE, CELL_FEAT_FILE, | |||||
SYNERGY_FILE, use_folds=inner_trn_folds) | |||||
valid_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, DRUG_FEAT_FILE, CELL_FEAT_FILE, | |||||
SYNERGY_FILE, use_folds=valid_folds, train=False) | |||||
train_loader = FastTensorDataLoader(*train_data.tensor_samples(), batch_size=args.batch, | |||||
shuffle=True) | |||||
valid_loader = FastTensorDataLoader(*valid_data.tensor_samples(), batch_size=len(valid_data) // 4) | |||||
model = create_model(train_data, hs, gpu_id) | |||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |||||
logging.info( | |||||
"Start inner loop: train folds {}, valid folds {}".format(inner_trn_folds, valid_folds)) | |||||
ret = train_model(model, optimizer, loss_func, train_loader, valid_loader, | |||||
args.epoch, args.patience, gpu_id, sl=False) | |||||
ret_vals.append(ret) | |||||
del model | |||||
inner_loss = sum(ret_vals) / len(ret_vals) | |||||
logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss)) | |||||
logging.info("-" * n_delimiter) | |||||
losses.append(inner_loss) | |||||
torch.cuda.empty_cache() | |||||
time.sleep(10) | |||||
min_ls, min_idx = arg_min(losses) | |||||
best_hs, best_lr = param[min_idx] | |||||
train_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, DRUG_FEAT_FILE, CELL_FEAT_FILE, | |||||
SYNERGY_FILE, use_folds=outer_trn_folds) | |||||
test_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, DRUG_FEAT_FILE, CELL_FEAT_FILE, | |||||
SYNERGY_FILE, use_folds=[test_fold], train=False) | |||||
model = create_model(train_data, best_hs, gpu_id) | |||||
optimizer = torch.optim.Adam(model.parameters(), lr=best_lr) | |||||
logging.info("Best hidden size: {} | Best learning rate: {}".format(best_hs, best_lr)) | |||||
logging.info("Start test on fold {}.".format([test_fold])) | |||||
test_mdl_dir = os.path.join(out_dir, str(test_fold)) | |||||
if not os.path.exists(test_mdl_dir): | |||||
os.makedirs(test_mdl_dir) | |||||
test_loss = eval_model(model, optimizer, loss_func, train_data, test_data, | |||||
args.batch, args.epoch, args.patience, gpu_id, test_mdl_dir) | |||||
test_losses.append(test_loss) | |||||
logging.info("Test loss: {:.4f}".format(test_loss)) | |||||
logging.info("*" * n_delimiter + '\n') | |||||
logging.info("CV completed") | |||||
with open(test_loss_file, 'wb') as f: | |||||
pickle.dump(test_losses, f) | |||||
mu, sigma = calc_stat(test_losses) | |||||
logging.info("MSE: {:.4f} ± {:.4f}".format(mu, sigma)) | |||||
lo, hi = conf_inv(mu, sigma, len(test_losses)) | |||||
logging.info("Confidence interval: [{:.4f}, {:.4f}]".format(lo, hi)) | |||||
rmse_loss = [x ** 0.5 for x in test_losses] | |||||
mu, sigma = calc_stat(rmse_loss) | |||||
logging.info("RMSE: {:.4f} ± {:.4f}".format(mu, sigma)) | |||||
def main(): | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--epoch', type=int, default=500, help="n epoch") | |||||
parser.add_argument('--batch', type=int, default=256, help="batch size") | |||||
parser.add_argument('--gpu', type=int, default=None, help="cuda device") | |||||
parser.add_argument('--patience', type=int, default=100, help='patience for early stop') | |||||
parser.add_argument('--suffix', type=str, default=time_str, help="model dir suffix") | |||||
parser.add_argument('--hidden', type=int, nargs='+', default=[2048, 4096, 8192], help="hidden size") | |||||
parser.add_argument('--lr', type=float, nargs='+', default=[1e-3, 1e-4, 1e-5], help="learning rate") | |||||
args = parser.parse_args() | |||||
out_dir = os.path.join(OUTPUT_DIR, 'cv_{}'.format(args.suffix)) | |||||
if not os.path.exists(out_dir): | |||||
os.makedirs(out_dir) | |||||
log_file = os.path.join(out_dir, 'cv.log') | |||||
logging.basicConfig(filename=log_file, | |||||
format='%(asctime)s %(message)s', | |||||
datefmt='[%Y-%m-%d %H:%M:%S]', | |||||
level=logging.INFO) | |||||
cv(args, out_dir) | |||||
if __name__ == '__main__': | |||||
main() |
import numpy as np | |||||
import torch | |||||
import random | |||||
from torch.utils.data import Dataset | |||||
from .utils import read_map | |||||
class FastTensorDataLoader: | |||||
""" | |||||
A DataLoader-like object for a set of tensors that can be much faster than | |||||
TensorDataset + DataLoader because dataloader grabs individual indices of | |||||
the dataset and calls cat (slow). | |||||
Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 | |||||
""" | |||||
def __init__(self, *tensors, batch_size=32, shuffle=False): | |||||
""" | |||||
Initialize a FastTensorDataLoader. | |||||
:param *tensors: tensors to store. Must have the same length @ dim 0. | |||||
:param batch_size: batch size to load. | |||||
:param shuffle: if True, shuffle the data *in-place* whenever an | |||||
iterator is created out of this object. | |||||
:returns: A FastTensorDataLoader. | |||||
""" | |||||
assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) | |||||
self.tensors = tensors | |||||
self.dataset_len = self.tensors[0].shape[0] | |||||
self.batch_size = batch_size | |||||
self.shuffle = shuffle | |||||
# Calculate # batches | |||||
n_batches, remainder = divmod(self.dataset_len, self.batch_size) | |||||
if remainder > 0: | |||||
n_batches += 1 | |||||
self.n_batches = n_batches | |||||
def __iter__(self): | |||||
if self.shuffle: | |||||
r = torch.randperm(self.dataset_len) | |||||
self.tensors = [t[r] for t in self.tensors] | |||||
self.i = 0 | |||||
return self | |||||
def __next__(self): | |||||
if self.i >= self.dataset_len: | |||||
raise StopIteration | |||||
batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors) | |||||
self.i += self.batch_size | |||||
return batch | |||||
def __len__(self): | |||||
return self.n_batches | |||||
class FastSynergyDataset(Dataset): | |||||
def __init__(self, drug2id_file, cell2id_file, drug_feat_file, cell_feat_file, synergy_score_file, use_folds, | |||||
train=True): | |||||
self.drug2id = read_map(drug2id_file) | |||||
self.cell2id = read_map(cell2id_file) | |||||
self.drug_feat = np.load(drug_feat_file) | |||||
self.cell_feat = np.load(cell_feat_file) | |||||
self.samples = [] | |||||
self.raw_samples = [] | |||||
self.train = train | |||||
valid_drugs = set(self.drug2id.keys()) | |||||
valid_cells = set(self.cell2id.keys()) | |||||
with open(synergy_score_file, 'r') as f: | |||||
f.readline() | |||||
for line in f: | |||||
drug1, drug2, cellname, score, fold = line.rstrip().split('\t') | |||||
if drug1 in valid_drugs and drug2 in valid_drugs and cellname in valid_cells: | |||||
if int(fold) in use_folds: | |||||
sample = [ | |||||
# TODO: specify drug_feat | |||||
torch.from_numpy(self.drug_feat[self.drug2id[drug1]]).float(), | |||||
torch.from_numpy(self.drug_feat[self.drug2id[drug2]]).float(), | |||||
torch.from_numpy(self.cell_feat[self.cell2id[cellname]]).float(), | |||||
torch.FloatTensor([float(score)]), | |||||
] | |||||
self.samples.append(sample) | |||||
raw_sample = [self.drug2id[drug1], self.drug2id[drug2], self.cell2id[cellname], score] | |||||
self.raw_samples.append(raw_sample) | |||||
if train: | |||||
sample = [ | |||||
torch.from_numpy(self.drug_feat[self.drug2id[drug2]]).float(), | |||||
torch.from_numpy(self.drug_feat[self.drug2id[drug1]]).float(), | |||||
torch.from_numpy(self.cell_feat[self.cell2id[cellname]]).float(), | |||||
torch.FloatTensor([float(score)]), | |||||
] | |||||
self.samples.append(sample) | |||||
raw_sample = [self.drug2id[drug2], self.drug2id[drug1], self.cell2id[cellname], score] | |||||
self.raw_samples.append(raw_sample) | |||||
def __len__(self): | |||||
return len(self.samples) | |||||
def __getitem__(self, item): | |||||
return self.samples[item] | |||||
def drug_feat_len(self): | |||||
return self.drug_feat.shape[-1] | |||||
def cell_feat_len(self): | |||||
return self.cell_feat.shape[-1] | |||||
def tensor_samples(self, indices=None): | |||||
if indices is None: | |||||
indices = list(range(len(self))) | |||||
d1 = torch.cat([torch.unsqueeze(self.samples[i][0], 0) for i in indices], dim=0) | |||||
d2 = torch.cat([torch.unsqueeze(self.samples[i][1], 0) for i in indices], dim=0) | |||||
c = torch.cat([torch.unsqueeze(self.samples[i][2], 0) for i in indices], dim=0) | |||||
y = torch.cat([torch.unsqueeze(self.samples[i][3], 0) for i in indices], dim=0) | |||||
return d1, d2, c, y |
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
class PPI2Cell(nn.Module): | |||||
def __init__(self, n_cell: int, ppi_emb: torch.Tensor, bias=True): | |||||
super(PPI2Cell, self).__init__() | |||||
self.n_cell = n_cell | |||||
self.cell_emb = nn.Embedding(n_cell, ppi_emb.shape[1], max_norm=1.0, norm_type=2.0) | |||||
if bias: | |||||
self.bias = nn.Parameter(torch.randn((1, ppi_emb.shape[0])), requires_grad=True) | |||||
else: | |||||
self.bias = 0 | |||||
self.ppi_emb = ppi_emb.permute(1, 0) | |||||
def forward(self, x: torch.Tensor): | |||||
x = x.squeeze(dim=1) | |||||
emb = self.cell_emb(x) | |||||
y = emb.mm(self.ppi_emb) | |||||
y += self.bias | |||||
return y | |||||
class PPI2CellV2(nn.Module): | |||||
def __init__(self, n_cell: int, ppi_emb: torch.Tensor, hidden_dim: int, bias=True): | |||||
super(PPI2CellV2, self).__init__() | |||||
self.n_cell = n_cell | |||||
self.projector = nn.Sequential( | |||||
nn.Linear(ppi_emb.shape[1], hidden_dim, bias=bias), | |||||
nn.LeakyReLU() | |||||
) | |||||
self.cell_emb = nn.Embedding(n_cell, hidden_dim, max_norm=1.0, norm_type=2.0) | |||||
self.ppi_emb = ppi_emb | |||||
def forward(self, x: torch.Tensor): | |||||
x = x.squeeze(dim=1) | |||||
proj = self.projector(self.ppi_emb).permute(1, 0) | |||||
emb = self.cell_emb(x) | |||||
y = emb.mm(proj) | |||||
return y | |||||
class SynEmb(nn.Module): | |||||
def __init__(self, n_drug: int, drug_dim: int, n_cell: int, cell_dim: int, hidden_dim: int): | |||||
super(SynEmb, self).__init__() | |||||
self.drug_emb = nn.Embedding(n_drug, drug_dim, max_norm=1) | |||||
self.cell_emb = nn.Embedding(n_cell, cell_dim, max_norm=1) | |||||
self.network = DNN(2 * drug_dim + cell_dim, hidden_dim) | |||||
def forward(self, drug1, drug2, cell): | |||||
d1 = self.drug_emb(drug1).squeeze(1) | |||||
d2 = self.drug_emb(drug2).squeeze(1) | |||||
c = self.cell_emb(cell).squeeze(1) | |||||
return self.network(d1, d2, c) | |||||
class AutoEncoder(nn.Module): | |||||
def __init__(self, input_size: int, latent_size: int): | |||||
super(AutoEncoder, self).__init__() | |||||
self.encoder = nn.Sequential( | |||||
nn.Linear(input_size, input_size // 2), | |||||
nn.ReLU(), | |||||
nn.Linear(input_size // 2, input_size // 4), | |||||
nn.ReLU(), | |||||
nn.Linear(input_size // 4, latent_size) | |||||
) | |||||
self.decoder = nn.Sequential( | |||||
nn.Linear(latent_size, input_size // 4), | |||||
nn.ReLU(), | |||||
nn.Linear(input_size // 4, input_size // 2), | |||||
nn.ReLU(), | |||||
nn.Linear(input_size // 2, input_size) | |||||
) | |||||
def forward(self, x: torch.Tensor): | |||||
encoded = self.encoder(x) | |||||
decoded = self.decoder(encoded) | |||||
return encoded, decoded | |||||
class GeneExpressionAE(nn.Module): | |||||
def __init__(self, input_size: int, latent_size: int): | |||||
super(GeneExpressionAE, self).__init__() | |||||
self.encoder = nn.Sequential( | |||||
nn.Linear(input_size, 2048), | |||||
nn.Tanh(), | |||||
nn.Linear(2048, 1024), | |||||
nn.Tanh(), | |||||
nn.Linear(1024, latent_size), | |||||
nn.Tanh() | |||||
) | |||||
self.decoder = nn.Sequential( | |||||
nn.Linear(latent_size, 1024), | |||||
nn.Tanh(), | |||||
nn.Linear(1024, 2048), | |||||
nn.Tanh(), | |||||
nn.Linear(2048, input_size), | |||||
nn.Tanh() | |||||
) | |||||
def forward(self, x: torch.Tensor): | |||||
encoded = self.encoder(x) | |||||
decoded = self.decoder(encoded) | |||||
return encoded, decoded | |||||
class DrugFeatAE(nn.Module): | |||||
def __init__(self, input_size: int, latent_size: int): | |||||
super(DrugFeatAE, self).__init__() | |||||
self.encoder = nn.Sequential( | |||||
nn.Linear(input_size, 128), | |||||
nn.ReLU(), | |||||
nn.Linear(128, latent_size), | |||||
nn.Sigmoid(), | |||||
) | |||||
self.decoder = nn.Sequential( | |||||
nn.Linear(latent_size, 128), | |||||
nn.ReLU(), | |||||
nn.Linear(128, input_size), | |||||
nn.Sigmoid() | |||||
) | |||||
def forward(self, x: torch.Tensor): | |||||
encoded = self.encoder(x) | |||||
decoded = self.decoder(encoded) | |||||
return encoded, decoded | |||||
class DSDNN(nn.Module): | |||||
def __init__(self, input_size: int, hidden_size: int): | |||||
super(DSDNN, self).__init__() | |||||
self.network = nn.Sequential( | |||||
nn.Linear(input_size, hidden_size), | |||||
nn.ReLU(), | |||||
nn.BatchNorm1d(hidden_size), | |||||
nn.Linear(hidden_size, hidden_size // 2), | |||||
nn.ReLU(), | |||||
nn.BatchNorm1d(hidden_size // 2), | |||||
nn.Linear(hidden_size // 2, 1) | |||||
) | |||||
def forward(self, feat: torch.Tensor): | |||||
out = self.network(feat) | |||||
return out | |||||
class DNN(nn.Module): | |||||
def __init__(self, input_size: int, hidden_size: int): | |||||
super(DNN, self).__init__() | |||||
self.network = nn.Sequential( | |||||
nn.Linear(input_size, hidden_size), | |||||
nn.ReLU(), | |||||
nn.BatchNorm1d(hidden_size), | |||||
nn.Linear(hidden_size, hidden_size // 2), | |||||
nn.ReLU(), | |||||
nn.BatchNorm1d(hidden_size // 2), | |||||
nn.Linear(hidden_size // 2, 1) | |||||
) | |||||
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): | |||||
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) | |||||
out = self.network(feat) | |||||
return out | |||||
class BottleneckLayer(nn.Module): | |||||
def __init__(self, in_channels: int, out_channels: int): | |||||
super(BottleneckLayer, self).__init__() | |||||
self.net = nn.Sequential( | |||||
nn.Conv1d(in_channels, out_channels, 1), | |||||
nn.BatchNorm1d(out_channels), | |||||
nn.LeakyReLU() | |||||
) | |||||
def forward(self, x): | |||||
return self.net(x) | |||||
class PatchySan(nn.Module): | |||||
def __init__(self, drug_size: int, cell_size: int, hidden_size: int, field_size: int): | |||||
super(PatchySan, self).__init__() | |||||
# self.drug_proj = nn.Linear(drug_size, hidden_size, bias=False) | |||||
# self.cell_proj = nn.Linear(cell_size, hidden_size, bias=False) | |||||
self.conv = nn.Sequential( | |||||
BottleneckLayer(field_size, 16), | |||||
BottleneckLayer(16, 32), | |||||
BottleneckLayer(32, 16), | |||||
BottleneckLayer(16, 1), | |||||
) | |||||
self.network = nn.Sequential( | |||||
nn.Linear(2 * drug_size + cell_size, hidden_size), | |||||
nn.LeakyReLU(), | |||||
nn.BatchNorm1d(hidden_size), | |||||
nn.Linear(hidden_size, hidden_size // 2), | |||||
nn.LeakyReLU(), | |||||
nn.BatchNorm1d(hidden_size // 2), | |||||
nn.Linear(hidden_size // 2, 1) | |||||
) | |||||
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): | |||||
cell_feat = cell_feat.permute(0, 2, 1) | |||||
cell_feat = self.conv(cell_feat).squeeze(1) | |||||
# drug1_feat = self.drug_proj(drug1_feat) | |||||
# drug2_feat = self.drug_proj(drug2_feat) | |||||
# express = self.cell_proj(cell_feat) | |||||
# feat = torch.cat([drug1_feat, drug2_feat, express], 1) | |||||
# drug_feat = (drug1_feat + drug2_feat) / 2 | |||||
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) | |||||
out = self.network(feat) | |||||
return out | |||||
class SynSyn(nn.Module): | |||||
def __init__(self, drug_size: int, cell_size: int, hidden_size: int): | |||||
super(SynSyn, self).__init__() | |||||
self.drug_proj = nn.Linear(drug_size, drug_size) | |||||
self.cell_proj = nn.Linear(cell_size, cell_size) | |||||
self.network = nn.Sequential( | |||||
nn.Linear(drug_size + cell_size, hidden_size), | |||||
nn.ReLU(), | |||||
nn.BatchNorm1d(hidden_size), | |||||
nn.Linear(hidden_size, hidden_size // 2), | |||||
nn.ReLU(), | |||||
nn.BatchNorm1d(hidden_size // 2), | |||||
nn.Linear(hidden_size // 2, 1) | |||||
) | |||||
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): | |||||
d1 = self.drug_proj(drug1_feat) | |||||
d2 = self.drug_proj(drug2_feat) | |||||
d = d1.mul(d2) | |||||
c = self.cell_proj(cell_feat) | |||||
feat = torch.cat([d, c], 1) | |||||
out = self.network(feat) | |||||
return out | |||||
class PPIDNN(nn.Module): | |||||
def __init__(self, drug_size: int, cell_size: int, hidden_size: int, emb_size: int): | |||||
super(PPIDNN, self).__init__() | |||||
self.conv = nn.Sequential( | |||||
BottleneckLayer(emb_size, 64), | |||||
BottleneckLayer(64, 128), | |||||
BottleneckLayer(128, 64), | |||||
BottleneckLayer(64, 1), | |||||
) | |||||
self.network = nn.Sequential( | |||||
nn.Linear(2 * drug_size + cell_size, hidden_size), | |||||
nn.LeakyReLU(), | |||||
nn.BatchNorm1d(hidden_size), | |||||
nn.Linear(hidden_size, hidden_size // 2), | |||||
nn.LeakyReLU(), | |||||
nn.BatchNorm1d(hidden_size // 2), | |||||
nn.Linear(hidden_size // 2, 1) | |||||
) | |||||
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): | |||||
cell_feat = cell_feat.permute(0, 2, 1) | |||||
cell_feat = self.conv(cell_feat).squeeze(1) | |||||
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) | |||||
out = self.network(feat) | |||||
return out | |||||
class StackLinearDNN(nn.Module): | |||||
def __init__(self, input_size: int, stack_size: int, hidden_size: int): | |||||
super(StackLinearDNN, self).__init__() | |||||
self.compress = nn.Parameter(torch.zeros(size=(1, stack_size))) | |||||
nn.init.xavier_uniform_(self.compress.data, gain=1.414) | |||||
self.network = nn.Sequential( | |||||
nn.Linear(input_size, hidden_size), | |||||
nn.ReLU(), | |||||
nn.BatchNorm1d(hidden_size), | |||||
nn.Linear(hidden_size, hidden_size // 2), | |||||
nn.ReLU(), | |||||
nn.BatchNorm1d(hidden_size // 2), | |||||
nn.Linear(hidden_size // 2, 1) | |||||
) | |||||
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): | |||||
cell_feat = torch.matmul(self.compress, cell_feat).squeeze(1) | |||||
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) | |||||
out = self.network(feat) | |||||
return out | |||||
class InteractionNet(nn.Module): | |||||
def __init__(self, drug_size: int, cell_size: int, hidden_size: int): | |||||
super(InteractionNet, self).__init__() | |||||
# self.compress = nn.Parameter(torch.ones(size=(1, stack_size))) | |||||
# self.drug_proj = nn.Sequential( | |||||
# nn.Linear(drug_size, hidden_size), | |||||
# nn.LeakyReLU(), | |||||
# nn.BatchNorm1d(hidden_size) | |||||
# ) | |||||
self.inter_net = nn.Sequential( | |||||
nn.Linear(drug_size + cell_size, hidden_size), | |||||
nn.LeakyReLU(), | |||||
nn.BatchNorm1d(hidden_size) | |||||
) | |||||
self.network = nn.Sequential( | |||||
nn.Linear(hidden_size, hidden_size // 2), | |||||
nn.LeakyReLU(), | |||||
nn.BatchNorm1d(hidden_size // 2), | |||||
nn.Linear(hidden_size // 2, 1) | |||||
) | |||||
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): | |||||
# cell_feat = torch.mat | |||||
# mul(self.compress, cell_feat).squeeze(1) | |||||
# d1 = self.drug_proj(drug1_feat) | |||||
# d2 = self.drug_proj(drug2_feat) | |||||
dc1 = torch.cat([drug1_feat, cell_feat], 1) | |||||
dc2 = torch.cat([drug2_feat, cell_feat], 1) | |||||
inter1 = self.inter_net(dc1) | |||||
inter2 = self.inter_net(dc2) | |||||
inter3 = inter1 + inter2 | |||||
out = self.network(inter3) | |||||
return out | |||||
class StackProjDNN(nn.Module): | |||||
def __init__(self, drug_size: int, cell_size: int, stack_size: int, hidden_size: int): | |||||
super(StackProjDNN, self).__init__() | |||||
self.projectors = nn.Parameter(torch.zeros(size=(stack_size, cell_size, cell_size))) | |||||
nn.init.xavier_uniform_(self.projectors.data, gain=1.414) | |||||
self.network = nn.Sequential( | |||||
nn.Linear(2 * drug_size + cell_size, hidden_size), | |||||
nn.ReLU(), | |||||
nn.BatchNorm1d(hidden_size), | |||||
nn.Linear(hidden_size, hidden_size // 2), | |||||
nn.ReLU(), | |||||
nn.BatchNorm1d(hidden_size // 2), | |||||
nn.Linear(hidden_size // 2, 1) | |||||
) | |||||
def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor): | |||||
cell_feat = cell_feat.unsqueeze(-1) | |||||
cell_feats = torch.matmul(self.projectors, cell_feat).squeeze(-1) | |||||
cell_feat = torch.sum(cell_feats, 1) | |||||
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) | |||||
out = self.network(feat) | |||||
return out |
import os | |||||
import pickle | |||||
import matplotlib.pyplot as plt | |||||
import torch | |||||
import json | |||||
import random | |||||
def arg_min(lst): | |||||
m = float('inf') | |||||
idx = 0 | |||||
for i, v in enumerate(lst): | |||||
if v < m: | |||||
m = v | |||||
idx = i | |||||
return m, idx | |||||
def save_best_model(state_dict, model_dir: str, best_epoch: int, keep: int): | |||||
save_to = os.path.join(model_dir, '{}.pkl'.format(best_epoch)) | |||||
torch.save(state_dict, save_to) | |||||
model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl'] | |||||
epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])] | |||||
outdated = sorted(epochs, reverse=True)[keep:] | |||||
for n in outdated: | |||||
os.remove(os.path.join(model_dir, '{}.pkl'.format(n))) | |||||
def find_best_model(model_dir: str): | |||||
model_files = [f for f in os.listdir(model_dir) if os.path.splitext(f)[-1] == '.pkl'] | |||||
epochs = [int(os.path.splitext(f)[0]) for f in model_files if str.isdigit(f[0])] | |||||
best_epoch = max(epochs) | |||||
return os.path.join(model_dir, '{}.pkl'.format(best_epoch)) | |||||
def save_args(args, save_to: str): | |||||
args_dict = args.__dict__ | |||||
with open(save_to, 'w') as f: | |||||
json.dump(args_dict, f, indent=2) | |||||
def read_map(map_file): | |||||
d = {} | |||||
with open(map_file, 'r') as f: | |||||
f.readline() | |||||
for line in f: | |||||
k, v = line.rstrip().split('\t') | |||||
d[k] = int(v) | |||||
return d | |||||
def random_split_indices(n_samples, train_rate: float = None, test_rate: float = None): | |||||
if train_rate is not None and (train_rate < 0 or train_rate > 1): | |||||
raise ValueError("train rate should be in [0, 1], found {}".format(train_rate)) | |||||
elif test_rate is not None: | |||||
if test_rate < 0 or test_rate > 1: | |||||
raise ValueError("test rate should be in [0, 1], found {}".format(test_rate)) | |||||
train_rate = 1 - test_rate | |||||
elif train_rate is None and test_rate is None: | |||||
raise ValueError("Either train_rate or test_rate should be given.") | |||||
evidence = list(range(n_samples)) | |||||
train_size = int(len(evidence) * train_rate) | |||||
random.shuffle(evidence) | |||||
train_indices = evidence[:train_size] | |||||
test_indices = evidence[train_size:] | |||||
return train_indices, test_indices |