You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

extract_val.py 2.0KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. """Prepare the ImageNet dataset"""
  2. # import torch
  3. import os
  4. import argparse
  5. import tarfile
  6. import pickle
  7. import gzip
  8. # import subprocess
  9. # from tqdm import tqdm
  10. # from mxnet.gluon.utils import check_sha1
  11. # from gluoncv.utils import download, makedirs
  12. _VAL_TAR = 'ILSVRC2012_img_val.tar'
  13. def parse_args():
  14. parser = argparse.ArgumentParser(
  15. description='Setup the ImageNet dataset.',
  16. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  17. parser.add_argument('-s', dest='download_dir', required=True,
  18. help="The directory that contains downloaded tar files")
  19. parser.add_argument('-m', dest='mapping', required=True,
  20. help="The mapping file for validation set")
  21. parser.add_argument('--target-dir', default='data/imagenet/extracted',
  22. help="The directory to store extracted images")
  23. args = parser.parse_args()
  24. return args
  25. def check_file(filename):
  26. if not os.path.exists(filename):
  27. raise ValueError('File not found: '+filename)
  28. def extract_val(tar_fname, target_dir, val_maps_file):
  29. os.makedirs(target_dir)
  30. print('Extracting ' + tar_fname)
  31. with tarfile.open(tar_fname) as tar:
  32. tar.extractall(target_dir)
  33. # build rec file before images are moved into subfolders
  34. # move images to proper subfolders
  35. with gzip.open(val_maps_file, 'rb') as f:
  36. dirs, mappings = pickle.load(f)
  37. for d in dirs:
  38. os.makedirs(os.path.join(target_dir, d))
  39. for m in mappings:
  40. os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0]))
  41. def main():
  42. args = parse_args()
  43. target_dir = os.path.expanduser(args.target_dir)
  44. if os.path.exists(target_dir):
  45. raise ValueError('Target dir ['+target_dir+'] exists. Remove it first')
  46. download_dir = os.path.expanduser(args.download_dir)
  47. val_tar_fname = os.path.join(download_dir, _VAL_TAR)
  48. check_file(val_tar_fname)
  49. extract_val(val_tar_fname, os.path.join(target_dir, 'val'), args.mapping)
  50. if __name__ == '__main__':
  51. main()