123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- from typing import List, TYPE_CHECKING, Dict
-
- import numpy as np
- import torch
-
- from ..data.data_loader import DataLoader
- from ..models.model import Model
- from ..model_evaluation.evaluator import Evaluator
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
-
-
- class BinaryEvaluator(Evaluator):
-
- def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
- super(BinaryEvaluator, self).__init__(model, data_loader, conf)
-
- # for summarized information
- self.tp: int = 0
- self.tn: int = 0
- self.fp: int = 0
- self.fn: int = 0
- self.avg_loss: float = 0
- self.n_received_samples: int = 0
-
- def reset(self):
- self.tp = 0
- self.tn = 0
- self.fp = 0
- self.fn = 0
- self.avg_loss = 0
- self.n_received_samples = 0
-
- @property
- def _prob_key(self) -> str:
- return 'positive_class_probability'
-
- @property
- def _loss_key(self) -> str:
- return 'loss'
-
- def _get_current_batch_gt(self) -> np.ndarray:
- return self.data_loader.get_current_batch_samples_labels()
-
- def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:
- assert self._prob_key in model_output, \
- f"model's output must contain {self._prob_key}"
-
- gt = self._get_current_batch_gt()
- prediction = (model_output[self._prob_key].cpu().numpy() >= 0.5).astype(int)
-
- ntp = int(np.sum(np.logical_and(gt == 1, prediction == 1)))
- ntn = int(np.sum(np.logical_and(gt == 0, prediction == 0)))
- nfp = int(np.sum(np.logical_and(gt == 0, prediction == 1)))
- nfn = int(np.sum(np.logical_and(gt == 1, prediction == 0)))
-
- self.tp += ntp
- self.tn += ntn
- self.fp += nfp
- self.fn += nfn
-
- new_n = self.n_received_samples + len(gt)
- self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \
- model_output.get(self._loss_key, 0.0) * (float(len(gt)) / new_n)
-
- self.n_received_samples = new_n
-
- def get_titles_of_evaluation_metrics(self) -> List[str]:
- return ['Loss', 'Acc', 'Sens', 'Spec', 'BAcc', 'N']
-
- def get_values_of_evaluation_metrics(self) -> List[str]:
-
- p = self.tp + self.fn
- n = self.tn + self.fp
-
- if p + n > 0:
- accuracy = (self.tp + self.tn) * 100.0 / (n + p)
- else:
- accuracy = -1
-
- if p > 0:
- sensitivity = 100.0 * self.tp / p
- else:
- sensitivity = -1
-
- if n > 0:
- specificity = 100.0 * self.tn / max(n, 1)
- else:
- specificity = -1
-
- if sensitivity > -1 and specificity > -1:
- avg_ss = 0.5 * (sensitivity + specificity)
- elif sensitivity > -1:
- avg_ss = sensitivity
- else:
- avg_ss = specificity
-
- return ['%.4f' % self.avg_loss, '%.2f' % accuracy, '%.2f' % sensitivity,
- '%.2f' % specificity, '%.2f' % avg_ss, str(self.n_received_samples)]
|