123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- from typing import List, TYPE_CHECKING, Dict, Callable
- from functools import partial
-
- import numpy as np
- from torch import Tensor
-
- from ..data.data_loader import DataLoader
- from ..models.model import Model
- from ..model_evaluation.evaluator import Evaluator
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
-
-
- class BinForetellerEvaluator(Evaluator):
-
- def __init__(self, kw_prefix: str, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
- super(BinForetellerEvaluator, self).__init__(model, data_loader, conf)
-
- self._tp_by_kw: Dict[str, int] = dict()
- self._fp_by_kw: Dict[str, int] = dict()
- self._tn_by_kw: Dict[str, int] = dict()
- self._fn_by_kw: Dict[str, int] = dict()
-
- self._kw_prefix = kw_prefix
-
- def reset(self):
- self._tp_by_kw: Dict[str, int] = dict()
- self._fp_by_kw: Dict[str, int] = dict()
- self._tn_by_kw: Dict[str, int] = dict()
- self._fn_by_kw: Dict[str, int] = dict()
-
- def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None:
-
- gt = self.data_loader.get_current_batch_samples_labels()
-
- # looking for prefixes
- for k in model_output.keys():
- if k.startswith(self._kw_prefix):
-
- if k not in self._tp_by_kw:
- self._tp_by_kw[k] = 0
- self._tn_by_kw[k] = 0
- self._fp_by_kw[k] = 0
- self._fn_by_kw[k] = 0
-
- pred = (model_output[k].cpu().numpy() >= 0.5).astype(int)
- self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1))
- self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1))
- self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0))
- self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0))
-
- def get_titles_of_evaluation_metrics(self) -> List[str]:
- return [f'{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']]
-
- def _get_values_of_evaluation_metrics(self, kw) -> List[str]:
-
- tp = self._tp_by_kw[kw]
- tn = self._tn_by_kw[kw]
- fp = self._fp_by_kw[kw]
- fn = self._fn_by_kw[kw]
-
- p = tp + fn
- n = tn + fp
-
- if p + n > 0:
- accuracy = (tp + tn) * 100.0 / (n + p)
- else:
- accuracy = -1
-
- if p > 0:
- sensitivity = 100.0 * tp / p
- else:
- sensitivity = -1
-
- if n > 0:
- specificity = 100.0 * tn / max(n, 1)
- else:
- specificity = -1
-
- if sensitivity > -1 and specificity > -1:
- avg_ss = 0.5 * (sensitivity + specificity)
- elif sensitivity > -1:
- avg_ss = sensitivity
- else:
- avg_ss = specificity
-
- return ['%.2f' % accuracy, '%.2f' % sensitivity,
- '%.2f' % specificity, '%.2f' % avg_ss]
-
- def get_values_of_evaluation_metrics(self) -> List[str]:
- return [
- v
- for k in self._tp_by_kw.keys()
- for v in self._get_values_of_evaluation_metrics(k)]
-
- @classmethod
- def standard_creator(cls, prefix_kw: str) -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinForetellerEvaluator']:
- return partial(BinForetellerEvaluator, prefix_kw)
-
- def print_evaluation_metrics(self, title: str) -> None:
- """ For more readable printing! """
-
- print(f'{title}:')
- for k in self._tp_by_kw.keys():
- print(
- f'\t{k}:: ' +
- ', '.join([f'{m_name}: {m_val}'
- for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))]))
|