1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- """Prepare the ImageNet dataset"""
- # import torch
- import os
- import argparse
- import tarfile
- import pickle
- import gzip
- # import subprocess
- # from tqdm import tqdm
- # from mxnet.gluon.utils import check_sha1
- # from gluoncv.utils import download, makedirs
-
- _VAL_TAR = 'ILSVRC2012_img_val.tar'
-
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Setup the ImageNet dataset.',
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('-s', dest='download_dir', required=True,
- help="The directory that contains downloaded tar files")
- parser.add_argument('-m', dest='mapping', required=True,
- help="The mapping file for validation set")
- parser.add_argument('--target-dir', default='data/imagenet/extracted',
- help="The directory to store extracted images")
- args = parser.parse_args()
- return args
-
- def check_file(filename):
- if not os.path.exists(filename):
- raise ValueError('File not found: '+filename)
-
- def extract_val(tar_fname, target_dir, val_maps_file):
- os.makedirs(target_dir)
- print('Extracting ' + tar_fname)
- with tarfile.open(tar_fname) as tar:
- tar.extractall(target_dir)
- # build rec file before images are moved into subfolders
- # move images to proper subfolders
- with gzip.open(val_maps_file, 'rb') as f:
- dirs, mappings = pickle.load(f)
- for d in dirs:
- os.makedirs(os.path.join(target_dir, d))
- for m in mappings:
- os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0]))
-
- def main():
- args = parse_args()
-
- target_dir = os.path.expanduser(args.target_dir)
- if os.path.exists(target_dir):
- raise ValueError('Target dir ['+target_dir+'] exists. Remove it first')
-
- download_dir = os.path.expanduser(args.download_dir)
- val_tar_fname = os.path.join(download_dir, _VAL_TAR)
- check_file(val_tar_fname)
- extract_val(val_tar_fname, os.path.join(target_dir, 'val'), args.mapping)
-
- if __name__ == '__main__':
- main()
|