Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 1.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import argparse
  2. from optuna_main import optuna_main
  3. from test_main import test_main
  4. from torch_main import torch_main
  5. from data.twitter.config import TwitterConfig
  6. from data.weibo.config import WeiboConfig
  7. from torch.utils.tensorboard import SummaryWriter
  8. from utils import make_directory
  9. if __name__ == '__main__':
  10. import os
  11. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument('--data', type=str, required=True)
  14. parser.add_argument('--use_optuna', type=int, required=False)
  15. parser.add_argument('--use_lime', type=int, required=False)
  16. parser.add_argument('--just_test', type=int, required=False)
  17. parser.add_argument('--batch', type=int, required=False)
  18. parser.add_argument('--epoch', type=int, required=False)
  19. parser.add_argument('--extra', type=str, required=False)
  20. args = parser.parse_args()
  21. if args.data == 'twitter':
  22. config = TwitterConfig()
  23. elif args.data == 'weibo':
  24. config = WeiboConfig()
  25. else:
  26. raise Exception('Enter a valid dataset name', args.data)
  27. if args.batch:
  28. config.batch_size = args.batch
  29. if args.epoch:
  30. config.epochs = args.epoch
  31. if args.use_optuna:
  32. config.output_path += 'logs/' + args.data + '_optuna' + '_' + str(args.extra)
  33. else:
  34. config.output_path += 'logs/' + args.data + '_' + str(args.extra)
  35. make_directory(config.output_path)
  36. # config.writer = SummaryWriter(config.output_path)
  37. if args.use_optuna is not None:
  38. optuna_main(config, args.use_optuna)
  39. elif args.just_test is not None:
  40. test_main(config, args.just_test)
  41. else:
  42. torch_main(config)
  43. # config.writer.close()