You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

multiclass_evaluator.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from typing import Any, Callable, List, Dict, TYPE_CHECKING, Optional
  2. from functools import partial
  3. import numpy as np
  4. import torch
  5. from sklearn.metrics import confusion_matrix
  6. from ..data.data_loader import DataLoader
  7. from ..models.model import Model
  8. from ..model_evaluation.evaluator import Evaluator
  9. if TYPE_CHECKING:
  10. from ..configs.base_config import BaseConfig
  11. def _precision_score(cm: np.ndarray):
  12. numinator = np.diag(cm)
  13. denominator = cm.sum(axis=0)
  14. return (numinator[denominator != 0] / denominator[denominator != 0]).mean()
  15. def _recall_score(cm: np.ndarray):
  16. numinator = np.diag(cm)
  17. denominator = cm.sum(axis=1)
  18. return (numinator[denominator != 0] / denominator[denominator != 0]).mean()
  19. def _accuracy_score(cm: np.ndarray):
  20. return np.diag(cm).sum() / cm.sum()
  21. class MulticlassEvaluator(Evaluator):
  22. @classmethod
  23. def standard_creator(cls, class_probability_key: str = 'categorical_probability',
  24. include_top5: bool = False) -> Callable[[Model, DataLoader, 'BaseConfig'], 'MulticlassEvaluator']:
  25. return partial(MulticlassEvaluator, class_probability_key=class_probability_key, include_top5=include_top5)
  26. def __init__(self, model: Model,
  27. data_loader: DataLoader,
  28. conf: 'BaseConfig',
  29. class_probability_key: str = 'categorical_probability',
  30. include_top5: bool = False):
  31. super().__init__(model, data_loader, conf)
  32. # for summarized information
  33. self._avg_loss: float = 0
  34. self._n_received_samples: int = 0
  35. self._trues = {}
  36. self._falses = {}
  37. self._top5_trues: int = 0
  38. self._top5_falses: int = 0
  39. self._num_iters = 0
  40. self._predictions: Dict[int, np.ndarray] = {}
  41. self._prediction_weights: Dict[int, np.ndarray] = {}
  42. self._sample_details: Dict[str, Dict[str, Any]] = {}
  43. self._cm: Optional[np.ndarray] = None
  44. self._c = None
  45. self._class_probability_key = class_probability_key
  46. self._include_top5 = include_top5
  47. def reset(self):
  48. # for summarized information
  49. self._avg_loss: float = 0
  50. self._n_received_samples: int = 0
  51. self._trues = {}
  52. self._falses = {}
  53. self._top5_trues: int = 0
  54. self._top5_falses: int = 0
  55. self._num_iters = 0
  56. self._cm: Optional[np.ndarray] = None
  57. def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:
  58. assert self._class_probability_key in model_output, \
  59. f"model's output dictionary must contain {self._class_probability_key}"
  60. prediction: np.ndarray = model_output[self._class_probability_key]\
  61. .cpu().numpy() # B C
  62. c = prediction.shape[1]
  63. ground_truth = self.data_loader.get_current_batch_samples_labels() # B
  64. if self._include_top5:
  65. top5_p = prediction.argsort(axis=1)[:, -5:]
  66. trues = (top5_p == ground_truth[:, None]).any(axis=1).sum()
  67. self._top5_trues += trues
  68. self._top5_falses += len(ground_truth) - trues
  69. prediction = prediction.argmax(axis=1) # B
  70. for i in range(c):
  71. nt = int(np.sum(np.logical_and(ground_truth == i, prediction == i)))
  72. nf = int(np.sum(np.logical_and(ground_truth == i, prediction != i)))
  73. if i not in self._trues:
  74. self._trues[i] = 0
  75. self._falses[i] = 0
  76. self._trues[i] += nt
  77. self._falses[i] += nf
  78. self._cm = (0 if self._cm is None else self._cm)\
  79. + confusion_matrix(ground_truth, prediction,
  80. labels=np.arange(c)).astype(float)
  81. self._num_iters += 1
  82. new_n = self._n_received_samples + len(ground_truth)
  83. self._avg_loss = self._avg_loss * (float(self._n_received_samples) / new_n) + \
  84. model_output.get('loss', 0.0) * (float(len(ground_truth)) / new_n)
  85. self._n_received_samples = new_n
  86. def get_titles_of_evaluation_metrics(self) -> List[str]:
  87. return ['Loss', 'Accuracy', 'Precision', 'Recall'] + \
  88. (['Top5Acc'] if self._include_top5 else [])
  89. def get_values_of_evaluation_metrics(self) -> List[str]:
  90. accuracy = _accuracy_score(self._cm) * 100
  91. precision = _precision_score(self._cm) * 100
  92. recall = _recall_score(self._cm) * 100
  93. top5_acc = self._top5_trues * 100.0 / (self._top5_trues + self._top5_falses)\
  94. if self._include_top5 else None
  95. return [f'{self._avg_loss:.4e}', f'{accuracy:8.4f}', f'{precision:8.4f}', f'{recall:8.4f}'] + \
  96. ([f'{top5_acc:8.4f}'] if self._include_top5 else [])