You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

prepare_rsna_data.py 1.7KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from typing import List, Tuple
  2. import argparse
  3. from os import makedirs, path, listdir
  4. import numpy as np
  5. import cv2
  6. from multiprocessing import Pool
  7. def resize_image(img_dir: str, img_save_dir: str, res: int) -> None:
  8. img = cv2.imread(img_dir)
  9. img = cv2.resize(img, dsize=(res, res))
  10. cv2.imwrite(img_save_dir, img)
  11. if __name__ == '__main__':
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument('kaggle_dataset_dir', type=str, help='download the dataset from https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data, extract it, and pass its path is this argument')
  14. parser.add_argument('resolution', type=int, help='The resolution required for your model, 224 for resnet, 299 for vanilla inception, 256 for modified inception')
  15. parser.add_argument('cores', type=int, help='The number of cores for multiprocessing.')
  16. args = parser.parse_args()
  17. save_dir = f'data/RSNA-Kaggle_R{args.resolution}'
  18. makedirs(save_dir, exist_ok=True)
  19. kaggle_dataset_dir = args.kaggle_dataset_dir
  20. assert path.exists(kaggle_dataset_dir), f'{kaggle_dataset_dir} does not exist!'
  21. # reading rsna images names
  22. rsna_imgs_path = path.join(kaggle_dataset_dir, 'stage_2_train_images')
  23. assert path.exists(rsna_imgs_path), 'Make sure there is a folder named stage_2_train_images in the passed kaggle_directory!'
  24. imgs_names = np.asarray(listdir(rsna_imgs_path))
  25. imgs_src = np.vectorize(lambda x: path.join(rsna_imgs_path, x))(imgs_names)
  26. imgs_dst = np.vectorize(lambda x: path.join(save_dir, x))(imgs_names)
  27. pool = Pool(args.cores)
  28. pool.starmap(resize_image, zip(imgs_src, imgs_dst, np.full((len(imgs_src),), args.resolution)))
  29. pool.close()