You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

extract_train.py 985B

1 year ago
1234567891011121314151617181920212223242526272829303132
  1. import argparse
  2. import glob
  3. import os
  4. import tarfile
  5. from multiprocessing import Pool
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument('-s', dest='source', help='Class tars directory', required=True)
  8. parser.add_argument('-t', dest='target', help='train set directory', default='data/imagenet/extracted')
  9. parser.add_argument('-n', dest='num_threads', help='number of threads', default=1, type=int)
  10. args = parser.parse_args()
  11. class_tars = glob.glob(os.path.join(args.source, '*.tar'))
  12. assert len(class_tars) == 1000, f"class_tars length: {len(class_tars)}"
  13. def extract(class_tar):
  14. filename = os.path.basename(class_tar)
  15. class_name = filename.replace('.tar', '')
  16. print('Extract ' + os.path.basename(class_tar))
  17. class_fname = os.path.join(args.target, class_name)
  18. os.makedirs(class_fname, exist_ok=True)
  19. with tarfile.open(class_tar) as f:
  20. f.extractall(class_fname)
  21. pool = Pool(args.num_threads)
  22. pool.map(extract, class_tars)
  23. pool.close()