123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- import argparse
- import os
- import logging
- import time
- import pickle
- import torch
- import torch.nn as nn
- import gc
-
- from datetime import datetime
-
- time_str = str(datetime.now().strftime('%y%m%d%H%M'))
-
- from model.datasets import FastSynergyDataset, FastTensorDataLoader
- from model.models import MLP
- from model.utils import save_args, save_best_model, find_best_model, arg_min, random_split_indices, calc_stat, conf_inv
- from const import SYNERGY_FILE, CELL_FEAT_FILE, CELL2ID_FILE, OUTPUT_DIR, DRUGNAME_2_DRUGBANKID_FILE, DTI_FEAT_FILE, FP_FEAT_FILE, DRUG2ID_FILE
-
-
- def eval_model(model, optimizer, loss_func, train_data, test_data,
- batch_size, n_epoch, patience, gpu_id, mdl_dir):
- tr_indices, es_indices = random_split_indices(len(train_data), test_rate=0.1)
- train_loader = FastTensorDataLoader(*train_data.tensor_samples(tr_indices), batch_size=batch_size, shuffle=True)
- valid_loader = FastTensorDataLoader(*train_data.tensor_samples(es_indices), batch_size=len(es_indices) // 4)
- test_loader = FastTensorDataLoader(*test_data.tensor_samples(), batch_size=len(test_data) // 4)
- train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
- sl=True, mdl_dir=mdl_dir)
- test_loss = eval_epoch(model, test_loader, loss_func, gpu_id)
- test_loss /= len(test_data)
- test_accuracy, test_recall = eval_acc(model, test_loader, loss_func, gpu_id, len(test_data))
- return test_loss, test_accuracy, test_recall
-
-
- def step_batch(model, batch, loss_func, gpu_id=None, train=True):
- drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat, y_true = batch
- if gpu_id is not None:
- drug1_idx = drug1_idx.cuda(gpu_id)
- drug2_idx = drug2_idx.cuda(gpu_id)
- drug1_fp = drug1_fp.cuda(gpu_id)
- drug2_fp = drug2_fp.cuda(gpu_id)
- drug1_dti = drug1_dti.cuda(gpu_id)
- drug2_dti = drug2_dti.cuda(gpu_id)
- cell_feat = cell_feat.cuda(gpu_id)
- y_true = y_true.cuda(gpu_id)
-
-
- if train:
- y_pred = model(drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat)
- else:
- yp1 = model(drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat)
- yp2 = model(drug1_idx, drug2_idx, drug2_fp, drug1_fp, drug2_dti, drug1_dti, cell_feat)
- y_pred = (yp1 + yp2) / 2
- loss = loss_func(y_pred, y_true)
-
- # weighted loss:
- # indices_true = []
- # indices_false = []
- # for i in range(len(y_pred)):
- # if (y_true[i] > 10 and y_pred[i] > 10) or (y_true[i] < 10 and y_pred[i] < 10):
- # indices_true.append(i)
- # else:
- # indices_false.append(i)
- # loss_true = loss_func(y_pred[indices_true], y_true[indices_true])
- # loss_false = loss_func(y_pred[indices_false], y_true[indices_false])
- # loss = loss_true + 5 * loss_false
- return loss
-
-
- def train_epoch(model, loader, loss_func, optimizer, gpu_id=None):
- model.train()
- epoch_loss = 0
- for _, batch in enumerate(loader):
- optimizer.zero_grad()
- loss = step_batch(model, batch, loss_func, gpu_id)
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- return epoch_loss
-
-
- def eval_epoch(model, loader, loss_func, gpu_id=None):
- model.eval()
- with torch.no_grad():
- epoch_loss = 0
- for batch in loader:
- loss = step_batch(model, batch, loss_func, gpu_id, train=False)
- epoch_loss += loss.item()
- return epoch_loss
-
- def step_acc(model, batch, loss_func, gpu_id=None, train=True):
- drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat, y_true = batch
- if gpu_id is not None:
- drug1_idx = drug1_idx.cuda(gpu_id)
- drug2_idx = drug2_idx.cuda(gpu_id)
- drug1_fp = drug1_fp.cuda(gpu_id)
- drug2_fp = drug2_fp.cuda(gpu_id)
- drug1_dti = drug1_dti.cuda(gpu_id)
- drug2_dti = drug2_dti.cuda(gpu_id)
- cell_feat = cell_feat.cuda(gpu_id)
- y_true = y_true.cuda(gpu_id)
-
-
- if train:
- y_pred = model(drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat)
- else:
- yp1 = model(drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat)
- yp2 = model(drug1_idx, drug2_idx, drug2_fp, drug1_fp, drug2_dti, drug1_dti, cell_feat)
- y_pred = (yp1 + yp2) / 2
- TP = 0
- TN = 0
- FN = 0
- loss = loss_func(y_pred, y_true)
- for i in range(len(y_pred)):
- if (y_true[i] > 10 and y_pred[i] > 10) or (y_true[i] < 10 and y_pred[i] < 10):
- if y_true[i] > 0:
- TP += 1
- else:
- TN += 1
- elif y_true[i] > 0:
- FN += 1
- return TP, TN, FN
-
- def eval_acc(model, loader, loss_func, gpu_id, test_len):
- model.eval()
- with torch.no_grad():
- TP = 0
- TN = 0
- FN = 0
- for batch in loader:
- t1, t2, t3 = step_acc(model, batch, loss_func, gpu_id, train=False)
- TP += t1
- TN += t2
- FN += t3
- acc = (TP + TN) / test_len
- recall = TP / (TP + FN)
- return acc, recall
-
-
- def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
- sl=False, mdl_dir=None):
- min_loss = float('inf')
- angry = 0
- for epoch in range(1, n_epoch + 1):
- trn_loss = train_epoch(model, train_loader, loss_func, optimizer, gpu_id)
- trn_loss /= train_loader.dataset_len
- val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id)
- val_loss /= valid_loader.dataset_len
- print("loss epoch " + str(epoch) + ": " + str(trn_loss))
- if val_loss < min_loss:
- angry = 0
- min_loss = val_loss
- if sl:
- # TODO: save best model in case of need
- save_best_model(model.state_dict(), mdl_dir, epoch, keep=1)
- pass
- else:
- angry += 1
- if angry >= patience:
- break
- if sl:
- model.load_state_dict(torch.load(find_best_model(mdl_dir)))
- return min_loss
-
-
- def create_model(data, hidden_size, gpu_id=None):
- # TODO: use our own MLP model
- # get 256
- model = MLP(data.cell_feat_len() + 2 * 40 + 2 * 256 + 2 * 64, hidden_size, gpu_id)
- if gpu_id is not None:
- model = model.cuda(gpu_id)
- return model
-
-
- def cv(args, out_dir):
-
- gc.collect()
-
- save_args(args, os.path.join(out_dir, 'args.json'))
- test_loss_file = os.path.join(out_dir, 'test_loss.pkl')
- print("cuda available: " + str(torch.cuda.is_available()))
-
- if torch.cuda.is_available() and (args.gpu is not None):
- gpu_id = args.gpu
- else:
- gpu_id = None
-
-
- n_folds = 5
- n_delimiter = 60
- loss_func = nn.MSELoss(reduction='sum')
- test_losses = []
- all_acc = []
- all_recall = []
- for test_fold in range(n_folds):
- outer_trn_folds = [x for x in range(n_folds) if x != test_fold]
- logging.info("Outer: train folds {}, test folds {}".format(outer_trn_folds, test_fold))
- logging.info("-" * n_delimiter)
- param = []
- losses = []
- for hs in args.hidden:
- for lr in args.lr:
- param.append((hs, lr))
- logging.info("Hidden size: {} | Learning rate: {}".format(hs, lr))
- ret_vals = []
- for valid_fold in outer_trn_folds:
- inner_trn_folds = [x for x in outer_trn_folds if x != valid_fold]
- valid_folds = [valid_fold]
- train_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
- SYNERGY_FILE, DTI_FEAT_FILE, FP_FEAT_FILE, use_folds=inner_trn_folds)
- valid_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
- SYNERGY_FILE, DTI_FEAT_FILE, FP_FEAT_FILE, use_folds=valid_folds, train=False)
- train_loader = FastTensorDataLoader(*train_data.tensor_samples(), batch_size=args.batch,
- shuffle=True)
- valid_loader = FastTensorDataLoader(*valid_data.tensor_samples(), batch_size=len(valid_data) // 4)
- model = create_model(train_data, hs, gpu_id)
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
- logging.info(
- "Start inner loop: train folds {}, valid folds {}".format(inner_trn_folds, valid_folds))
- ret = train_model(model, optimizer, loss_func, train_loader, valid_loader,
- args.epoch, args.patience, gpu_id, sl=False)
- ret_vals.append(ret)
- del model
-
- inner_loss = sum(ret_vals) / len(ret_vals)
- logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss))
- logging.info("-" * n_delimiter)
- losses.append(inner_loss)
-
- torch.cuda.memory_summary(device=None, abbreviated=False)
- gc.collect()
- torch.cuda.empty_cache()
- time.sleep(10)
- min_ls, min_idx = arg_min(losses)
- best_hs, best_lr = param[min_idx]
- train_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
- SYNERGY_FILE, DTI_FEAT_FILE, FP_FEAT_FILE, use_folds=outer_trn_folds)
- test_data = FastSynergyDataset(DRUG2ID_FILE, CELL2ID_FILE, CELL_FEAT_FILE,
- SYNERGY_FILE, DTI_FEAT_FILE, FP_FEAT_FILE, use_folds=[test_fold], train=False)
- model = create_model(train_data, best_hs, gpu_id)
- optimizer = torch.optim.Adam(model.parameters(), lr=best_lr)
-
- logging.info("Best hidden size: {} | Best learning rate: {}".format(best_hs, best_lr))
- logging.info("Start test on fold {}.".format([test_fold]))
- test_mdl_dir = os.path.join(out_dir, str(test_fold))
- if not os.path.exists(test_mdl_dir):
- os.makedirs(test_mdl_dir)
- test_loss, acc, recall = eval_model(model, optimizer, loss_func, train_data, test_data,
- args.batch, args.epoch, args.patience, gpu_id, test_mdl_dir)
- all_acc.append(acc)
- all_recall.append(recall)
- test_losses.append(test_loss)
- logging.info("Test loss: {:.4f}".format(test_loss))
- logging.info("Test accuracy: {:.4f}".format(acc))
- logging.info("Test recall: {:.4f}".format(recall))
- logging.info("*" * n_delimiter + '\n')
- logging.info("CV completed")
- with open(test_loss_file, 'wb') as f:
- pickle.dump(test_losses, f)
- mu, sigma = calc_stat(test_losses)
- logging.info("MSE: {:.4f} ± {:.4f}".format(mu, sigma))
- lo, hi = conf_inv(mu, sigma, len(test_losses))
- logging.info("Confidence interval: [{:.4f}, {:.4f}]".format(lo, hi))
- rmse_loss = [x ** 0.5 for x in test_losses]
- mu, sigma = calc_stat(rmse_loss)
- logging.info("RMSE: {:.4f} ± {:.4f}".format(mu, sigma))
- logging.info("average accuracy on test: {}".format(sum(all_acc) / len(all_acc)))
- logging.info("average recall on test: {}".format({sum(all_recall) / len(all_recall)}))
-
-
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument('--epoch', type=int, default=500, help="n epoch")
- parser.add_argument('--batch', type=int, default=256, help="batch size")
- parser.add_argument('--gpu', type=int, default=None, help="cuda device")
- parser.add_argument('--patience', type=int, default=100, help='patience for early stop')
- parser.add_argument('--suffix', type=str, default=time_str, help="model dir suffix")
- parser.add_argument('--hidden', type=int, nargs='+', default=[2048, 4096, 8192], help="hidden size")
- parser.add_argument('--lr', type=float, nargs='+', default=[1e-3, 1e-4, 1e-5], help="learning rate")
- args = parser.parse_args()
- out_dir = os.path.join(OUTPUT_DIR, 'cv_{}'.format(args.suffix))
- if not os.path.exists(out_dir):
- os.makedirs(out_dir)
- log_file = os.path.join(out_dir, 'cv.log')
- logging.basicConfig(filename=log_file,
- format='%(asctime)s %(message)s',
- datefmt='[%Y-%m-%d %H:%M:%S]',
- level=logging.INFO)
- cv(args, out_dir)
-
-
- if __name__ == '__main__':
- main()
|