|
1234567891011121314151617181920212223242526272829303132 |
- import argparse
- import glob
- import os
- import tarfile
- from multiprocessing import Pool
-
-
- parser = argparse.ArgumentParser()
- parser.add_argument('-s', dest='source', help='Class tars directory', required=True)
- parser.add_argument('-t', dest='target', help='train set directory', default='data/imagenet/extracted')
- parser.add_argument('-n', dest='num_threads', help='number of threads', default=1, type=int)
- args = parser.parse_args()
-
- class_tars = glob.glob(os.path.join(args.source, '*.tar'))
- assert len(class_tars) == 1000, f"class_tars length: {len(class_tars)}"
-
-
- def extract(class_tar):
- filename = os.path.basename(class_tar)
- class_name = filename.replace('.tar', '')
- print('Extract ' + os.path.basename(class_tar))
- class_fname = os.path.join(args.target, class_name)
- os.makedirs(class_fname, exist_ok=True)
-
- with tarfile.open(class_tar) as f:
- f.extractall(class_fname)
-
-
-
- pool = Pool(args.num_threads)
- pool.map(extract, class_tars)
- pool.close()
|