| @@ -15,6 +15,7 @@ from model.utils import get_FP_by_negative_index | |||
| class Connector(nn.Module): | |||
| def __init__(self, gpu_id=None): | |||
| self.gpu_id = gpu_id | |||
| super(Connector, self).__init__() | |||
| self.ddiDataset = DDInteractionDataset() | |||
| @@ -29,8 +30,17 @@ class Connector(nn.Module): | |||
| x = self.gcn(x, edge_index) | |||
| drug1_idx = torch.flatten(drug1_idx) | |||
| drug2_idx = torch.flatten(drug2_idx) | |||
| drug1_feat = x[drug1_idx] | |||
| drug2_feat = x[drug2_idx] | |||
| #drug1_feat = x[drug1_idx] | |||
| #drug2_feat = x[drug2_idx] | |||
| drug1_feat = torch.empty((len(drug1_idx), len(x[0]))) | |||
| drug2_feat = torch.empty((len(drug2_idx), len(x[0]))) | |||
| for index, element in enumerate(drug1_idx): | |||
| drug1_feat[index] = (x[element]) | |||
| for index, element in enumerate(drug2_idx): | |||
| drug2_feat[index] = (x[element]) | |||
| if self.gpu_id is not None: | |||
| drug1_feat = drug1_feat.cuda(self.gpu_id) | |||
| drug2_feat = drug2_feat.cuda(self.gpu_id) | |||
| for i, x in enumerate(drug1_idx): | |||
| if x < 0: | |||
| drug1_feat[i] = get_FP_by_negative_index(x) | |||