| from model import DeepTraCDR, ModelOptimizer | |||||
| from data_sampler import NewSampler | |||||
| import torch | |||||
| def run_DeepTraCDR(cell_exprs, drug_finger, res_mat, null_mask, target_dim, target_index, | |||||
| evaluate_fun, args): | |||||
| # Convert to numpy if needed | |||||
| if isinstance(res_mat, torch.Tensor): | |||||
| res_mat = res_mat.numpy() | |||||
| if isinstance(null_mask, torch.Tensor): | |||||
| null_mask = null_mask.numpy() | |||||
| # Initialize sampler with required parameters | |||||
| sampler = NewSampler(res_mat, null_mask, target_dim, target_index) | |||||
| # Initialize model with correct device parameter name | |||||
| model = DeepTraCDR( | |||||
| adj_mat=sampler.train_data, | |||||
| cell_exprs=cell_exprs, | |||||
| drug_fingerprints=drug_finger, | |||||
| layer_size=args.layer_size, | |||||
| gamma=args.gamma, | |||||
| device=args.device | |||||
| ) | |||||
| # Initialize optimizer with correct parameter order | |||||
| opt = ModelOptimizer( | |||||
| model=model, | |||||
| train_data=sampler.train_data, | |||||
| test_data=sampler.test_data, | |||||
| test_mask=sampler.test_mask, | |||||
| train_mask=sampler.train_mask, | |||||
| adj_matrix=res_mat, | |||||
| evaluate_fun=evaluate_fun, | |||||
| lr=args.lr, | |||||
| wd=args.wd, | |||||
| epochs=args.epochs, | |||||
| test_freq=args.test_freq if hasattr(args, 'test_freq') else 20, | |||||
| device=args.device | |||||
| ) | |||||
| # Call train method and unpack results | |||||
| results = opt.train() | |||||
| return results[0], results[1] # true_data, predict_data, auc_data |