|
|
@@ -13,8 +13,21 @@ time_str = str(datetime.now().strftime('%y%m%d%H%M')) |
|
|
|
|
|
|
|
from model.datasets import FastSynergyDataset, FastTensorDataLoader |
|
|
|
from model.models import MLP |
|
|
|
from drug.datasets import DDInteractionDataset |
|
|
|
from model.utils import save_args, save_best_model, find_best_model, arg_min, random_split_indices, calc_stat, conf_inv |
|
|
|
from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE |
|
|
|
from torch_geometric.loader import NeighborLoader |
|
|
|
|
|
|
|
|
|
|
|
def find_subgraph_mask(ddi_graph, batch): |
|
|
|
print('----------------------------------') |
|
|
|
drug1_id, drug2_id, cell_feat, y_true = batch |
|
|
|
drug1_id = torch.flatten(drug1_id) |
|
|
|
drug2_id = torch.flatten(drug2_id) |
|
|
|
drug_ids = torch.cat((drug1_id, drug2_id), 0) |
|
|
|
drug_ids, count = drug_ids.unique(return_counts=True) |
|
|
|
print(drug_ids) |
|
|
|
return drug_ids |
|
|
|
|
|
|
|
|
|
|
|
def eval_model(model, optimizer, loss_func, train_data, test_data, |
|
|
@@ -30,7 +43,7 @@ def eval_model(model, optimizer, loss_func, train_data, test_data, |
|
|
|
return test_loss |
|
|
|
|
|
|
|
|
|
|
|
def step_batch(model, batch, loss_func, gpu_id=None, train=True): |
|
|
|
def step_batch(model, batch, loss_func, subgraph, gpu_id=None, train=True): |
|
|
|
drug1_id, drug2_id, cell_feat, y_true = batch |
|
|
|
if gpu_id is not None: |
|
|
|
drug1_id = drug1_id.cuda(gpu_id) |
|
|
@@ -39,24 +52,36 @@ def step_batch(model, batch, loss_func, gpu_id=None, train=True): |
|
|
|
y_true = y_true.cuda(gpu_id) |
|
|
|
|
|
|
|
if train: |
|
|
|
y_pred = model(drug1_id, drug2_id, cell_feat) |
|
|
|
y_pred = model(drug1_id, drug2_id, cell_feat, subgraph) |
|
|
|
else: |
|
|
|
yp1 = model(drug1_id, drug2_id, cell_feat) |
|
|
|
yp2 = model(drug2_id, drug1_id, cell_feat) |
|
|
|
yp1 = model(drug1_id, drug2_id, cell_feat, subgraph) |
|
|
|
yp2 = model(drug2_id, drug1_id, cell_feat, subgraph) |
|
|
|
y_pred = (yp1 + yp2) / 2 |
|
|
|
loss = loss_func(y_pred, y_true) |
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(model, loader, loss_func, optimizer, gpu_id=None): |
|
|
|
def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None): |
|
|
|
print('this is in train epoch: ', ddi_graph) |
|
|
|
model.train() |
|
|
|
epoch_loss = 0 |
|
|
|
|
|
|
|
for _, batch in enumerate(loader): |
|
|
|
optimizer.zero_grad() |
|
|
|
loss = step_batch(model, batch, loss_func, gpu_id) |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
epoch_loss += loss.item() |
|
|
|
# TODO: for batch ... desired subgraph |
|
|
|
subgraph_mask = find_subgraph_mask(ddi_graph, batch) |
|
|
|
print(ddi_graph) |
|
|
|
neighbor_loader = NeighborLoader(ddi_graph, |
|
|
|
num_neighbors=[3,2], batch_size=100, |
|
|
|
input_nodes=subgraph_mask, |
|
|
|
directed=False, shuffle=True) |
|
|
|
|
|
|
|
for subgraph in neighbor_loader: |
|
|
|
optimizer.zero_grad() |
|
|
|
loss = step_batch(model, batch, loss_func, subgraph, gpu_id) |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
epoch_loss += loss.item() |
|
|
|
|
|
|
|
return epoch_loss |
|
|
|
|
|
|
|
|
|
|
@@ -71,11 +96,13 @@ def eval_epoch(model, loader, loss_func, gpu_id=None): |
|
|
|
|
|
|
|
|
|
|
|
def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id, |
|
|
|
ddi_graph, |
|
|
|
sl=False, mdl_dir=None): |
|
|
|
print('this is in train: ', ddi_graph) |
|
|
|
min_loss = float('inf') |
|
|
|
angry = 0 |
|
|
|
for epoch in range(1, n_epoch + 1): |
|
|
|
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, gpu_id) |
|
|
|
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, ddi_graph, gpu_id) |
|
|
|
trn_loss /= train_loader.dataset_len |
|
|
|
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id) |
|
|
|
val_loss /= valid_loader.dataset_len |
|
|
@@ -97,7 +124,6 @@ def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch |
|
|
|
|
|
|
|
|
|
|
|
def create_model(data, hidden_size, gpu_id=None): |
|
|
|
# TODO: use our own MLP model |
|
|
|
# get 256 |
|
|
|
model = MLP(data.cell_feat_len() + 2 * 256, hidden_size, gpu_id) |
|
|
|
if gpu_id is not None: |
|
|
@@ -118,6 +144,9 @@ def cv(args, out_dir): |
|
|
|
else: |
|
|
|
gpu_id = None |
|
|
|
|
|
|
|
ddiDataset = DDInteractionDataset(gpu_id = gpu_id) |
|
|
|
ddi_graph = ddiDataset.get() |
|
|
|
|
|
|
|
n_folds = 5 |
|
|
|
n_delimiter = 60 |
|
|
|
loss_func = nn.MSELoss(reduction='sum') |
|
|
@@ -148,7 +177,7 @@ def cv(args, out_dir): |
|
|
|
logging.info( |
|
|
|
"Start inner loop: train folds {}, valid folds {}".format(inner_trn_folds, valid_folds)) |
|
|
|
ret = train_model(model, optimizer, loss_func, train_loader, valid_loader, |
|
|
|
args.epoch, args.patience, gpu_id, sl=False) |
|
|
|
args.epoch, args.patience, gpu_id, ddi_graph = ddi_graph, sl=False, ) |
|
|
|
ret_vals.append(ret) |
|
|
|
del model |
|
|
|
|