from typing import List, TYPE_CHECKING, Dict, Callable from functools import partial import numpy as np from torch import Tensor from ..data.data_loader import DataLoader from ..models.model import Model from ..model_evaluation.evaluator import Evaluator if TYPE_CHECKING: from ..configs.base_config import BaseConfig class BinForetellerEvaluator(Evaluator): def __init__(self, kw_prefix: str, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): super(BinForetellerEvaluator, self).__init__(model, data_loader, conf) self._tp_by_kw: Dict[str, int] = dict() self._fp_by_kw: Dict[str, int] = dict() self._tn_by_kw: Dict[str, int] = dict() self._fn_by_kw: Dict[str, int] = dict() self._kw_prefix = kw_prefix def reset(self): self._tp_by_kw: Dict[str, int] = dict() self._fp_by_kw: Dict[str, int] = dict() self._tn_by_kw: Dict[str, int] = dict() self._fn_by_kw: Dict[str, int] = dict() def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None: gt = self.data_loader.get_current_batch_samples_labels() # looking for prefixes for k in model_output.keys(): if k.startswith(self._kw_prefix): if k not in self._tp_by_kw: self._tp_by_kw[k] = 0 self._tn_by_kw[k] = 0 self._fp_by_kw[k] = 0 self._fn_by_kw[k] = 0 pred = (model_output[k].cpu().numpy() >= 0.5).astype(int) self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1)) self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1)) self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0)) self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0)) def get_titles_of_evaluation_metrics(self) -> List[str]: return [f'{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']] def _get_values_of_evaluation_metrics(self, kw) -> List[str]: tp = self._tp_by_kw[kw] tn = self._tn_by_kw[kw] fp = self._fp_by_kw[kw] fn = self._fn_by_kw[kw] p = tp + fn n = tn + fp if p + n > 0: accuracy = (tp + tn) * 100.0 / (n + p) else: accuracy = -1 if p > 0: sensitivity = 100.0 * tp / p else: sensitivity = -1 if n > 0: specificity = 100.0 * tn / max(n, 1) else: specificity = -1 if sensitivity > -1 and specificity > -1: avg_ss = 0.5 * (sensitivity + specificity) elif sensitivity > -1: avg_ss = sensitivity else: avg_ss = specificity return ['%.2f' % accuracy, '%.2f' % sensitivity, '%.2f' % specificity, '%.2f' % avg_ss] def get_values_of_evaluation_metrics(self) -> List[str]: return [ v for k in self._tp_by_kw.keys() for v in self._get_values_of_evaluation_metrics(k)] @classmethod def standard_creator(cls, prefix_kw: str) -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinForetellerEvaluator']: return partial(BinForetellerEvaluator, prefix_kw) def print_evaluation_metrics(self, title: str) -> None: """ For more readable printing! """ print(f'{title}:') for k in self._tp_by_kw.keys(): print( f'\t{k}:: ' + ', '.join([f'{m_name}: {m_val}' for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))]))