|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- from typing import List, TYPE_CHECKING, Tuple
-
- import numpy as np
- from skimage import measure
-
- from .utils import process_interpretations
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
- from . import Interpreter
-
-
- class BinaryInterpretationEvaluator2D:
-
- def __init__(self, n_samples: int, config: 'BaseConfig'):
-
- self._config = config
-
- self._n_samples = n_samples
- self._min_intersection_threshold = config.acceptable_min_intersection_threshold
-
- self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
- self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
-
- self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
- self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
-
- def reset(self):
-
- self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
- self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
-
- self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
- self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
-
- def update_summaries(
- self,
- m_interpretations: np.ndarray, ground_truth_interpretations: np.ndarray,
- net_preds: np.ndarray, ground_truth_labels: np.ndarray,
- batch_inds: np.ndarray, interpreter: 'Interpreter'
- ) -> None:
-
- assert len(ground_truth_interpretations.shape) == 3, f'GT interpretations must have a shape of BxWxH but it is {ground_truth_interpretations.shape}'
- assert len(m_interpretations.shape) == 3, f'Model interpretations must have a shape of BxWxH but it is {m_interpretations.shape}'
-
- # skipping samples without interpretations!
- has_interpretation_mask = np.logical_not(np.any(np.isnan(ground_truth_interpretations), axis=(1, 2)))
- m_interpretations = m_interpretations[has_interpretation_mask]
- ground_truths = ground_truth_interpretations[has_interpretation_mask]
-
- net_preds[has_interpretation_mask]
-
- # finding class labels
-
- if len(net_preds.shape) == 1:
- net_preds = (net_preds >= 0.5).astype(int)
- elif net_preds.shape[1] == 1:
- net_preds = (net_preds[:, 0] >= 0.5).astype(int)
- else:
- net_preds = net_preds.argmax(axis=-1)
-
- ground_truth_labels = ground_truth_labels[has_interpretation_mask]
- batch_inds = batch_inds[has_interpretation_mask]
-
- # Checking shapes
- if net_preds.shape == ground_truth_labels.shape:
- net_preds = np.round(net_preds, 0).astype(int)
- else:
- net_preds = np.argmax(net_preds, axis=1)
-
- # calculating soundness
- soundnesses = (net_preds == ground_truth_labels).astype(int)
-
- c_interpretations = np.clip(m_interpretations, 0, np.amax(m_interpretations))
-
- b_interpretations = np.stack(tuple([
- process_interpretations(m_interpretations[ind][None, ...], self._config, interpreter)
- for ind in range(len(m_interpretations))
- ]), axis=0)[:, 0, ...]
- b_interpretations = (b_interpretations > 0).astype(bool)
- ground_truths = (ground_truths >= 0.5).astype(bool) #making sure values are 0 and 1 even if resize has been applied
-
- assert ground_truths.shape[-2:] == b_interpretations.shape[-2:], f'Ground truth and model interpretations must have the same shape, found {ground_truths.shape[-2:]} and {b_interpretations.shape[-2:]}'
-
- norm_factor = 1.0 * b_interpretations.shape[1] * b_interpretations.shape[2]
-
- np.add.at(self._normed_intersection_per_soundness, soundnesses,
- np.sum(b_interpretations & ground_truths, axis=(1, 2)) * 1.0 / norm_factor)
- np.add.at(self._normed_union_per_soundness, soundnesses,
- np.sum(b_interpretations | ground_truths, axis=(1, 2)) * 1.0 / norm_factor)
-
- for i in range(len(b_interpretations)):
-
- has_nonzero_captured_bbs = False
- has_nonzero_captured_bbs_by_topk = False
-
- s = soundnesses[i]
- org_labels = measure.label(ground_truths[i, :, :])
- check_labels = measure.label(b_interpretations[i, :, :])
-
- # keeping topK interpretations with k = n_GT! = finding a threshold by quantile! calculating quantile by GT
- n_on_gt = np.sum(ground_truths[i])
- q = (1 + n_on_gt) * 1.0 / (ground_truths.shape[-1] * ground_truths.shape[-2])
- # 1 is added because we have > in thresholding not >=
-
- if q < 1:
- tints = c_interpretations[i]
- th = max(0, np.quantile(tints.reshape(-1), 1 - q))
- tints = (tints > th)
- else:
- tints = (c_interpretations[i] > 0)
-
- # TOPK METRICS
-
- tk_intersection = np.sum(tints & ground_truths[i])
- tk_union = np.sum(tints | ground_truths[i])
-
- self._tk_normed_intersection_per_soundness[s] += tk_intersection * 1.0 / norm_factor
- self._tk_normed_union_per_soundness[s] += tk_union * 1.0 / norm_factor
-
- @staticmethod
- def get_titles_of_evaluation_metrics() -> List[str]:
- return ['S-IOU', 'S-TK-IOU',
- 'M-IOU', 'M-TK-IOU',
- 'A-IOU', 'A-TK-IOU']
-
- @staticmethod
- def _get_eval_metrics(normed_intersection, normed_union, title,
- tk_normed_intersection, tk_normed_union) -> \
- Tuple[str, str, str, str, str]:
-
- iou = (1e-6 + normed_intersection) / (1e-6 + normed_union)
- tk_iou = (1e-6 + tk_normed_intersection) / (1e-6 + tk_normed_union)
-
- return '%.4f' % (iou * 100,), '%.4f' % (tk_iou * 100,)
-
- def get_values_of_evaluation_metrics(self) -> List[str]:
- return \
- list(self._get_eval_metrics(
- self._normed_intersection_per_soundness[1],
- self._normed_union_per_soundness[1],
- 'Sounds',
- self._tk_normed_intersection_per_soundness[1],
- self._tk_normed_union_per_soundness[1],
- )) + \
- list(self._get_eval_metrics(
- self._normed_intersection_per_soundness[0],
- self._normed_union_per_soundness[0],
- 'Mistakes',
- self._tk_normed_intersection_per_soundness[0],
- self._tk_normed_union_per_soundness[0],
- )) + \
- list(self._get_eval_metrics(
- sum(self._normed_intersection_per_soundness),
- sum(self._normed_union_per_soundness),
- 'All',
- sum(self._tk_normed_intersection_per_soundness),
- sum(self._tk_normed_union_per_soundness),
- ))
-
- def print_summaries(self):
- titles = self.get_titles_of_evaluation_metrics()
- vals = self.get_values_of_evaluation_metrics()
-
- nc = 2
- for r in range(3):
- print(', '.join(['%s: %s' % (titles[nc * r + i], vals[nc * r + i]) for i in range(nc)]))
|