You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

binary_fortelling.py 3.6KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from typing import List, TYPE_CHECKING, Dict, Callable
  2. from functools import partial
  3. import numpy as np
  4. from torch import Tensor
  5. from ..data.data_loader import DataLoader
  6. from ..models.model import Model
  7. from ..model_evaluation.evaluator import Evaluator
  8. if TYPE_CHECKING:
  9. from ..configs.base_config import BaseConfig
  10. class BinForetellerEvaluator(Evaluator):
  11. def __init__(self, kw_prefix: str, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
  12. super(BinForetellerEvaluator, self).__init__(model, data_loader, conf)
  13. self._tp_by_kw: Dict[str, int] = dict()
  14. self._fp_by_kw: Dict[str, int] = dict()
  15. self._tn_by_kw: Dict[str, int] = dict()
  16. self._fn_by_kw: Dict[str, int] = dict()
  17. self._kw_prefix = kw_prefix
  18. def reset(self):
  19. self._tp_by_kw: Dict[str, int] = dict()
  20. self._fp_by_kw: Dict[str, int] = dict()
  21. self._tn_by_kw: Dict[str, int] = dict()
  22. self._fn_by_kw: Dict[str, int] = dict()
  23. def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None:
  24. gt = self.data_loader.get_current_batch_samples_labels()
  25. # looking for prefixes
  26. for k in model_output.keys():
  27. if k.startswith(self._kw_prefix):
  28. if k not in self._tp_by_kw:
  29. self._tp_by_kw[k] = 0
  30. self._tn_by_kw[k] = 0
  31. self._fp_by_kw[k] = 0
  32. self._fn_by_kw[k] = 0
  33. pred = (model_output[k].cpu().numpy() >= 0.5).astype(int)
  34. self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1))
  35. self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1))
  36. self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0))
  37. self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0))
  38. def get_titles_of_evaluation_metrics(self) -> List[str]:
  39. return [f'{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']]
  40. def _get_values_of_evaluation_metrics(self, kw) -> List[str]:
  41. tp = self._tp_by_kw[kw]
  42. tn = self._tn_by_kw[kw]
  43. fp = self._fp_by_kw[kw]
  44. fn = self._fn_by_kw[kw]
  45. p = tp + fn
  46. n = tn + fp
  47. if p + n > 0:
  48. accuracy = (tp + tn) * 100.0 / (n + p)
  49. else:
  50. accuracy = -1
  51. if p > 0:
  52. sensitivity = 100.0 * tp / p
  53. else:
  54. sensitivity = -1
  55. if n > 0:
  56. specificity = 100.0 * tn / max(n, 1)
  57. else:
  58. specificity = -1
  59. if sensitivity > -1 and specificity > -1:
  60. avg_ss = 0.5 * (sensitivity + specificity)
  61. elif sensitivity > -1:
  62. avg_ss = sensitivity
  63. else:
  64. avg_ss = specificity
  65. return ['%.2f' % accuracy, '%.2f' % sensitivity,
  66. '%.2f' % specificity, '%.2f' % avg_ss]
  67. def get_values_of_evaluation_metrics(self) -> List[str]:
  68. return [
  69. v
  70. for k in self._tp_by_kw.keys()
  71. for v in self._get_values_of_evaluation_metrics(k)]
  72. @classmethod
  73. def standard_creator(cls, prefix_kw: str) -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinForetellerEvaluator']:
  74. return partial(BinForetellerEvaluator, prefix_kw)
  75. def print_evaluation_metrics(self, title: str) -> None:
  76. """ For more readable printing! """
  77. print(f'{title}:')
  78. for k in self._tp_by_kw.keys():
  79. print(
  80. f'\t{k}:: ' +
  81. ', '.join([f'{m_name}: {m_val}'
  82. for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))]))