|
|
@@ -27,7 +27,8 @@ def find_subgraph_mask(ddi_graph, batch): |
|
|
|
drug_ids = torch.cat((drug1_id, drug2_id), 0) |
|
|
|
drug_ids, count = drug_ids.unique(return_counts=True) |
|
|
|
# Removing negative ids |
|
|
|
drug_ids = (drug_ids >= 0).nonzero(as_tuple=True)[0] |
|
|
|
drug_id_range = (drug_ids >= 0).nonzero(as_tuple=True)[0] |
|
|
|
drug_ids = torch.index_select(drug_ids, 0, drug_id_range) |
|
|
|
return drug_ids |
|
|
|
|
|
|
|
|
|
|
@@ -66,41 +67,62 @@ def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None): |
|
|
|
# print('this is in train epoch: ', ddi_graph) |
|
|
|
model.train() |
|
|
|
epoch_loss = 0 |
|
|
|
|
|
|
|
ddi_graph.n_id = torch.arange(ddi_graph.num_nodes) |
|
|
|
|
|
|
|
for _, batch in enumerate(loader): |
|
|
|
# print("len of batch loader: ",len(loader)) |
|
|
|
for i, batch in enumerate(loader): |
|
|
|
# print(f"this is batch {i} ------------------") |
|
|
|
# TODO: for batch ... desired subgraph |
|
|
|
subgraph_mask = find_subgraph_mask(ddi_graph, batch) |
|
|
|
# print("this is subgraph mask: ---------------") |
|
|
|
# print(subgraph_mask) |
|
|
|
# print("subgraph_mask: ",subgraph_mask) ---> subgraph_mask is OK |
|
|
|
# print("This is ddi_graph:") |
|
|
|
# print(ddi_graph) ---> ddi_graph is OK |
|
|
|
subgraph_batch_size = len(subgraph_mask) |
|
|
|
neighbor_loader = NeighborLoader(ddi_graph, |
|
|
|
num_neighbors=[3,2], batch_size=10, |
|
|
|
num_neighbors=[3,2], batch_size=subgraph_batch_size, |
|
|
|
input_nodes=subgraph_mask, |
|
|
|
directed=False, shuffle=True) |
|
|
|
|
|
|
|
|
|
|
|
# sampled_data = next(iter(neighbor_loader)) |
|
|
|
# print(sampled_data) |
|
|
|
# print("---------------") |
|
|
|
# print(type(sampled_data)) |
|
|
|
# print("Sampled_data: ") |
|
|
|
# print(sampled_data.batch_size) |
|
|
|
|
|
|
|
|
|
|
|
for subgraph in neighbor_loader: |
|
|
|
# print("this is subgraph in cross_validation:") |
|
|
|
# print(subgraph) |
|
|
|
optimizer.zero_grad() |
|
|
|
loss = step_batch(model, batch, loss_func, subgraph, gpu_id) |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
epoch_loss += loss.item() |
|
|
|
|
|
|
|
print("epoch_loss: ", epoch_loss) |
|
|
|
# print("epoch_loss: ", epoch_loss) |
|
|
|
return epoch_loss |
|
|
|
|
|
|
|
|
|
|
|
def eval_epoch(model, loader, loss_func, gpu_id=None): |
|
|
|
def eval_epoch(model, loader, loss_func, ddi_graph, gpu_id=None): |
|
|
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
|
|
epoch_loss = 0 |
|
|
|
for batch in loader: |
|
|
|
loss = step_batch(model, batch, loss_func, gpu_id, train=False) |
|
|
|
epoch_loss += loss.item() |
|
|
|
subgraph_mask = find_subgraph_mask(ddi_graph, batch) |
|
|
|
subgraph_batch_size = len(subgraph_mask) |
|
|
|
neighbor_loader = NeighborLoader(ddi_graph, |
|
|
|
num_neighbors=[3,2], batch_size=subgraph_batch_size, |
|
|
|
input_nodes=subgraph_mask, |
|
|
|
directed=False, shuffle=True) |
|
|
|
for subgraph in neighbor_loader: |
|
|
|
# print("this is subgraph in cross_validation:") |
|
|
|
# print(subgraph) |
|
|
|
loss = step_batch(model, batch, loss_func, subgraph, gpu_id, train=False) |
|
|
|
epoch_loss += loss.item() |
|
|
|
return epoch_loss |
|
|
|
|
|
|
|
|
|
|
@@ -110,13 +132,16 @@ def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch |
|
|
|
# print('this is in train: ', ddi_graph) |
|
|
|
min_loss = float('inf') |
|
|
|
angry = 0 |
|
|
|
|
|
|
|
for epoch in range(1, n_epoch + 1): |
|
|
|
start_time = time.time() |
|
|
|
print("epoch: ",str(epoch)) |
|
|
|
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, ddi_graph, gpu_id) |
|
|
|
print("train loss: ", trn_loss) |
|
|
|
trn_loss /= train_loader.dataset_len |
|
|
|
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id) |
|
|
|
# print("train loss: ", trn_loss) |
|
|
|
val_loss = eval_epoch(model, valid_loader, loss_func, ddi_graph, gpu_id) |
|
|
|
val_loss /= valid_loader.dataset_len |
|
|
|
# print("val loss: ", val_loss) |
|
|
|
print("loss epoch " + str(epoch) + ": " + str(trn_loss)) |
|
|
|
if val_loss < min_loss: |
|
|
|
angry = 0 |
|
|
@@ -129,6 +154,8 @@ def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch |
|
|
|
angry += 1 |
|
|
|
if angry >= patience: |
|
|
|
break |
|
|
|
end_time = time.time() |
|
|
|
print("epoch time (min): ", (end_time - start_time)/60) |
|
|
|
if sl: |
|
|
|
model.load_state_dict(torch.load(find_best_model(mdl_dir))) |
|
|
|
return min_loss |