Browse Source

add ddi_graph to eval_model

main
MahsaYazdani 1 year ago
parent
commit
3f1c3f94d9
1 changed files with 4 additions and 4 deletions
  1. 4
    4
      predictor/cross_validation.py

+ 4
- 4
predictor/cross_validation.py View File



def eval_model(model, optimizer, loss_func, train_data, test_data, def eval_model(model, optimizer, loss_func, train_data, test_data,
batch_size, n_epoch, patience, gpu_id, mdl_dir):
batch_size, n_epoch, patience, ddi_graph, gpu_id, mdl_dir):
tr_indices, es_indices = random_split_indices(len(train_data), test_rate=0.1) tr_indices, es_indices = random_split_indices(len(train_data), test_rate=0.1)
train_loader = FastTensorDataLoader(*train_data.tensor_samples(tr_indices), batch_size=batch_size, shuffle=True) train_loader = FastTensorDataLoader(*train_data.tensor_samples(tr_indices), batch_size=batch_size, shuffle=True)
valid_loader = FastTensorDataLoader(*train_data.tensor_samples(es_indices), batch_size=len(es_indices) // 4) valid_loader = FastTensorDataLoader(*train_data.tensor_samples(es_indices), batch_size=len(es_indices) // 4)
test_loader = FastTensorDataLoader(*test_data.tensor_samples(), batch_size=len(test_data) // 4) test_loader = FastTensorDataLoader(*test_data.tensor_samples(), batch_size=len(test_data) // 4)
train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id,
train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id, ddi_graph,
sl=True, mdl_dir=mdl_dir) sl=True, mdl_dir=mdl_dir)
test_loss = eval_epoch(model, test_loader, loss_func, gpu_id)
test_loss = eval_epoch(model, test_loader, loss_func, ddi_graph, gpu_id)
test_loss /= len(test_data) test_loss /= len(test_data)
return test_loss return test_loss


if not os.path.exists(test_mdl_dir): if not os.path.exists(test_mdl_dir):
os.makedirs(test_mdl_dir) os.makedirs(test_mdl_dir)
test_loss = eval_model(model, optimizer, loss_func, train_data, test_data, test_loss = eval_model(model, optimizer, loss_func, train_data, test_data,
args.batch, args.epoch, args.patience, gpu_id, test_mdl_dir)
args.batch, args.epoch, args.patience, ddi_graph, gpu_id, test_mdl_dir)
test_losses.append(test_loss) test_losses.append(test_loss)
logging.info("Test loss: {:.4f}".format(test_loss)) logging.info("Test loss: {:.4f}".format(test_loss))
logging.info("*" * n_delimiter + '\n') logging.info("*" * n_delimiter + '\n')

Loading…
Cancel
Save