| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 | import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import sys 
PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
sys.path.insert(0, PROJ_DIR)
from drug.models import GCN
from drug.datasets import DDInteractionDataset
from model.utils import get_FP_by_negative_index
class Connector(nn.Module):
    def __init__(self, gpu_id=None):
        self.gpu_id = gpu_id
        super(Connector, self).__init__()
        
        self.gcn = None
        
        
        
    def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
        if self.gcn == None:
            
            self.gcn = GCN(subgraph.num_features, subgraph.num_features // 2)
        
        
        
        
        x = subgraph.x
        edge_index = subgraph.edge_index
        x = self.gcn(x, edge_index)
        
        
        node_indices = edge_index.flatten().unique()
        
        
        
        node_indices = subgraph.n_id
        if self.gpu_id is not None:
            node_indices = node_indices.cuda(self.gpu_id)
        
        
        drug1_idx = torch.flatten(drug1_idx)
        drug2_idx = torch.flatten(drug2_idx)
        
        
        drug1_feat = torch.empty((len(drug1_idx), len(x[0])))
        drug2_feat = torch.empty((len(drug2_idx), len(x[0])))
        for index, element in enumerate(drug1_idx):
            x_element = element
            if element >= 0:
                x_element = (node_indices == element).nonzero().squeeze()
                drug1_feat[index] = (x[x_element])
        for index, element in enumerate(drug2_idx):
            x_element = element
            if element >= 0:
                x_element = (node_indices == element).nonzero().squeeze()
                drug2_feat[index] = (x[x_element])
        if self.gpu_id is not None:
            drug1_feat = drug1_feat.cuda(self.gpu_id)
            drug2_feat = drug2_feat.cuda(self.gpu_id)
        for i, x in enumerate(drug1_idx):
            if x < 0:
                drug1_feat[i] = get_FP_by_negative_index(x)
        for i, x in enumerate(drug2_idx):
            if x < 0:
                drug2_feat[i] = get_FP_by_negative_index(x)
        feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
        return feat
class MLP(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, gpu_id=None):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_size // 2),
            nn.Linear(hidden_size // 2, 1)
        )
        self.connector = Connector(gpu_id)
    
    
    def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
        feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph)
        out = self.layers(feat)
        return out
 |