class DDInteractionDataset(Dataset): | class DDInteractionDataset(Dataset): | ||||
def __init__(self, root = "\\drug/data/", transform=None, pre_transform=None, pre_filter=None): | |||||
def __init__(self, root = "\\drug/data/", transform=None, pre_transform=None, pre_filter=None, gpu_id=None): | |||||
self.gpu_id = gpu_id | |||||
super(DDInteractionDataset, self).__init__(os.path.dirname(os.path.abspath(os.path.dirname( __file__ ))) + "/drug/data/", transform, pre_transform, pre_filter) | super(DDInteractionDataset, self).__init__(os.path.dirname(os.path.abspath(os.path.dirname( __file__ ))) + "/drug/data/", transform, pre_transform, pre_filter) | ||||
@property | @property | ||||
def num_features(self): | def num_features(self): | ||||
return self._num_features | return self._num_features | ||||
# --------------------------------------------------------------- | # --------------------------------------------------------------- | ||||
data = Data(x = node_features, edge_index = edge_index) | data = Data(x = node_features, edge_index = edge_index) | ||||
if self.gpu_id is not None: | |||||
data = data.cuda(self.gpu_id) | |||||
if self.pre_filter is not None and not self.pre_filter(data): | if self.pre_filter is not None and not self.pre_filter(data): | ||||
pass | pass | ||||
data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt')) | data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt')) | ||||
return data | return data | ||||
ddiDataset = DDInteractionDataset(root = "drug/data/") | |||||
print(ddiDataset.get().edge_index.t()) | |||||
# run for checking | |||||
# ddiDataset = DDInteractionDataset(root = "drug/data/") | |||||
# print(ddiDataset.get().edge_index.t()) | |||||
# print(ddiDataset.get().x) | # print(ddiDataset.get().x) | ||||
print(ddiDataset.num_features) | |||||
# print(ddiDataset.num_features) |
# base from this notebook: https://colab.research.google.com/drive/1LJir3T6M6Omc2Vn2GV2cDW_GV2YfI53_?usp=sharing#scrollTo=jNsToorfSgS0 | # base from this notebook: https://colab.research.google.com/drive/1LJir3T6M6Omc2Vn2GV2cDW_GV2YfI53_?usp=sharing#scrollTo=jNsToorfSgS0 | ||||
class GCN(torch.nn.Module): | class GCN(torch.nn.Module): | ||||
def __init__(self, num_features, hidden_channels, gpu_id=None): # num_features = dataset.num_features | |||||
def __init__(self, num_features, hidden_channels): # num_features = dataset.num_features | |||||
super(GCN, self).__init__() | super(GCN, self).__init__() | ||||
torch.manual_seed(42) | torch.manual_seed(42) | ||||
# Initialize the layers | # Initialize the layers | ||||
self.conv1 = GCNConv(num_features, hidden_channels) | self.conv1 = GCNConv(num_features, hidden_channels) | ||||
self.conv2 = GCNConv(hidden_channels, num_features) | self.conv2 = GCNConv(hidden_channels, num_features) | ||||
self.gpu_id = gpu_id | |||||
def forward(self, x, edge_index): | def forward(self, x, edge_index): | ||||
# First Message Passing Layer (Transformation) | # First Message Passing Layer (Transformation) | ||||
x = x.to(torch.float32) | x = x.to(torch.float32) | ||||
if self.gpu_id is not None: | |||||
x = x.cuda(self.gpu_id) | |||||
edge_index = edge_index.cuda(self.gpu_id) | |||||
x = self.conv1(x, edge_index) | x = self.conv1(x, edge_index) | ||||
x = x.relu() | x = x.relu() | ||||
x = F.dropout(x, p=0.5, training=self.training) | x = F.dropout(x, p=0.5, training=self.training) |
drug1_id = drug1_id.cuda(gpu_id) | drug1_id = drug1_id.cuda(gpu_id) | ||||
drug2_id = drug2_id.cuda(gpu_id) | drug2_id = drug2_id.cuda(gpu_id) | ||||
cell_feat = cell_feat.cuda(gpu_id) | cell_feat = cell_feat.cuda(gpu_id) | ||||
pass | |||||
y_true = y_true.cuda(gpu_id) | |||||
if train: | if train: | ||||
y_pred = model(drug1_id, drug2_id, cell_feat) | y_pred = model(drug1_id, drug2_id, cell_feat) | ||||
else: | else: | ||||
logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss)) | logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss)) | ||||
logging.info("-" * n_delimiter) | logging.info("-" * n_delimiter) | ||||
losses.append(inner_loss) | losses.append(inner_loss) | ||||
torch.cuda.memory_summary(device=None, abbreviated=False) | |||||
gc.collect() | gc.collect() | ||||
torch.cuda.empty_cache() | torch.cuda.empty_cache() | ||||
time.sleep(10) | time.sleep(10) |
class Connector(nn.Module): | class Connector(nn.Module): | ||||
def __init__(self, gpu_id=None): | def __init__(self, gpu_id=None): | ||||
super(Connector, self).__init__() | super(Connector, self).__init__() | ||||
self.ddiDataset = DDInteractionDataset() | |||||
self.gcn = GCN(self.ddiDataset.num_features, self.ddiDataset.num_features // 2, gpu_id) | |||||
self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id) | |||||
self.gcn = GCN(self.ddiDataset.num_features, self.ddiDataset.num_features // 2) | |||||
#Cell line features | #Cell line features | ||||
# np.load('cell_feat.npy') | # np.load('cell_feat.npy') |