|
|
@@ -20,13 +20,14 @@ from torch_geometric.loader import NeighborLoader |
|
|
|
|
|
|
|
|
|
|
|
def find_subgraph_mask(ddi_graph, batch): |
|
|
|
print('----------------------------------') |
|
|
|
# print('----------------------------------') |
|
|
|
drug1_id, drug2_id, cell_feat, y_true = batch |
|
|
|
drug1_id = torch.flatten(drug1_id) |
|
|
|
drug2_id = torch.flatten(drug2_id) |
|
|
|
drug_ids = torch.cat((drug1_id, drug2_id), 0) |
|
|
|
drug_ids, count = drug_ids.unique(return_counts=True) |
|
|
|
print(drug_ids) |
|
|
|
# Removing negative ids |
|
|
|
drug_ids = (drug_ids >= 0).nonzero(as_tuple=True)[0] |
|
|
|
return drug_ids |
|
|
|
|
|
|
|
|
|
|
@@ -62,26 +63,34 @@ def step_batch(model, batch, loss_func, subgraph, gpu_id=None, train=True): |
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None): |
|
|
|
print('this is in train epoch: ', ddi_graph) |
|
|
|
# print('this is in train epoch: ', ddi_graph) |
|
|
|
model.train() |
|
|
|
epoch_loss = 0 |
|
|
|
|
|
|
|
for _, batch in enumerate(loader): |
|
|
|
# TODO: for batch ... desired subgraph |
|
|
|
subgraph_mask = find_subgraph_mask(ddi_graph, batch) |
|
|
|
print(ddi_graph) |
|
|
|
# print("subgraph_mask: ",subgraph_mask) ---> subgraph_mask is OK |
|
|
|
# print("This is ddi_graph:") |
|
|
|
# print(ddi_graph) ---> ddi_graph is OK |
|
|
|
neighbor_loader = NeighborLoader(ddi_graph, |
|
|
|
num_neighbors=[3,2], batch_size=100, |
|
|
|
num_neighbors=[3,2], batch_size=10, |
|
|
|
input_nodes=subgraph_mask, |
|
|
|
directed=False, shuffle=True) |
|
|
|
|
|
|
|
|
|
|
|
# sampled_data = next(iter(neighbor_loader)) |
|
|
|
# print("Sampled_data: ") |
|
|
|
# print(sampled_data.batch_size) |
|
|
|
|
|
|
|
for subgraph in neighbor_loader: |
|
|
|
optimizer.zero_grad() |
|
|
|
loss = step_batch(model, batch, loss_func, subgraph, gpu_id) |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
epoch_loss += loss.item() |
|
|
|
|
|
|
|
|
|
|
|
print("epoch_loss: ", epoch_loss) |
|
|
|
return epoch_loss |
|
|
|
|
|
|
|
|
|
|
@@ -98,11 +107,13 @@ def eval_epoch(model, loader, loss_func, gpu_id=None): |
|
|
|
def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id, |
|
|
|
ddi_graph, |
|
|
|
sl=False, mdl_dir=None): |
|
|
|
print('this is in train: ', ddi_graph) |
|
|
|
# print('this is in train: ', ddi_graph) |
|
|
|
min_loss = float('inf') |
|
|
|
angry = 0 |
|
|
|
for epoch in range(1, n_epoch + 1): |
|
|
|
print("epoch: ",str(epoch)) |
|
|
|
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, ddi_graph, gpu_id) |
|
|
|
print("train loss: ", trn_loss) |
|
|
|
trn_loss /= train_loader.dataset_len |
|
|
|
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id) |
|
|
|
val_loss /= valid_loader.dataset_len |