Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 1.9KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import argparse
  2. from lime_main import lime_main
  3. from optuna_main import optuna_main
  4. from test_main import test_main
  5. from torch_main import torch_main
  6. from data.twitter.config import TwitterConfig
  7. from data.weibo.config import WeiboConfig
  8. from torch.utils.tensorboard import SummaryWriter
  9. from utils import make_directory
  10. if __name__ == '__main__':
  11. import os
  12. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument('--data', type=str, required=True)
  15. parser.add_argument('--use_optuna', type=int, required=False)
  16. parser.add_argument('--use_lime', type=int, required=False)
  17. parser.add_argument('--just_test', type=int, required=False)
  18. parser.add_argument('--batch', type=int, required=False)
  19. parser.add_argument('--epoch', type=int, required=False)
  20. parser.add_argument('--extra', type=str, required=False)
  21. args = parser.parse_args()
  22. if args.data == 'twitter':
  23. config = TwitterConfig()
  24. elif args.data == 'weibo':
  25. config = WeiboConfig()
  26. else:
  27. raise Exception('Enter a valid dataset name', args.data)
  28. if args.batch:
  29. config.batch_size = args.batch
  30. if args.epoch:
  31. config.epochs = args.epoch
  32. if args.use_optuna:
  33. config.output_path += 'logs/' + args.data + '_optuna' + '_' + str(args.extra)
  34. else:
  35. config.output_path += 'logs/' + args.data + '_' + str(args.extra)
  36. use_optuna = True
  37. if not args.extra or 'temp' in args.extra:
  38. config.output_path = str(args.extra)
  39. use_optuna = False
  40. make_directory(config.output_path)
  41. config.writer = SummaryWriter(config.output_path)
  42. if args.use_optuna and use_optuna:
  43. optuna_main(config, args.use_optuna)
  44. elif args.just_test:
  45. test_main(config, args.just_test)
  46. elif args.use_lime:
  47. lime_main(config, args.use_lime)
  48. else:
  49. torch_main(config)
  50. config.writer.close()