Browse Source

Upload files to 'Scenario2/new'

master
Zahra Asgari 1 day ago
parent
commit
0f6278ca57
1 changed files with 44 additions and 0 deletions
  1. 44
    0
      Scenario2/new/new_model.py

+ 44
- 0
Scenario2/new/new_model.py View File

@@ -0,0 +1,44 @@
from model import DeepTraCDR, ModelOptimizer
from data_sampler import NewSampler
import torch
def run_DeepTraCDR(cell_exprs, drug_finger, res_mat, null_mask, target_dim, target_index,
evaluate_fun, args):

# Convert to numpy if needed
if isinstance(res_mat, torch.Tensor):
res_mat = res_mat.numpy()
if isinstance(null_mask, torch.Tensor):
null_mask = null_mask.numpy()
# Initialize sampler with required parameters
sampler = NewSampler(res_mat, null_mask, target_dim, target_index)
# Initialize model with correct device parameter name
model = DeepTraCDR(
adj_mat=sampler.train_data,
cell_exprs=cell_exprs,
drug_fingerprints=drug_finger,
layer_size=args.layer_size,
gamma=args.gamma,
device=args.device
)
# Initialize optimizer with correct parameter order
opt = ModelOptimizer(
model=model,
train_data=sampler.train_data,
test_data=sampler.test_data,
test_mask=sampler.test_mask,
train_mask=sampler.train_mask,
adj_matrix=res_mat,
evaluate_fun=evaluate_fun,
lr=args.lr,
wd=args.wd,
epochs=args.epochs,
test_freq=args.test_freq if hasattr(args, 'test_freq') else 20,
device=args.device
)
# Call train method and unpack results
results = opt.train()
return results[0], results[1] # true_data, predict_data, auc_data

Loading…
Cancel
Save