Browse Source

resolve all bugs related to DDI graph (batched)

main
MahsaYazdani 1 year ago
parent
commit
7ed13bfd09
2 changed files with 60 additions and 15 deletions
  1. 37
    10
      predictor/cross_validation.py
  2. 23
    5
      predictor/model/models.py

+ 37
- 10
predictor/cross_validation.py View File

@@ -27,7 +27,8 @@ def find_subgraph_mask(ddi_graph, batch):
drug_ids = torch.cat((drug1_id, drug2_id), 0)
drug_ids, count = drug_ids.unique(return_counts=True)
# Removing negative ids
drug_ids = (drug_ids >= 0).nonzero(as_tuple=True)[0]
drug_id_range = (drug_ids >= 0).nonzero(as_tuple=True)[0]
drug_ids = torch.index_select(drug_ids, 0, drug_id_range)
return drug_ids

@@ -66,41 +67,62 @@ def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None):
# print('this is in train epoch: ', ddi_graph)
model.train()
epoch_loss = 0

ddi_graph.n_id = torch.arange(ddi_graph.num_nodes)
for _, batch in enumerate(loader):
# print("len of batch loader: ",len(loader))
for i, batch in enumerate(loader):
# print(f"this is batch {i} ------------------")
# TODO: for batch ... desired subgraph
subgraph_mask = find_subgraph_mask(ddi_graph, batch)
# print("this is subgraph mask: ---------------")
# print(subgraph_mask)
# print("subgraph_mask: ",subgraph_mask) ---> subgraph_mask is OK
# print("This is ddi_graph:")
# print(ddi_graph) ---> ddi_graph is OK
subgraph_batch_size = len(subgraph_mask)
neighbor_loader = NeighborLoader(ddi_graph,
num_neighbors=[3,2], batch_size=10,
num_neighbors=[3,2], batch_size=subgraph_batch_size,
input_nodes=subgraph_mask,
directed=False, shuffle=True)

# sampled_data = next(iter(neighbor_loader))
# print(sampled_data)
# print("---------------")
# print(type(sampled_data))
# print("Sampled_data: ")
# print(sampled_data.batch_size)
for subgraph in neighbor_loader:
# print("this is subgraph in cross_validation:")
# print(subgraph)
optimizer.zero_grad()
loss = step_batch(model, batch, loss_func, subgraph, gpu_id)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("epoch_loss: ", epoch_loss)
# print("epoch_loss: ", epoch_loss)
return epoch_loss


def eval_epoch(model, loader, loss_func, gpu_id=None):
def eval_epoch(model, loader, loss_func, ddi_graph, gpu_id=None):
model.eval()
with torch.no_grad():
epoch_loss = 0
for batch in loader:
loss = step_batch(model, batch, loss_func, gpu_id, train=False)
epoch_loss += loss.item()
subgraph_mask = find_subgraph_mask(ddi_graph, batch)
subgraph_batch_size = len(subgraph_mask)
neighbor_loader = NeighborLoader(ddi_graph,
num_neighbors=[3,2], batch_size=subgraph_batch_size,
input_nodes=subgraph_mask,
directed=False, shuffle=True)
for subgraph in neighbor_loader:
# print("this is subgraph in cross_validation:")
# print(subgraph)
loss = step_batch(model, batch, loss_func, subgraph, gpu_id, train=False)
epoch_loss += loss.item()
return epoch_loss


@@ -110,13 +132,16 @@ def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch
# print('this is in train: ', ddi_graph)
min_loss = float('inf')
angry = 0
for epoch in range(1, n_epoch + 1):
start_time = time.time()
print("epoch: ",str(epoch))
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, ddi_graph, gpu_id)
print("train loss: ", trn_loss)
trn_loss /= train_loader.dataset_len
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id)
# print("train loss: ", trn_loss)
val_loss = eval_epoch(model, valid_loader, loss_func, ddi_graph, gpu_id)
val_loss /= valid_loader.dataset_len
# print("val loss: ", val_loss)
print("loss epoch " + str(epoch) + ": " + str(trn_loss))
if val_loss < min_loss:
angry = 0
@@ -129,6 +154,8 @@ def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch
angry += 1
if angry >= patience:
break
end_time = time.time()
print("epoch time (min): ", (end_time - start_time)/60)
if sl:
model.load_state_dict(torch.load(find_best_model(mdl_dir)))
return min_loss

+ 23
- 5
predictor/model/models.py View File

@@ -25,15 +25,27 @@ class Connector(nn.Module):

def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
if self.gcn == None:
print("here is for initializing the GCN. num_features: ", subgraph.num_features)
# print("here is for initializing the GCN. num_features: ", subgraph.num_features)
self.gcn = GCN(subgraph.num_features, subgraph.num_features // 2)
print("this is subgraph: --------------")
print(subgraph)
# print("this is subgraph in connector model forward: --------------")
# print(subgraph)
# graph.get().x --> DDInteractionDataset
# subgraph = graph.get() --> Data
x = subgraph.x
edge_index = subgraph.edge_index
x = self.gcn(x, edge_index)
# print("node local indices:")
node_indices = edge_index.flatten().unique()
# print(node_indices)

# print("-----------------------")
# print("node global indices:")
node_indices = subgraph.n_id
if self.gpu_id is not None:
node_indices = node_indices.cuda(self.gpu_id)
# print(node_indices)
drug1_idx = torch.flatten(drug1_idx)
drug2_idx = torch.flatten(drug2_idx)
#drug1_feat = x[drug1_idx]
@@ -41,9 +53,15 @@ class Connector(nn.Module):
drug1_feat = torch.empty((len(drug1_idx), len(x[0])))
drug2_feat = torch.empty((len(drug2_idx), len(x[0])))
for index, element in enumerate(drug1_idx):
drug1_feat[index] = (x[element])
x_element = element
if element >= 0:
x_element = (node_indices == element).nonzero().squeeze()
drug1_feat[index] = (x[x_element])
for index, element in enumerate(drug2_idx):
drug2_feat[index] = (x[element])
x_element = element
if element >= 0:
x_element = (node_indices == element).nonzero().squeeze()
drug2_feat[index] = (x[x_element])
if self.gpu_id is not None:
drug1_feat = drug1_feat.cuda(self.gpu_id)
drug2_feat = drug2_feat.cuda(self.gpu_id)

Loading…
Cancel
Save