You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

binary_interpretation_evaluator_2d.py 7.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from typing import List, TYPE_CHECKING, Tuple
  2. import numpy as np
  3. from skimage import measure
  4. from .utils import process_interpretations
  5. if TYPE_CHECKING:
  6. from ..configs.base_config import BaseConfig
  7. from . import Interpreter
  8. class BinaryInterpretationEvaluator2D:
  9. def __init__(self, n_samples: int, config: 'BaseConfig'):
  10. self._config = config
  11. self._n_samples = n_samples
  12. self._min_intersection_threshold = config.acceptable_min_intersection_threshold
  13. self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
  14. self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
  15. self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
  16. self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
  17. def reset(self):
  18. self._normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
  19. self._normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
  20. self._tk_normed_intersection_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
  21. self._tk_normed_union_per_soundness: np.ndarray = np.asarray([0, 0], dtype=float)
  22. def update_summaries(
  23. self,
  24. m_interpretations: np.ndarray, ground_truth_interpretations: np.ndarray,
  25. net_preds: np.ndarray, ground_truth_labels: np.ndarray,
  26. batch_inds: np.ndarray, interpreter: 'Interpreter'
  27. ) -> None:
  28. assert len(ground_truth_interpretations.shape) == 3, f'GT interpretations must have a shape of BxWxH but it is {ground_truth_interpretations.shape}'
  29. assert len(m_interpretations.shape) == 3, f'Model interpretations must have a shape of BxWxH but it is {m_interpretations.shape}'
  30. # skipping samples without interpretations!
  31. has_interpretation_mask = np.logical_not(np.any(np.isnan(ground_truth_interpretations), axis=(1, 2)))
  32. m_interpretations = m_interpretations[has_interpretation_mask]
  33. ground_truths = ground_truth_interpretations[has_interpretation_mask]
  34. net_preds[has_interpretation_mask]
  35. # finding class labels
  36. if len(net_preds.shape) == 1:
  37. net_preds = (net_preds >= 0.5).astype(int)
  38. elif net_preds.shape[1] == 1:
  39. net_preds = (net_preds[:, 0] >= 0.5).astype(int)
  40. else:
  41. net_preds = net_preds.argmax(axis=-1)
  42. ground_truth_labels = ground_truth_labels[has_interpretation_mask]
  43. batch_inds = batch_inds[has_interpretation_mask]
  44. # Checking shapes
  45. if net_preds.shape == ground_truth_labels.shape:
  46. net_preds = np.round(net_preds, 0).astype(int)
  47. else:
  48. net_preds = np.argmax(net_preds, axis=1)
  49. # calculating soundness
  50. soundnesses = (net_preds == ground_truth_labels).astype(int)
  51. c_interpretations = np.clip(m_interpretations, 0, np.amax(m_interpretations))
  52. b_interpretations = np.stack(tuple([
  53. process_interpretations(m_interpretations[ind][None, ...], self._config, interpreter)
  54. for ind in range(len(m_interpretations))
  55. ]), axis=0)[:, 0, ...]
  56. b_interpretations = (b_interpretations > 0).astype(bool)
  57. ground_truths = (ground_truths >= 0.5).astype(bool) #making sure values are 0 and 1 even if resize has been applied
  58. assert ground_truths.shape[-2:] == b_interpretations.shape[-2:], f'Ground truth and model interpretations must have the same shape, found {ground_truths.shape[-2:]} and {b_interpretations.shape[-2:]}'
  59. norm_factor = 1.0 * b_interpretations.shape[1] * b_interpretations.shape[2]
  60. np.add.at(self._normed_intersection_per_soundness, soundnesses,
  61. np.sum(b_interpretations & ground_truths, axis=(1, 2)) * 1.0 / norm_factor)
  62. np.add.at(self._normed_union_per_soundness, soundnesses,
  63. np.sum(b_interpretations | ground_truths, axis=(1, 2)) * 1.0 / norm_factor)
  64. for i in range(len(b_interpretations)):
  65. has_nonzero_captured_bbs = False
  66. has_nonzero_captured_bbs_by_topk = False
  67. s = soundnesses[i]
  68. org_labels = measure.label(ground_truths[i, :, :])
  69. check_labels = measure.label(b_interpretations[i, :, :])
  70. # keeping topK interpretations with k = n_GT! = finding a threshold by quantile! calculating quantile by GT
  71. n_on_gt = np.sum(ground_truths[i])
  72. q = (1 + n_on_gt) * 1.0 / (ground_truths.shape[-1] * ground_truths.shape[-2])
  73. # 1 is added because we have > in thresholding not >=
  74. if q < 1:
  75. tints = c_interpretations[i]
  76. th = max(0, np.quantile(tints.reshape(-1), 1 - q))
  77. tints = (tints > th)
  78. else:
  79. tints = (c_interpretations[i] > 0)
  80. # TOPK METRICS
  81. tk_intersection = np.sum(tints & ground_truths[i])
  82. tk_union = np.sum(tints | ground_truths[i])
  83. self._tk_normed_intersection_per_soundness[s] += tk_intersection * 1.0 / norm_factor
  84. self._tk_normed_union_per_soundness[s] += tk_union * 1.0 / norm_factor
  85. @staticmethod
  86. def get_titles_of_evaluation_metrics() -> List[str]:
  87. return ['S-IOU', 'S-TK-IOU',
  88. 'M-IOU', 'M-TK-IOU',
  89. 'A-IOU', 'A-TK-IOU']
  90. @staticmethod
  91. def _get_eval_metrics(normed_intersection, normed_union, title,
  92. tk_normed_intersection, tk_normed_union) -> \
  93. Tuple[str, str, str, str, str]:
  94. iou = (1e-6 + normed_intersection) / (1e-6 + normed_union)
  95. tk_iou = (1e-6 + tk_normed_intersection) / (1e-6 + tk_normed_union)
  96. return '%.4f' % (iou * 100,), '%.4f' % (tk_iou * 100,)
  97. def get_values_of_evaluation_metrics(self) -> List[str]:
  98. return \
  99. list(self._get_eval_metrics(
  100. self._normed_intersection_per_soundness[1],
  101. self._normed_union_per_soundness[1],
  102. 'Sounds',
  103. self._tk_normed_intersection_per_soundness[1],
  104. self._tk_normed_union_per_soundness[1],
  105. )) + \
  106. list(self._get_eval_metrics(
  107. self._normed_intersection_per_soundness[0],
  108. self._normed_union_per_soundness[0],
  109. 'Mistakes',
  110. self._tk_normed_intersection_per_soundness[0],
  111. self._tk_normed_union_per_soundness[0],
  112. )) + \
  113. list(self._get_eval_metrics(
  114. sum(self._normed_intersection_per_soundness),
  115. sum(self._normed_union_per_soundness),
  116. 'All',
  117. sum(self._tk_normed_intersection_per_soundness),
  118. sum(self._tk_normed_union_per_soundness),
  119. ))
  120. def print_summaries(self):
  121. titles = self.get_titles_of_evaluation_metrics()
  122. vals = self.get_values_of_evaluation_metrics()
  123. nc = 2
  124. for r in range(3):
  125. print(', '.join(['%s: %s' % (titles[nc * r + i], vals[nc * r + i]) for i in range(nc)]))