| @@ -0,0 +1,44 @@ | |||
| from model import DeepTraCDR, ModelOptimizer | |||
| from data_sampler import NewSampler | |||
| import torch | |||
| def run_DeepTraCDR(cell_exprs, drug_finger, res_mat, null_mask, target_dim, target_index, | |||
| evaluate_fun, args): | |||
| # Convert to numpy if needed | |||
| if isinstance(res_mat, torch.Tensor): | |||
| res_mat = res_mat.numpy() | |||
| if isinstance(null_mask, torch.Tensor): | |||
| null_mask = null_mask.numpy() | |||
| # Initialize sampler with required parameters | |||
| sampler = NewSampler(res_mat, null_mask, target_dim, target_index) | |||
| # Initialize model with correct device parameter name | |||
| model = DeepTraCDR( | |||
| adj_mat=sampler.train_data, | |||
| cell_exprs=cell_exprs, | |||
| drug_fingerprints=drug_finger, | |||
| layer_size=args.layer_size, | |||
| gamma=args.gamma, | |||
| device=args.device | |||
| ) | |||
| # Initialize optimizer with correct parameter order | |||
| opt = ModelOptimizer( | |||
| model=model, | |||
| train_data=sampler.train_data, | |||
| test_data=sampler.test_data, | |||
| test_mask=sampler.test_mask, | |||
| train_mask=sampler.train_mask, | |||
| adj_matrix=res_mat, | |||
| evaluate_fun=evaluate_fun, | |||
| lr=args.lr, | |||
| wd=args.wd, | |||
| epochs=args.epochs, | |||
| test_freq=args.test_freq if hasattr(args, 'test_freq') else 20, | |||
| device=args.device | |||
| ) | |||
| # Call train method and unpack results | |||
| results = opt.train() | |||
| return results[0], results[1] # true_data, predict_data, auc_data | |||