| class Connector(nn.Module): | class Connector(nn.Module): | ||||
| def __init__(self, gpu_id=None): | def __init__(self, gpu_id=None): | ||||
| self.gpu_id = gpu_id | |||||
| super(Connector, self).__init__() | super(Connector, self).__init__() | ||||
| self.ddiDataset = DDInteractionDataset() | self.ddiDataset = DDInteractionDataset() | ||||
| x = self.gcn(x, edge_index) | x = self.gcn(x, edge_index) | ||||
| drug1_idx = torch.flatten(drug1_idx) | drug1_idx = torch.flatten(drug1_idx) | ||||
| drug2_idx = torch.flatten(drug2_idx) | drug2_idx = torch.flatten(drug2_idx) | ||||
| drug1_feat = x[drug1_idx] | |||||
| drug2_feat = x[drug2_idx] | |||||
| #drug1_feat = x[drug1_idx] | |||||
| #drug2_feat = x[drug2_idx] | |||||
| drug1_feat = torch.empty((len(drug1_idx), len(x[0]))) | |||||
| drug2_feat = torch.empty((len(drug2_idx), len(x[0]))) | |||||
| for index, element in enumerate(drug1_idx): | |||||
| drug1_feat[index] = (x[element]) | |||||
| for index, element in enumerate(drug2_idx): | |||||
| drug2_feat[index] = (x[element]) | |||||
| if self.gpu_id is not None: | |||||
| drug1_feat = drug1_feat.cuda(self.gpu_id) | |||||
| drug2_feat = drug2_feat.cuda(self.gpu_id) | |||||
| for i, x in enumerate(drug1_idx): | for i, x in enumerate(drug1_idx): | ||||
| if x < 0: | if x < 0: | ||||
| drug1_feat[i] = get_FP_by_negative_index(x) | drug1_feat[i] = get_FP_by_negative_index(x) |