Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 1.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import argparse
  2. from optuna_main import optuna_main
  3. from test_main import test_main
  4. from torch_main import torch_main
  5. from data.twitter.config import TwitterConfig
  6. from data.weibo.config import WeiboConfig
  7. from torch.utils.tensorboard import SummaryWriter
  8. if __name__ == '__main__':
  9. import os
  10. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument('--data', type=str, required=True)
  13. parser.add_argument('--use_optuna', type=int, required=False)
  14. parser.add_argument('--just_test', type=int, required=False)
  15. parser.add_argument('--batch', type=int, required=False)
  16. parser.add_argument('--epoch', type=int, required=False)
  17. parser.add_argument('--extra', type=str, required=False)
  18. args = parser.parse_args()
  19. if args.data == 'twitter':
  20. config = TwitterConfig()
  21. elif args.data == 'weibo':
  22. config = WeiboConfig()
  23. else:
  24. raise Exception('Enter a valid dataset name', args.data)
  25. if args.batch:
  26. config.batch_size = args.batch
  27. if args.epoch:
  28. config.epochs = args.epoch
  29. if args.use_optuna:
  30. config.output_path += 'logs/' + args.data + '_optuna' + '_' + str(args.extra)
  31. else:
  32. config.output_path += 'logs/' + args.data + '_' + str(args.extra)
  33. use_optuna = True
  34. if not args.extra or 'temp' in args.extra:
  35. config.output_path = str(args.extra)
  36. use_optuna = False
  37. try:
  38. os.mkdir(config.output_path)
  39. except OSError:
  40. print("Creation of the directory failed")
  41. else:
  42. print("Successfully created the directory")
  43. config.writer = SummaryWriter(config.output_path)
  44. if args.use_optuna and use_optuna:
  45. optuna_main(config, args.use_optuna)
  46. elif args.just_test:
  47. test_main(config, args.just_test)
  48. else:
  49. torch_main(config)
  50. config.writer.close()