You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from typing import TYPE_CHECKING, Optional
  2. import numpy as np
  3. import cv2
  4. from skimage.morphology import dilation, erosion
  5. if TYPE_CHECKING:
  6. from ..configs.base_config import BaseConfig
  7. from . import Interpreter
  8. def overlay_interpretation(image: np.ndarray, int_map: np.ndarray, config: 'BaseConfig') -> np.ndarray:
  9. """
  10. image (np.ndarray): The image to overlay interpretations over. Of shape C W H. The numbers are assumed to be in range [0, 1]
  11. int_map (np.ndarray): The map containing interpretation scores.
  12. config ('Config'): An object containing required information about the configurations.
  13. Returns:
  14. (np.ndarray): The overlayed image
  15. """
  16. int_map = process_interpretations(int_map, config) # C W H
  17. int_map = np.moveaxis(int_map, 0, -1) # W H C
  18. # if there are more than 3 channels, use the first 3!
  19. if int_map.shape[-1] > 3:
  20. int_map = int_map[..., :3]
  21. # norm by max to make sure!
  22. int_map = int_map * 1.0 / (1e-6 + np.amax(int_map))
  23. alpha = 0.5
  24. if np.amin(image) < 0:
  25. image += 0.5
  26. image = np.moveaxis(image, 0, -1) # W H C
  27. image = image * 255
  28. o_color_by_c = {
  29. 1: [255, 0, 0],
  30. 2: [255, 255, 0],
  31. 3: [255, 255, 255],
  32. }
  33. o_color = np.asarray(o_color_by_c.get(int_map.shape[-1], None))
  34. if o_color is None:
  35. raise Exception("Trying to overlay more than 3 maps on the image.")
  36. # if int map has two channels, append one zeros to the end!
  37. if int_map.shape[-1] == 2:
  38. int_map = np.concatenate((int_map, np.zeros((int_map.shape[-3], int_map.shape[-2], 1))), axis=-1)
  39. overlayed = (1 - int_map) * image + \
  40. int_map * (1 - alpha) * image + \
  41. int_map * alpha * o_color[np.newaxis, np.newaxis, :]
  42. overlayed = np.round(overlayed).astype(np.uint8)
  43. return overlayed
  44. def process_interpretations(int_map: np.ndarray, config: 'BaseConfig', interpreter: Optional['Interpreter'] = None) -> np.ndarray:
  45. """
  46. int_map (np.ndarray): The map of interpretation scores. Is assumed to have a shape of C ...
  47. config ('Config'): An object containing the information about the required configurations
  48. Returns:
  49. np.ndarray: Processed interpretation maps, with numbers in the range of [0, 1]
  50. """
  51. if interpreter and config.dynamic_threshold:
  52. int_map = interpreter.dynamic_threshold(int_map)
  53. if not config.global_threshold:
  54. # Treating map as negative=reverse effect; discarding negatives
  55. int_map[int_map < 0] = 0
  56. qval = np.amax(int_map)
  57. if qval > 0:
  58. int_map[int_map > qval] = qval
  59. int_map = int_map / qval
  60. # applying cut threshold!
  61. int_map[int_map <= config.cut_threshold] = config.cut_threshold
  62. int_map -= config.cut_threshold
  63. kw = int(np.round(0.1 * int_map.shape[-1]))
  64. mv = np.amax(int_map)
  65. int_map /= (mv + 1e-6)
  66. # dilation and erosion to make continuous objects and erase noises
  67. for c in range(int_map.shape[0]):
  68. int_map[c] = dilation(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw + 3, kw + 3)))
  69. int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw, kw)))
  70. int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)))
  71. int_map[c] = cv2.blur(np.round(255 * int_map[c]).astype(np.uint8), (5, 5)) * 1.0 / 255.0
  72. return int_map # [0, 1]