1234567891011121314151617181920212223242526272829303132333435363738394041 |
- from typing import List, Tuple
- import argparse
- from os import makedirs, path, listdir
- import numpy as np
- import cv2
- from multiprocessing import Pool
-
-
- def resize_image(img_dir: str, img_save_dir: str, res: int) -> None:
- img = cv2.imread(img_dir)
- img = cv2.resize(img, dsize=(res, res))
- cv2.imwrite(img_save_dir, img)
-
-
- if __name__ == '__main__':
-
- parser = argparse.ArgumentParser()
- parser.add_argument('kaggle_dataset_dir', type=str, help='download the dataset from https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data, extract it, and pass its path is this argument')
- parser.add_argument('resolution', type=int, help='The resolution required for your model, 224 for resnet, 299 for vanilla inception, 256 for modified inception')
- parser.add_argument('cores', type=int, help='The number of cores for multiprocessing.')
- args = parser.parse_args()
-
- save_dir = f'data/RSNA-Kaggle_R{args.resolution}'
- makedirs(save_dir, exist_ok=True)
-
- kaggle_dataset_dir = args.kaggle_dataset_dir
- assert path.exists(kaggle_dataset_dir), f'{kaggle_dataset_dir} does not exist!'
-
- # reading rsna images names
- rsna_imgs_path = path.join(kaggle_dataset_dir, 'stage_2_train_images')
- assert path.exists(rsna_imgs_path), 'Make sure there is a folder named stage_2_train_images in the passed kaggle_directory!'
-
- imgs_names = np.asarray(listdir(rsna_imgs_path))
-
- imgs_src = np.vectorize(lambda x: path.join(rsna_imgs_path, x))(imgs_names)
- imgs_dst = np.vectorize(lambda x: path.join(save_dir, x))(imgs_names)
-
- pool = Pool(args.cores)
- pool.starmap(resize_image, zip(imgs_src, imgs_dst, np.full((len(imgs_src),), args.resolution)))
- pool.close()
|