123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import argparse
-
- from lime_main import lime_main
- from optuna_main import optuna_main
- from test_main import test_main
- from torch_main import torch_main
- from data.twitter.config import TwitterConfig
- from data.weibo.config import WeiboConfig
- from torch.utils.tensorboard import SummaryWriter
-
- from utils import make_directory
-
- if __name__ == '__main__':
- import os
-
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--data', type=str, required=True)
- parser.add_argument('--use_optuna', type=int, required=False)
- parser.add_argument('--use_lime', type=int, required=False)
- parser.add_argument('--just_test', type=int, required=False)
- parser.add_argument('--batch', type=int, required=False)
- parser.add_argument('--epoch', type=int, required=False)
- parser.add_argument('--extra', type=str, required=False)
-
- args = parser.parse_args()
-
- if args.data == 'twitter':
- config = TwitterConfig()
- elif args.data == 'weibo':
- config = WeiboConfig()
- else:
- raise Exception('Enter a valid dataset name', args.data)
-
- if args.batch:
- config.batch_size = args.batch
- if args.epoch:
- config.epochs = args.epoch
-
- if args.use_optuna:
- config.output_path += 'logs/' + args.data + '_optuna' + '_' + str(args.extra)
- else:
- config.output_path += 'logs/' + args.data + '_' + str(args.extra)
-
-
- make_directory(config.output_path)
-
- # config.writer = SummaryWriter(config.output_path)
-
- if args.use_optuna is not None:
- optuna_main(config, args.use_optuna)
- elif args.just_test is not None:
- test_main(config, args.just_test)
- elif args.use_lime is not None:
- lime_main(config, args.use_lime)
- else:
- torch_main(config)
-
- # config.writer.close()
|