1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- from model import DeepTraCDR, ModelOptimizer
- from data_sampler import NewSampler
- import torch
- def run_DeepTraCDR(cell_exprs, drug_finger, res_mat, null_mask, target_dim, target_index,
- evaluate_fun, args):
-
- # Convert to numpy if needed
- if isinstance(res_mat, torch.Tensor):
- res_mat = res_mat.numpy()
- if isinstance(null_mask, torch.Tensor):
- null_mask = null_mask.numpy()
- # Initialize sampler with required parameters
- sampler = NewSampler(res_mat, null_mask, target_dim, target_index)
-
- # Initialize model with correct device parameter name
- model = DeepTraCDR(
- adj_mat=sampler.train_data,
- cell_exprs=cell_exprs,
- drug_fingerprints=drug_finger,
- layer_size=args.layer_size,
- gamma=args.gamma,
- device=args.device
- )
-
- # Initialize optimizer with correct parameter order
- opt = ModelOptimizer(
- model=model,
- train_data=sampler.train_data,
- test_data=sampler.test_data,
- test_mask=sampler.test_mask,
- train_mask=sampler.train_mask,
- adj_matrix=res_mat,
- evaluate_fun=evaluate_fun,
- lr=args.lr,
- wd=args.wd,
- epochs=args.epochs,
- test_freq=args.test_freq if hasattr(args, 'test_freq') else 20,
- device=args.device
- )
-
- # Call train method and unpack results
- results = opt.train()
-
- return results[0], results[1] # true_data, predict_data, auc_data
|