Browse Source

Solve big graph problem

main
MahsaYazdani 1 year ago
parent
commit
ad923450e8
3 changed files with 82 additions and 29 deletions
  1. 24
    3
      drug/datasets.py
  2. 42
    13
      predictor/cross_validation.py
  3. 16
    13
      predictor/model/models.py

+ 24
- 3
drug/datasets.py View File

import pandas as pd import pandas as pd
import torch import torch
from torch_geometric.data import Dataset, Data from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.loader import NeighborLoader
import numpy as np import numpy as np
import random import random
import os import os


@property @property
def processed_file_names(self): def processed_file_names(self):
return ['ddi_processed.pt']
return ['ddi_graph_dataset.pt']
@property @property
def raw_dir(self): def raw_dir(self):
return data return data


# run for checking # run for checking
# ddiDataset = DDInteractionDataset(root = "drug/data/")
ddiDataset = DDInteractionDataset(root = "drug/data/")
print(ddiDataset.get()) # type of ddiDataset.get() is torch_geometric.data.data.Data
# print(ddiDataset.get().edge_index.t()) # print(ddiDataset.get().edge_index.t())
# print(ddiDataset.get().x) # print(ddiDataset.get().x)
# print(ddiDataset.num_features)
# print(ddiDataset.num_features)

# test for data batch loading
# dataloader = DataLoader(ddiDataset, batch_size=10)
# for data in dataloader:
# print(data)



# # true batching way for the knowledge graph
# data = ddiDataset.get()
# print(len(data))
# neighbor_loader = NeighborLoader(data,
# num_neighbors=[3,2], batch_size=100,
# directed=False, shuffle=True)


# for batch in neighbor_loader:
# print(batch.x, batch.edge_index)

+ 42
- 13
predictor/cross_validation.py View File



from model.datasets import FastSynergyDataset, FastTensorDataLoader from model.datasets import FastSynergyDataset, FastTensorDataLoader
from model.models import MLP from model.models import MLP
from drug.datasets import DDInteractionDataset
from model.utils import save_args, save_best_model, find_best_model, arg_min, random_split_indices, calc_stat, conf_inv from model.utils import save_args, save_best_model, find_best_model, arg_min, random_split_indices, calc_stat, conf_inv
from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE
from torch_geometric.loader import NeighborLoader


def find_subgraph_mask(ddi_graph, batch):
print('----------------------------------')
drug1_id, drug2_id, cell_feat, y_true = batch
drug1_id = torch.flatten(drug1_id)
drug2_id = torch.flatten(drug2_id)
drug_ids = torch.cat((drug1_id, drug2_id), 0)
drug_ids, count = drug_ids.unique(return_counts=True)
print(drug_ids)
return drug_ids


def eval_model(model, optimizer, loss_func, train_data, test_data, def eval_model(model, optimizer, loss_func, train_data, test_data,
return test_loss return test_loss




def step_batch(model, batch, loss_func, gpu_id=None, train=True):
def step_batch(model, batch, loss_func, subgraph, gpu_id=None, train=True):
drug1_id, drug2_id, cell_feat, y_true = batch drug1_id, drug2_id, cell_feat, y_true = batch
if gpu_id is not None: if gpu_id is not None:
drug1_id = drug1_id.cuda(gpu_id) drug1_id = drug1_id.cuda(gpu_id)
y_true = y_true.cuda(gpu_id) y_true = y_true.cuda(gpu_id)


if train: if train:
y_pred = model(drug1_id, drug2_id, cell_feat)
y_pred = model(drug1_id, drug2_id, cell_feat, subgraph)
else: else:
yp1 = model(drug1_id, drug2_id, cell_feat)
yp2 = model(drug2_id, drug1_id, cell_feat)
yp1 = model(drug1_id, drug2_id, cell_feat, subgraph)
yp2 = model(drug2_id, drug1_id, cell_feat, subgraph)
y_pred = (yp1 + yp2) / 2 y_pred = (yp1 + yp2) / 2
loss = loss_func(y_pred, y_true) loss = loss_func(y_pred, y_true)
return loss return loss




def train_epoch(model, loader, loss_func, optimizer, gpu_id=None):
def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None):
print('this is in train epoch: ', ddi_graph)
model.train() model.train()
epoch_loss = 0 epoch_loss = 0
for _, batch in enumerate(loader): for _, batch in enumerate(loader):
optimizer.zero_grad()
loss = step_batch(model, batch, loss_func, gpu_id)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# TODO: for batch ... desired subgraph
subgraph_mask = find_subgraph_mask(ddi_graph, batch)
print(ddi_graph)
neighbor_loader = NeighborLoader(ddi_graph,
num_neighbors=[3,2], batch_size=100,
input_nodes=subgraph_mask,
directed=False, shuffle=True)
for subgraph in neighbor_loader:
optimizer.zero_grad()
loss = step_batch(model, batch, loss_func, subgraph, gpu_id)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
return epoch_loss return epoch_loss








def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id, def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
ddi_graph,
sl=False, mdl_dir=None): sl=False, mdl_dir=None):
print('this is in train: ', ddi_graph)
min_loss = float('inf') min_loss = float('inf')
angry = 0 angry = 0
for epoch in range(1, n_epoch + 1): for epoch in range(1, n_epoch + 1):
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, gpu_id)
trn_loss = train_epoch(model, train_loader, loss_func, optimizer, ddi_graph, gpu_id)
trn_loss /= train_loader.dataset_len trn_loss /= train_loader.dataset_len
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id) val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id)
val_loss /= valid_loader.dataset_len val_loss /= valid_loader.dataset_len




def create_model(data, hidden_size, gpu_id=None): def create_model(data, hidden_size, gpu_id=None):
# TODO: use our own MLP model
# get 256 # get 256
model = MLP(data.cell_feat_len() + 2 * 256, hidden_size, gpu_id) model = MLP(data.cell_feat_len() + 2 * 256, hidden_size, gpu_id)
if gpu_id is not None: if gpu_id is not None:
else: else:
gpu_id = None gpu_id = None


ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
ddi_graph = ddiDataset.get()

n_folds = 5 n_folds = 5
n_delimiter = 60 n_delimiter = 60
loss_func = nn.MSELoss(reduction='sum') loss_func = nn.MSELoss(reduction='sum')
logging.info( logging.info(
"Start inner loop: train folds {}, valid folds {}".format(inner_trn_folds, valid_folds)) "Start inner loop: train folds {}, valid folds {}".format(inner_trn_folds, valid_folds))
ret = train_model(model, optimizer, loss_func, train_loader, valid_loader, ret = train_model(model, optimizer, loss_func, train_loader, valid_loader,
args.epoch, args.patience, gpu_id, sl=False)
args.epoch, args.patience, gpu_id, ddi_graph = ddi_graph, sl=False, )
ret_vals.append(ret) ret_vals.append(ret)
del model del model



+ 16
- 13
predictor/model/models.py View File

class Connector(nn.Module): class Connector(nn.Module):
def __init__(self, gpu_id=None): def __init__(self, gpu_id=None):
super(Connector, self).__init__() super(Connector, self).__init__()
self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
self.gcn = GCN(self.ddiDataset.num_features, self.ddiDataset.num_features // 2)
# self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
self.gcn = None
#Cell line features #Cell line features
# np.load('cell_feat.npy') # np.load('cell_feat.npy')


def forward(self, drug1_idx, drug2_idx, cell_feat):
x = self.ddiDataset.get().x
edge_index = self.ddiDataset.get().edge_index
def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
if self.gcn == None:
self.gcn = GCN(subgraph.num_features, subgraph.num_features // 2)
x = subgraph.get().x
edge_index = subgraph.edge_index
x = self.gcn(x, edge_index) x = self.gcn(x, edge_index)
drug1_idx = torch.flatten(drug1_idx) drug1_idx = torch.flatten(drug1_idx)
drug2_idx = torch.flatten(drug2_idx) drug2_idx = torch.flatten(drug2_idx)
drug1_feat = x[drug1_idx] drug1_feat = x[drug1_idx]
drug2_feat = x[drug2_idx] drug2_feat = x[drug2_idx]
for i, x in enumerate(drug1_idx):
if x < 0:
drug1_feat[i] = get_FP_by_negative_index(x)
for i, x in enumerate(drug2_idx):
if x < 0:
drug2_feat[i] = get_FP_by_negative_index(x)
for i, idx in enumerate(drug1_idx):
if idx < 0:
drug1_feat[i] = get_FP_by_negative_index(idx)
for i, idx in enumerate(drug2_idx):
if idx < 0:
drug2_feat[i] = get_FP_by_negative_index(idx)
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
return feat return feat




self.connector = Connector(gpu_id) self.connector = Connector(gpu_id)
def forward(self, drug1_idx, drug2_idx, cell_feat): # prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor
feat = self.connector(drug1_idx, drug2_idx, cell_feat)
# prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor, subgraph: related subgraph for the batch
def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph)
out = self.layers(feat) out = self.layers(feat)
return out return out



Loading…
Cancel
Save