| 
				
			 | 
			
			 | 
			@@ -13,8 +13,21 @@ time_str = str(datetime.now().strftime('%y%m%d%H%M')) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from model.datasets import FastSynergyDataset, FastTensorDataLoader | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from model.models import MLP | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from drug.datasets import DDInteractionDataset | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from model.utils import save_args, save_best_model, find_best_model, arg_min, random_split_indices, calc_stat, conf_inv | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from torch_geometric.loader import NeighborLoader | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def find_subgraph_mask(ddi_graph, batch): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    print('----------------------------------') | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    drug1_id, drug2_id, cell_feat, y_true = batch | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    drug1_id = torch.flatten(drug1_id) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    drug2_id = torch.flatten(drug2_id) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    drug_ids = torch.cat((drug1_id, drug2_id), 0) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    drug_ids, count = drug_ids.unique(return_counts=True) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    print(drug_ids) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    return drug_ids | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def eval_model(model, optimizer, loss_func, train_data, test_data, | 
		
		
	
	
		
			
			| 
				
			 | 
			
			 | 
			@@ -30,7 +43,7 @@ def eval_model(model, optimizer, loss_func, train_data, test_data, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    return test_loss | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def step_batch(model, batch, loss_func, gpu_id=None, train=True): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def step_batch(model, batch, loss_func, subgraph, gpu_id=None, train=True): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    drug1_id, drug2_id, cell_feat, y_true = batch | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    if gpu_id is not None: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        drug1_id = drug1_id.cuda(gpu_id) | 
		
		
	
	
		
			
			| 
				
			 | 
			
			 | 
			@@ -39,24 +52,36 @@ def step_batch(model, batch, loss_func, gpu_id=None, train=True): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        y_true = y_true.cuda(gpu_id) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    if train: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        y_pred = model(drug1_id, drug2_id, cell_feat) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        y_pred = model(drug1_id, drug2_id, cell_feat, subgraph) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    else: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        yp1 = model(drug1_id, drug2_id, cell_feat) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        yp2 = model(drug2_id, drug1_id, cell_feat) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        yp1 = model(drug1_id, drug2_id, cell_feat, subgraph) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        yp2 = model(drug2_id, drug1_id, cell_feat, subgraph) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        y_pred = (yp1 + yp2) / 2 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    loss = loss_func(y_pred, y_true) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    return loss | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def train_epoch(model, loader, loss_func, optimizer, gpu_id=None): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    print('this is in train epoch: ', ddi_graph) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    model.train() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    epoch_loss = 0 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			     | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    for _, batch in enumerate(loader): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        optimizer.zero_grad() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        loss = step_batch(model, batch, loss_func, gpu_id) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        loss.backward() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        optimizer.step() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        epoch_loss += loss.item() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        # TODO: for batch ... desired subgraph | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        subgraph_mask = find_subgraph_mask(ddi_graph, batch) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        print(ddi_graph) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        neighbor_loader  = NeighborLoader(ddi_graph, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                              num_neighbors=[3,2], batch_size=100, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                              input_nodes=subgraph_mask, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                              directed=False, shuffle=True) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			         | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        for subgraph in neighbor_loader: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            optimizer.zero_grad() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            loss = step_batch(model, batch, loss_func, subgraph, gpu_id) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            loss.backward() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            optimizer.step() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            epoch_loss += loss.item() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			             | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    return epoch_loss | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
	
		
			
			| 
				
			 | 
			
			 | 
			@@ -71,11 +96,13 @@ def eval_epoch(model, loader, loss_func, gpu_id=None): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                ddi_graph, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                sl=False, mdl_dir=None): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    print('this is in train: ', ddi_graph) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    min_loss = float('inf') | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    angry = 0 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    for epoch in range(1, n_epoch + 1): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        trn_loss = train_epoch(model, train_loader, loss_func, optimizer, gpu_id) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        trn_loss = train_epoch(model, train_loader, loss_func, optimizer, ddi_graph, gpu_id) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        trn_loss /= train_loader.dataset_len | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        val_loss /= valid_loader.dataset_len | 
		
		
	
	
		
			
			| 
				
			 | 
			
			 | 
			@@ -97,7 +124,6 @@ def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			def create_model(data, hidden_size, gpu_id=None): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    # TODO: use our own MLP model | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    # get 256 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    model = MLP(data.cell_feat_len() + 2 * 256, hidden_size, gpu_id) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    if gpu_id is not None: | 
		
		
	
	
		
			
			| 
				
			 | 
			
			 | 
			@@ -118,6 +144,9 @@ def cv(args, out_dir): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    else: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        gpu_id = None | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    ddiDataset = DDInteractionDataset(gpu_id = gpu_id) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    ddi_graph = ddiDataset.get() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    n_folds = 5 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    n_delimiter = 60 | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    loss_func = nn.MSELoss(reduction='sum') | 
		
		
	
	
		
			
			| 
				
			 | 
			
			 | 
			@@ -148,7 +177,7 @@ def cv(args, out_dir): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                    logging.info( | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                        "Start inner loop: train folds {}, valid folds {}".format(inner_trn_folds, valid_folds)) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                    ret = train_model(model, optimizer, loss_func, train_loader, valid_loader, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                                      args.epoch, args.patience, gpu_id, sl=False) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                                      args.epoch, args.patience, gpu_id, ddi_graph = ddi_graph, sl=False, ) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                    ret_vals.append(ret) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                    del model | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  |