Browse Source

Solve big graph problem

main
MahsaYazdani 1 year ago
parent
commit
ad923450e8
3 changed files with 82 additions and 29 deletions
  1. 24
    3
      drug/datasets.py
  2. 42
    13
      predictor/cross_validation.py
  3. 16
    13
      predictor/model/models.py

+ 24
- 3
drug/datasets.py View File

@@ -2,6 +2,8 @@ import os.path as osp
import pandas as pd
import torch
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.loader import NeighborLoader
import numpy as np
import random
import os
@@ -28,7 +30,7 @@ class DDInteractionDataset(Dataset):

@property
def processed_file_names(self):
return ['ddi_processed.pt']
return ['ddi_graph_dataset.pt']
@property
def raw_dir(self):
@@ -111,7 +113,26 @@ class DDInteractionDataset(Dataset):
return data

# run for checking
# ddiDataset = DDInteractionDataset(root = "drug/data/")
ddiDataset = DDInteractionDataset(root = "drug/data/")
print(ddiDataset.get()) # type of ddiDataset.get() is torch_geometric.data.data.Data
# print(ddiDataset.get().edge_index.t())
# print(ddiDataset.get().x)
# print(ddiDataset.num_features)
# print(ddiDataset.num_features)

# test for data batch loading
# dataloader = DataLoader(ddiDataset, batch_size=10)
# for data in dataloader:
# print(data)



# # true batching way for the knowledge graph
# data = ddiDataset.get()
# print(len(data))
# neighbor_loader = NeighborLoader(data,
# num_neighbors=[3,2], batch_size=100,
# directed=False, shuffle=True)


# for batch in neighbor_loader:
# print(batch.x, batch.edge_index)

+ 42
- 13
predictor/cross_validation.py View File

@@ -13,8 +13,21 @@ time_str = str(datetime.now().strftime('%y%m%d%H%M'))

from model.datasets import FastSynergyDataset, FastTensorDataLoader
from model.models import MLP
from drug.datasets import DDInteractionDataset
from model.utils import save_args, save_best_model, find_best_model, arg_min, random_split_indices, calc_stat, conf_inv
from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE
from torch_geometric.loader import NeighborLoader


def find_subgraph_mask(ddi_graph, batch):
print('----------------------------------')
drug1_id, drug2_id, cell_feat, y_true = batch
drug1_id = torch.flatten(drug1_id)
drug2_id = torch.flatten(drug2_id)
drug_ids = torch.cat((drug1_id, drug2_id), 0)
drug_ids, count = drug_ids.unique(return_counts=True)
print(drug_ids)
return drug_ids

def eval_model(model, optimizer, loss_func, train_data, test_data,
@@ -30,7 +43,7 @@ def eval_model(model, optimizer, loss_func, train_data, test_data,
return test_loss


def step_batch(model, batch, loss_func, gpu_id=None, train=True):
def step_batch(model, batch, loss_func, subgraph, gpu_id=None, train=True):
drug1_id, drug2_id, cell_feat, y_true = batch
if gpu_id is not None:
drug1_id = drug1_id.cuda(gpu_id)
@@ -39,24 +52,36 @@ def step_batch(model, batch, loss_func, gpu_id=None, train=True):
y_true = y_true.cuda(gpu_id)

if train:
y_pred = model(drug1_id, drug2_id, cell_feat)
y_pred = model(drug1_id, drug2_id, cell_feat, subgraph)
else:
yp1 = model(drug1_id, drug2_id, cell_feat)
yp2 = model(drug2_id, drug1_id, cell_feat)
yp1 = model(drug1_id, drug2_id, cell_feat, subgraph)
yp2 = model(drug2_id, drug1_id, cell_feat, subgraph)
y_pred = (yp1 + yp2) / 2
loss = loss_func(y_pred, y_true)
return loss


def train_epoch(model, loader, loss_func, optimizer, gpu_id=None):
def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None):
print('this is in train epoch: ', ddi_graph)
model.train()
epoch_loss = 0
for _, batch in enumerate(loader):
optimizer.zero_grad()
loss = step_batch(model, batch, loss_func, gpu_id)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# TODO: for batch ... desired subgraph
subgraph_mask = find_subgraph_mask(ddi_graph, batch)
print(ddi_graph)
neighbor_loader = NeighborLoader(ddi_graph,
num_neighbors=[3,2], batch_size=100,
input_nodes=subgraph_mask,
directed=False, shuffle=True)
for subgraph in neighbor_loader:
optimizer.zero_grad()
loss = step_batch(model, batch, loss_func, subgraph, gpu_id)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
return epoch_loss


@@ -71,11 +96,13 @@ def eval_epoch(model, loader, loss_func, gpu_id=None):


def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
ddi_graph,
sl=False, mdl_dir=None):
print('this is in train: ', ddi_graph)
min_loss = float('inf')
angry = 0
for epoch in range(1, n_epoch + 1):
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, gpu_id)
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, ddi_graph, gpu_id)
trn_loss /= train_loader.dataset_len
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id)
val_loss /= valid_loader.dataset_len
@@ -97,7 +124,6 @@ def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch


def create_model(data, hidden_size, gpu_id=None):
# TODO: use our own MLP model
# get 256
model = MLP(data.cell_feat_len() + 2 * 256, hidden_size, gpu_id)
if gpu_id is not None:
@@ -118,6 +144,9 @@ def cv(args, out_dir):
else:
gpu_id = None

ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
ddi_graph = ddiDataset.get()

n_folds = 5
n_delimiter = 60
loss_func = nn.MSELoss(reduction='sum')
@@ -148,7 +177,7 @@ def cv(args, out_dir):
logging.info(
"Start inner loop: train folds {}, valid folds {}".format(inner_trn_folds, valid_folds))
ret = train_model(model, optimizer, loss_func, train_loader, valid_loader,
args.epoch, args.patience, gpu_id, sl=False)
args.epoch, args.patience, gpu_id, ddi_graph = ddi_graph, sl=False, )
ret_vals.append(ret)
del model


+ 16
- 13
predictor/model/models.py View File

@@ -16,26 +16,28 @@ from model.utils import get_FP_by_negative_index
class Connector(nn.Module):
def __init__(self, gpu_id=None):
super(Connector, self).__init__()
self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
self.gcn = GCN(self.ddiDataset.num_features, self.ddiDataset.num_features // 2)
# self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
self.gcn = None
#Cell line features
# np.load('cell_feat.npy')

def forward(self, drug1_idx, drug2_idx, cell_feat):
x = self.ddiDataset.get().x
edge_index = self.ddiDataset.get().edge_index
def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
if self.gcn == None:
self.gcn = GCN(subgraph.num_features, subgraph.num_features // 2)
x = subgraph.get().x
edge_index = subgraph.edge_index
x = self.gcn(x, edge_index)
drug1_idx = torch.flatten(drug1_idx)
drug2_idx = torch.flatten(drug2_idx)
drug1_feat = x[drug1_idx]
drug2_feat = x[drug2_idx]
for i, x in enumerate(drug1_idx):
if x < 0:
drug1_feat[i] = get_FP_by_negative_index(x)
for i, x in enumerate(drug2_idx):
if x < 0:
drug2_feat[i] = get_FP_by_negative_index(x)
for i, idx in enumerate(drug1_idx):
if idx < 0:
drug1_feat[i] = get_FP_by_negative_index(idx)
for i, idx in enumerate(drug2_idx):
if idx < 0:
drug2_feat[i] = get_FP_by_negative_index(idx)
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
return feat

@@ -55,8 +57,9 @@ class MLP(nn.Module):

self.connector = Connector(gpu_id)
def forward(self, drug1_idx, drug2_idx, cell_feat): # prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor
feat = self.connector(drug1_idx, drug2_idx, cell_feat)
# prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor, subgraph: related subgraph for the batch
def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph)
out = self.layers(feat)
return out


Loading…
Cancel
Save