Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

optuna_main.py 2.3KB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import joblib
  2. import optuna
  3. from optuna.pruners import MedianPruner
  4. from optuna.trial import TrialState
  5. from data_loaders import make_dfs, build_loaders
  6. from learner import supervised_train
  7. from test_main import test
  8. def objective(trial, config, train_loader, validation_loader):
  9. config.optuna(trial=trial)
  10. print('Trial', trial.number, 'parameters', trial.params)
  11. accuracy = supervised_train(config, train_loader, validation_loader, trial=trial)
  12. return accuracy
  13. def optuna_main(config, n_trials=100):
  14. train_df, test_df, validation_df = make_dfs(config)
  15. train_loader = build_loaders(config, train_df, mode="train")
  16. validation_loader = build_loaders(config, validation_df, mode="validation")
  17. test_loader = build_loaders(config, test_df, mode="test")
  18. study = optuna.create_study(study_name=config.output_path.split('/')[-1],
  19. sampler=optuna.samplers.TPESampler(),
  20. storage=f'sqlite:///{config.output_path + "/optuna.db"}',
  21. load_if_exists=True,
  22. direction="maximize",
  23. pruner=MedianPruner(n_startup_trials=10, n_warmup_steps=10)
  24. )
  25. study.optimize(lambda trial: objective(trial, config, train_loader, validation_loader), n_trials=n_trials)
  26. pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
  27. complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
  28. joblib.dump(study, str(config.output_path) + '/study_optuna_model' + '.pkl')
  29. print("Study statistics: ")
  30. print(" Number of finished trials: ", len(study.trials))
  31. print(" Number of pruned trials: ", len(pruned_trials))
  32. print(" Number of complete trials: ", len(complete_trials))
  33. s = ''
  34. print("Best trial:")
  35. trial = study.best_trial
  36. print(' Number: ', trial.number)
  37. print(" Value: ", trial.value)
  38. s += 'number: ' + str(trial.number) + '\n'
  39. s += 'value: ' + str(trial.value) + '\n'
  40. print(" Params: ")
  41. s += 'params: \n'
  42. for key, value in trial.params.items():
  43. print(" {}: {}".format(key, value))
  44. s += " {}: {}\n".format(key, value)
  45. with open(config.output_path+'/optuna_results.txt', 'w') as f:
  46. f.write(s)
  47. test(config, test_loader, trial_number=trial.number)