from typing import List, TYPE_CHECKING, Tuple import numpy as np from skimage import measure from .utils import process_interpretations if TYPE_CHECKING: from ..configs.base_config import BaseConfig from . import Interpreter class BinaryInterpretationEvaluator2D: def __init__(self, n_samples: int, config: 'BaseConfig'): self._config = config self._n_samples = n_samples self._min_intersection_threshold = config.acceptable_min_intersection_threshold self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) def reset(self): self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float) def update_summaries( self, m_interpretations: np.ndarray, ground_truth_interpretations: np.ndarray, net_preds: np.ndarray, ground_truth_labels: np.ndarray, batch_inds: np.ndarray, interpreter: 'Interpreter' ) -> None: assert len(ground_truth_interpretations.shape) == 3, f'GT interpretations must have a shape of BxWxH but it is {ground_truth_interpretations.shape}' assert len(m_interpretations.shape) == 3, f'Model interpretations must have a shape of BxWxH but it is {m_interpretations.shape}' # skipping samples without interpretations! has_interpretation_mask = np.logical_not(np.any(np.isnan(ground_truth_interpretations), axis=(1, 2))) m_interpretations = m_interpretations[has_interpretation_mask] ground_truths = ground_truth_interpretations[has_interpretation_mask] net_preds[has_interpretation_mask] # finding class labels if len(net_preds.shape) == 1: net_preds = (net_preds >= 0.5).astype(int) elif net_preds.shape[1] == 1: net_preds = (net_preds[:, 0] >= 0.5).astype(int) else: net_preds = net_preds.argmax(axis=-1) ground_truth_labels = ground_truth_labels[has_interpretation_mask] batch_inds = batch_inds[has_interpretation_mask] # Checking shapes if net_preds.shape == ground_truth_labels.shape: net_preds = np.round(net_preds, 0).astype(int) else: net_preds = np.argmax(net_preds, axis=1) # calculating soundness soundnesses = (net_preds == ground_truth_labels).astype(int) c_interpretations = np.clip(m_interpretations, 0, np.amax(m_interpretations)) b_interpretations = np.stack(tuple([ process_interpretations(m_interpretations[ind][None, ...], self._config, interpreter) for ind in range(len(m_interpretations)) ]), axis=0)[:, 0, ...] b_interpretations = (b_interpretations > 0).astype(bool) ground_truths = (ground_truths >= 0.5).astype(bool) #making sure values are 0 and 1 even if resize has been applied assert ground_truths.shape[-2:] == b_interpretations.shape[-2:], f'Ground truth and model interpretations must have the same shape, found {ground_truths.shape[-2:]} and {b_interpretations.shape[-2:]}' norm_factor = 1.0 * b_interpretations.shape[1] * b_interpretations.shape[2] np.add.at(self._normed_intersection_per_soundness, soundnesses, np.sum(b_interpretations & ground_truths, axis=(1, 2)) * 1.0 / norm_factor) np.add.at(self._normed_union_per_soundness, soundnesses, np.sum(b_interpretations | ground_truths, axis=(1, 2)) * 1.0 / norm_factor) for i in range(len(b_interpretations)): has_nonzero_captured_bbs = False has_nonzero_captured_bbs_by_topk = False s = soundnesses[i] org_labels = measure.label(ground_truths[i, :, :]) check_labels = measure.label(b_interpretations[i, :, :]) # keeping topK interpretations with k = n_GT! = finding a threshold by quantile! calculating quantile by GT n_on_gt = np.sum(ground_truths[i]) q = (1 + n_on_gt) * 1.0 / (ground_truths.shape[-1] * ground_truths.shape[-2]) # 1 is added because we have > in thresholding not >= if q < 1: tints = c_interpretations[i] th = max(0, np.quantile(tints.reshape(-1), 1 - q)) tints = (tints > th) else: tints = (c_interpretations[i] > 0) # TOPK METRICS tk_intersection = np.sum(tints & ground_truths[i]) tk_union = np.sum(tints | ground_truths[i]) self._tk_normed_intersection_per_soundness[s] += tk_intersection * 1.0 / norm_factor self._tk_normed_union_per_soundness[s] += tk_union * 1.0 / norm_factor @staticmethod def get_titles_of_evaluation_metrics() -> List[str]: return ['S-IOU', 'S-TK-IOU', 'M-IOU', 'M-TK-IOU', 'A-IOU', 'A-TK-IOU'] @staticmethod def _get_eval_metrics(normed_intersection, normed_union, title, tk_normed_intersection, tk_normed_union) -> \ Tuple[str, str, str, str, str]: iou = (1e-6 + normed_intersection) / (1e-6 + normed_union) tk_iou = (1e-6 + tk_normed_intersection) / (1e-6 + tk_normed_union) return '%.4f' % (iou * 100,), '%.4f' % (tk_iou * 100,) def get_values_of_evaluation_metrics(self) -> List[str]: return \ list(self._get_eval_metrics( self._normed_intersection_per_soundness[1], self._normed_union_per_soundness[1], 'Sounds', self._tk_normed_intersection_per_soundness[1], self._tk_normed_union_per_soundness[1], )) + \ list(self._get_eval_metrics( self._normed_intersection_per_soundness[0], self._normed_union_per_soundness[0], 'Mistakes', self._tk_normed_intersection_per_soundness[0], self._tk_normed_union_per_soundness[0], )) + \ list(self._get_eval_metrics( sum(self._normed_intersection_per_soundness), sum(self._normed_union_per_soundness), 'All', sum(self._tk_normed_intersection_per_soundness), sum(self._tk_normed_union_per_soundness), )) def print_summaries(self): titles = self.get_titles_of_evaluation_metrics() vals = self.get_values_of_evaluation_metrics() nc = 2 for r in range(3): print(', '.join(['%s: %s' % (titles[nc * r + i], vals[nc * r + i]) for i in range(nc)]))