"""Prepare the ImageNet dataset""" # import torch import os import argparse import tarfile import pickle import gzip # import subprocess # from tqdm import tqdm # from mxnet.gluon.utils import check_sha1 # from gluoncv.utils import download, makedirs _VAL_TAR = 'ILSVRC2012_img_val.tar' def parse_args(): parser = argparse.ArgumentParser( description='Setup the ImageNet dataset.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-s', dest='download_dir', required=True, help="The directory that contains downloaded tar files") parser.add_argument('-m', dest='mapping', required=True, help="The mapping file for validation set") parser.add_argument('--target-dir', default='data/imagenet/extracted', help="The directory to store extracted images") args = parser.parse_args() return args def check_file(filename): if not os.path.exists(filename): raise ValueError('File not found: '+filename) def extract_val(tar_fname, target_dir, val_maps_file): os.makedirs(target_dir) print('Extracting ' + tar_fname) with tarfile.open(tar_fname) as tar: tar.extractall(target_dir) # build rec file before images are moved into subfolders # move images to proper subfolders with gzip.open(val_maps_file, 'rb') as f: dirs, mappings = pickle.load(f) for d in dirs: os.makedirs(os.path.join(target_dir, d)) for m in mappings: os.rename(os.path.join(target_dir, m[0]), os.path.join(target_dir, m[1], m[0])) def main(): args = parse_args() target_dir = os.path.expanduser(args.target_dir) if os.path.exists(target_dir): raise ValueError('Target dir ['+target_dir+'] exists. Remove it first') download_dir = os.path.expanduser(args.download_dir) val_tar_fname = os.path.join(download_dir, _VAL_TAR) check_file(val_tar_fname) extract_val(val_tar_fname, os.path.join(target_dir, 'val'), args.mapping) if __name__ == '__main__': main()