Browse Source

fix tensor indices bug

main
rezav 2 years ago
parent
commit
35f675ca33
1 changed files with 12 additions and 2 deletions
  1. 12
    2
      predictor/model/models.py

+ 12
- 2
predictor/model/models.py View File

@@ -15,6 +15,7 @@ from model.utils import get_FP_by_negative_index

class Connector(nn.Module):
def __init__(self, gpu_id=None):
self.gpu_id = gpu_id
super(Connector, self).__init__()

self.ddiDataset = DDInteractionDataset()
@@ -29,8 +30,17 @@ class Connector(nn.Module):
x = self.gcn(x, edge_index)
drug1_idx = torch.flatten(drug1_idx)
drug2_idx = torch.flatten(drug2_idx)
drug1_feat = x[drug1_idx]
drug2_feat = x[drug2_idx]
#drug1_feat = x[drug1_idx]
#drug2_feat = x[drug2_idx]
drug1_feat = torch.empty((len(drug1_idx), len(x[0])))
drug2_feat = torch.empty((len(drug2_idx), len(x[0])))
for index, element in enumerate(drug1_idx):
drug1_feat[index] = (x[element])
for index, element in enumerate(drug2_idx):
drug2_feat[index] = (x[element])
if self.gpu_id is not None:
drug1_feat = drug1_feat.cuda(self.gpu_id)
drug2_feat = drug2_feat.cuda(self.gpu_id)
for i, x in enumerate(drug1_idx):
if x < 0:
drug1_feat[i] = get_FP_by_negative_index(x)

Loading…
Cancel
Save