|
|
@@ -15,6 +15,7 @@ from model.utils import get_FP_by_negative_index |
|
|
|
|
|
|
|
class Connector(nn.Module): |
|
|
|
def __init__(self, gpu_id=None): |
|
|
|
self.gpu_id = gpu_id |
|
|
|
super(Connector, self).__init__() |
|
|
|
# self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id) |
|
|
|
self.gcn = None |
|
|
@@ -30,14 +31,23 @@ class Connector(nn.Module): |
|
|
|
x = self.gcn(x, edge_index) |
|
|
|
drug1_idx = torch.flatten(drug1_idx) |
|
|
|
drug2_idx = torch.flatten(drug2_idx) |
|
|
|
drug1_feat = x[drug1_idx] |
|
|
|
drug2_feat = x[drug2_idx] |
|
|
|
for i, idx in enumerate(drug1_idx): |
|
|
|
if idx < 0: |
|
|
|
drug1_feat[i] = get_FP_by_negative_index(idx) |
|
|
|
for i, idx in enumerate(drug2_idx): |
|
|
|
if idx < 0: |
|
|
|
drug2_feat[i] = get_FP_by_negative_index(idx) |
|
|
|
#drug1_feat = x[drug1_idx] |
|
|
|
#drug2_feat = x[drug2_idx] |
|
|
|
drug1_feat = torch.empty((len(drug1_idx), len(x[0]))) |
|
|
|
drug2_feat = torch.empty((len(drug2_idx), len(x[0]))) |
|
|
|
for index, element in enumerate(drug1_idx): |
|
|
|
drug1_feat[index] = (x[element]) |
|
|
|
for index, element in enumerate(drug2_idx): |
|
|
|
drug2_feat[index] = (x[element]) |
|
|
|
if self.gpu_id is not None: |
|
|
|
drug1_feat = drug1_feat.cuda(self.gpu_id) |
|
|
|
drug2_feat = drug2_feat.cuda(self.gpu_id) |
|
|
|
for i, x in enumerate(drug1_idx): |
|
|
|
if x < 0: |
|
|
|
drug1_feat[i] = get_FP_by_negative_index(x) |
|
|
|
for i, x in enumerate(drug2_idx): |
|
|
|
if x < 0: |
|
|
|
drug2_feat[i] = get_FP_by_negative_index(x) |
|
|
|
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) |
|
|
|
return feat |
|
|
|
|