|
|
@@ -0,0 +1,44 @@ |
|
|
|
from model import DeepTraCDR, ModelOptimizer |
|
|
|
from data_sampler import NewSampler |
|
|
|
import torch |
|
|
|
def run_DeepTraCDR(cell_exprs, drug_finger, res_mat, null_mask, target_dim, target_index, |
|
|
|
evaluate_fun, args): |
|
|
|
|
|
|
|
# Convert to numpy if needed |
|
|
|
if isinstance(res_mat, torch.Tensor): |
|
|
|
res_mat = res_mat.numpy() |
|
|
|
if isinstance(null_mask, torch.Tensor): |
|
|
|
null_mask = null_mask.numpy() |
|
|
|
# Initialize sampler with required parameters |
|
|
|
sampler = NewSampler(res_mat, null_mask, target_dim, target_index) |
|
|
|
|
|
|
|
# Initialize model with correct device parameter name |
|
|
|
model = DeepTraCDR( |
|
|
|
adj_mat=sampler.train_data, |
|
|
|
cell_exprs=cell_exprs, |
|
|
|
drug_fingerprints=drug_finger, |
|
|
|
layer_size=args.layer_size, |
|
|
|
gamma=args.gamma, |
|
|
|
device=args.device |
|
|
|
) |
|
|
|
|
|
|
|
# Initialize optimizer with correct parameter order |
|
|
|
opt = ModelOptimizer( |
|
|
|
model=model, |
|
|
|
train_data=sampler.train_data, |
|
|
|
test_data=sampler.test_data, |
|
|
|
test_mask=sampler.test_mask, |
|
|
|
train_mask=sampler.train_mask, |
|
|
|
adj_matrix=res_mat, |
|
|
|
evaluate_fun=evaluate_fun, |
|
|
|
lr=args.lr, |
|
|
|
wd=args.wd, |
|
|
|
epochs=args.epochs, |
|
|
|
test_freq=args.test_freq if hasattr(args, 'test_freq') else 20, |
|
|
|
device=args.device |
|
|
|
) |
|
|
|
|
|
|
|
# Call train method and unpack results |
|
|
|
results = opt.train() |
|
|
|
|
|
|
|
return results[0], results[1] # true_data, predict_data, auc_data |