123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- from typing import Any, Callable, List, Dict, TYPE_CHECKING, Optional
- from functools import partial
-
- import numpy as np
- import torch
- from sklearn.metrics import confusion_matrix
-
- from ..data.data_loader import DataLoader
- from ..models.model import Model
- from ..model_evaluation.evaluator import Evaluator
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
-
-
- def _precision_score(cm: np.ndarray):
- numinator = np.diag(cm)
- denominator = cm.sum(axis=0)
- return (numinator[denominator != 0] / denominator[denominator != 0]).mean()
-
-
- def _recall_score(cm: np.ndarray):
- numinator = np.diag(cm)
- denominator = cm.sum(axis=1)
- return (numinator[denominator != 0] / denominator[denominator != 0]).mean()
-
-
- def _accuracy_score(cm: np.ndarray):
- return np.diag(cm).sum() / cm.sum()
-
-
- class MulticlassEvaluator(Evaluator):
-
- @classmethod
- def standard_creator(cls, class_probability_key: str = 'categorical_probability',
- include_top5: bool = False) -> Callable[[Model, DataLoader, 'BaseConfig'], 'MulticlassEvaluator']:
- return partial(MulticlassEvaluator, class_probability_key=class_probability_key, include_top5=include_top5)
-
- def __init__(self, model: Model,
- data_loader: DataLoader,
- conf: 'BaseConfig',
- class_probability_key: str = 'categorical_probability',
- include_top5: bool = False):
- super().__init__(model, data_loader, conf)
-
- # for summarized information
- self._avg_loss: float = 0
- self._n_received_samples: int = 0
- self._trues = {}
- self._falses = {}
- self._top5_trues: int = 0
- self._top5_falses: int = 0
- self._num_iters = 0
- self._predictions: Dict[int, np.ndarray] = {}
- self._prediction_weights: Dict[int, np.ndarray] = {}
- self._sample_details: Dict[str, Dict[str, Any]] = {}
- self._cm: Optional[np.ndarray] = None
- self._c = None
- self._class_probability_key = class_probability_key
- self._include_top5 = include_top5
-
- def reset(self):
- # for summarized information
- self._avg_loss: float = 0
- self._n_received_samples: int = 0
- self._trues = {}
- self._falses = {}
- self._top5_trues: int = 0
- self._top5_falses: int = 0
- self._num_iters = 0
- self._cm: Optional[np.ndarray] = None
-
- def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:
-
- assert self._class_probability_key in model_output, \
- f"model's output dictionary must contain {self._class_probability_key}"
-
- prediction: np.ndarray = model_output[self._class_probability_key]\
- .cpu().numpy() # B C
- c = prediction.shape[1]
- ground_truth = self.data_loader.get_current_batch_samples_labels() # B
-
- if self._include_top5:
- top5_p = prediction.argsort(axis=1)[:, -5:]
- trues = (top5_p == ground_truth[:, None]).any(axis=1).sum()
- self._top5_trues += trues
- self._top5_falses += len(ground_truth) - trues
-
- prediction = prediction.argmax(axis=1) # B
-
- for i in range(c):
- nt = int(np.sum(np.logical_and(ground_truth == i, prediction == i)))
- nf = int(np.sum(np.logical_and(ground_truth == i, prediction != i)))
- if i not in self._trues:
- self._trues[i] = 0
- self._falses[i] = 0
- self._trues[i] += nt
- self._falses[i] += nf
-
- self._cm = (0 if self._cm is None else self._cm)\
- + confusion_matrix(ground_truth, prediction,
- labels=np.arange(c)).astype(float)
- self._num_iters += 1
-
- new_n = self._n_received_samples + len(ground_truth)
- self._avg_loss = self._avg_loss * (float(self._n_received_samples) / new_n) + \
- model_output.get('loss', 0.0) * (float(len(ground_truth)) / new_n)
-
- self._n_received_samples = new_n
-
- def get_titles_of_evaluation_metrics(self) -> List[str]:
- return ['Loss', 'Accuracy', 'Precision', 'Recall'] + \
- (['Top5Acc'] if self._include_top5 else [])
-
- def get_values_of_evaluation_metrics(self) -> List[str]:
- accuracy = _accuracy_score(self._cm) * 100
- precision = _precision_score(self._cm) * 100
- recall = _recall_score(self._cm) * 100
- top5_acc = self._top5_trues * 100.0 / (self._top5_trues + self._top5_falses)\
- if self._include_top5 else None
-
- return [f'{self._avg_loss:.4e}', f'{accuracy:8.4f}', f'{precision:8.4f}', f'{recall:8.4f}'] + \
- ([f'{top5_acc:8.4f}'] if self._include_top5 else [])
|