123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- from typing import TYPE_CHECKING, Optional
-
- import numpy as np
- import cv2
- from skimage.morphology import dilation, erosion
-
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
- from . import Interpreter
-
-
- def overlay_interpretation(image: np.ndarray, int_map: np.ndarray, config: 'BaseConfig') -> np.ndarray:
- """
- image (np.ndarray): The image to overlay interpretations over. Of shape C W H. The numbers are assumed to be in range [0, 1]
- int_map (np.ndarray): The map containing interpretation scores.
- config ('Config'): An object containing required information about the configurations.
- Returns:
- (np.ndarray): The overlayed image
- """
-
- int_map = process_interpretations(int_map, config) # C W H
- int_map = np.moveaxis(int_map, 0, -1) # W H C
-
- # if there are more than 3 channels, use the first 3!
- if int_map.shape[-1] > 3:
- int_map = int_map[..., :3]
-
- # norm by max to make sure!
- int_map = int_map * 1.0 / (1e-6 + np.amax(int_map))
-
- alpha = 0.5
-
- if np.amin(image) < 0:
- image += 0.5
-
- image = np.moveaxis(image, 0, -1) # W H C
- image = image * 255
-
- o_color_by_c = {
- 1: [255, 0, 0],
- 2: [255, 255, 0],
- 3: [255, 255, 255],
- }
- o_color = np.asarray(o_color_by_c.get(int_map.shape[-1], None))
- if o_color is None:
- raise Exception("Trying to overlay more than 3 maps on the image.")
-
- # if int map has two channels, append one zeros to the end!
- if int_map.shape[-1] == 2:
- int_map = np.concatenate((int_map, np.zeros((int_map.shape[-3], int_map.shape[-2], 1))), axis=-1)
-
- overlayed = (1 - int_map) * image + \
- int_map * (1 - alpha) * image + \
- int_map * alpha * o_color[np.newaxis, np.newaxis, :]
-
- overlayed = np.round(overlayed).astype(np.uint8)
- return overlayed
-
-
- def process_interpretations(int_map: np.ndarray, config: 'BaseConfig', interpreter: Optional['Interpreter'] = None) -> np.ndarray:
- """
- int_map (np.ndarray): The map of interpretation scores. Is assumed to have a shape of C ...
- config ('Config'): An object containing the information about the required configurations
-
- Returns:
- np.ndarray: Processed interpretation maps, with numbers in the range of [0, 1]
- """
-
- if interpreter and config.dynamic_threshold:
- int_map = interpreter.dynamic_threshold(int_map)
-
- if not config.global_threshold:
-
- # Treating map as negative=reverse effect; discarding negatives
-
- int_map[int_map < 0] = 0
-
- qval = np.amax(int_map)
- if qval > 0:
- int_map[int_map > qval] = qval
- int_map = int_map / qval
-
- # applying cut threshold!
- int_map[int_map <= config.cut_threshold] = config.cut_threshold
- int_map -= config.cut_threshold
-
- kw = int(np.round(0.1 * int_map.shape[-1]))
-
- mv = np.amax(int_map)
- int_map /= (mv + 1e-6)
-
- # dilation and erosion to make continuous objects and erase noises
- for c in range(int_map.shape[0]):
- int_map[c] = dilation(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw + 3, kw + 3)))
- int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kw, kw)))
- int_map[c] = erosion(int_map[c], cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)))
-
- int_map[c] = cv2.blur(np.round(255 * int_map[c]).astype(np.uint8), (5, 5)) * 1.0 / 255.0
-
- return int_map # [0, 1]
-
|