DeepTraCDR: Prediction Cancer Drug Response using multimodal deep learning with Transformers
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

new_model.py 1.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from model import DeepTraCDR, ModelOptimizer
  2. from data_sampler import NewSampler
  3. import torch
  4. def run_DeepTraCDR(cell_exprs, drug_finger, res_mat, null_mask, target_dim, target_index,
  5. evaluate_fun, args):
  6. # Convert to numpy if needed
  7. if isinstance(res_mat, torch.Tensor):
  8. res_mat = res_mat.numpy()
  9. if isinstance(null_mask, torch.Tensor):
  10. null_mask = null_mask.numpy()
  11. # Initialize sampler with required parameters
  12. sampler = NewSampler(res_mat, null_mask, target_dim, target_index)
  13. # Initialize model with correct device parameter name
  14. model = DeepTraCDR(
  15. adj_mat=sampler.train_data,
  16. cell_exprs=cell_exprs,
  17. drug_fingerprints=drug_finger,
  18. layer_size=args.layer_size,
  19. gamma=args.gamma,
  20. device=args.device
  21. )
  22. # Initialize optimizer with correct parameter order
  23. opt = ModelOptimizer(
  24. model=model,
  25. train_data=sampler.train_data,
  26. test_data=sampler.test_data,
  27. test_mask=sampler.test_mask,
  28. train_mask=sampler.train_mask,
  29. adj_matrix=res_mat,
  30. evaluate_fun=evaluate_fun,
  31. lr=args.lr,
  32. wd=args.wd,
  33. epochs=args.epochs,
  34. test_freq=args.test_freq if hasattr(args, 'test_freq') else 20,
  35. device=args.device
  36. )
  37. # Call train method and unpack results
  38. results = opt.train()
  39. return results[0], results[1] # true_data, predict_data, auc_data