Browse Source

resolve NeighborLoader exit with no error message

main
MahsaYazdani 1 year ago
parent
commit
3472adab5f
3 changed files with 26 additions and 10 deletions
  1. 18
    7
      predictor/cross_validation.py
  2. 2
    2
      predictor/model/datasets.py
  3. 6
    1
      predictor/model/models.py

+ 18
- 7
predictor/cross_validation.py View File

@@ -20,13 +20,14 @@ from torch_geometric.loader import NeighborLoader


def find_subgraph_mask(ddi_graph, batch):
print('----------------------------------')
# print('----------------------------------')
drug1_id, drug2_id, cell_feat, y_true = batch
drug1_id = torch.flatten(drug1_id)
drug2_id = torch.flatten(drug2_id)
drug_ids = torch.cat((drug1_id, drug2_id), 0)
drug_ids, count = drug_ids.unique(return_counts=True)
print(drug_ids)
# Removing negative ids
drug_ids = (drug_ids >= 0).nonzero(as_tuple=True)[0]
return drug_ids

@@ -62,26 +63,34 @@ def step_batch(model, batch, loss_func, subgraph, gpu_id=None, train=True):


def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None):
print('this is in train epoch: ', ddi_graph)
# print('this is in train epoch: ', ddi_graph)
model.train()
epoch_loss = 0
for _, batch in enumerate(loader):
# TODO: for batch ... desired subgraph
subgraph_mask = find_subgraph_mask(ddi_graph, batch)
print(ddi_graph)
# print("subgraph_mask: ",subgraph_mask) ---> subgraph_mask is OK
# print("This is ddi_graph:")
# print(ddi_graph) ---> ddi_graph is OK
neighbor_loader = NeighborLoader(ddi_graph,
num_neighbors=[3,2], batch_size=100,
num_neighbors=[3,2], batch_size=10,
input_nodes=subgraph_mask,
directed=False, shuffle=True)

# sampled_data = next(iter(neighbor_loader))
# print("Sampled_data: ")
# print(sampled_data.batch_size)
for subgraph in neighbor_loader:
optimizer.zero_grad()
loss = step_batch(model, batch, loss_func, subgraph, gpu_id)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("epoch_loss: ", epoch_loss)
return epoch_loss


@@ -98,11 +107,13 @@ def eval_epoch(model, loader, loss_func, gpu_id=None):
def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
ddi_graph,
sl=False, mdl_dir=None):
print('this is in train: ', ddi_graph)
# print('this is in train: ', ddi_graph)
min_loss = float('inf')
angry = 0
for epoch in range(1, n_epoch + 1):
print("epoch: ",str(epoch))
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, ddi_graph, gpu_id)
print("train loss: ", trn_loss)
trn_loss /= train_loader.dataset_len
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id)
val_loss /= valid_loader.dataset_len

+ 2
- 2
predictor/model/datasets.py View File

@@ -111,8 +111,8 @@ class FastSynergyDataset(Dataset):
# print('-----------------------------')
# print(self.samples)
# print('-----------------')
print(self.samples[0])
print(self.samples[i][0] and self.samples[i][1] for i in indices)
# print(self.samples[0])
# print(self.samples[i][0] and self.samples[i][1] for i in indices)
d1 = torch.cat([torch.unsqueeze(self.samples[i][0], 0) for i in indices], dim=0)
d2 = torch.cat([torch.unsqueeze(self.samples[i][1], 0) for i in indices], dim=0)
c = torch.cat([torch.unsqueeze(self.samples[i][2], 0) for i in indices], dim=0)

+ 6
- 1
predictor/model/models.py View File

@@ -25,8 +25,13 @@ class Connector(nn.Module):

def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
if self.gcn == None:
print("here is for initializing the GCN. num_features: ", subgraph.num_features)
self.gcn = GCN(subgraph.num_features, subgraph.num_features // 2)
x = subgraph.get().x
print("this is subgraph: --------------")
print(subgraph)
# graph.get().x --> DDInteractionDataset
# subgraph = graph.get() --> Data
x = subgraph.x
edge_index = subgraph.edge_index
x = self.gcn(x, edge_index)
drug1_idx = torch.flatten(drug1_idx)

Loading…
Cancel
Save