| @@ -15,6 +15,7 @@ from model.utils import get_FP_by_negative_index | |||
| class Connector(nn.Module): | |||
| def __init__(self, gpu_id=None): | |||
| self.gpu_id = gpu_id | |||
| super(Connector, self).__init__() | |||
| # self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id) | |||
| self.gcn = None | |||
| @@ -30,14 +31,23 @@ class Connector(nn.Module): | |||
| x = self.gcn(x, edge_index) | |||
| drug1_idx = torch.flatten(drug1_idx) | |||
| drug2_idx = torch.flatten(drug2_idx) | |||
| drug1_feat = x[drug1_idx] | |||
| drug2_feat = x[drug2_idx] | |||
| for i, idx in enumerate(drug1_idx): | |||
| if idx < 0: | |||
| drug1_feat[i] = get_FP_by_negative_index(idx) | |||
| for i, idx in enumerate(drug2_idx): | |||
| if idx < 0: | |||
| drug2_feat[i] = get_FP_by_negative_index(idx) | |||
| #drug1_feat = x[drug1_idx] | |||
| #drug2_feat = x[drug2_idx] | |||
| drug1_feat = torch.empty((len(drug1_idx), len(x[0]))) | |||
| drug2_feat = torch.empty((len(drug2_idx), len(x[0]))) | |||
| for index, element in enumerate(drug1_idx): | |||
| drug1_feat[index] = (x[element]) | |||
| for index, element in enumerate(drug2_idx): | |||
| drug2_feat[index] = (x[element]) | |||
| if self.gpu_id is not None: | |||
| drug1_feat = drug1_feat.cuda(self.gpu_id) | |||
| drug2_feat = drug2_feat.cuda(self.gpu_id) | |||
| for i, x in enumerate(drug1_idx): | |||
| if x < 0: | |||
| drug1_feat[i] = get_FP_by_negative_index(x) | |||
| for i, x in enumerate(drug2_idx): | |||
| if x < 0: | |||
| drug2_feat[i] = get_FP_by_negative_index(x) | |||
| feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) | |||
| return feat | |||