You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

binary_faithfulness.py 3.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from functools import partial
  2. from typing import List, TYPE_CHECKING, Dict, Callable
  3. import numpy as np
  4. from torch import Tensor
  5. from ..data.data_loader import DataLoader
  6. from ..models.model import Model
  7. from ..model_evaluation.evaluator import Evaluator
  8. if TYPE_CHECKING:
  9. from ..configs.base_config import BaseConfig
  10. class BinFaithfulnessEvaluator(Evaluator):
  11. def __init__(self, kw_prefix: str, gt_kw, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
  12. super(BinFaithfulnessEvaluator, self).__init__(model, data_loader, conf)
  13. self._tp_by_kw: Dict[str, int] = dict()
  14. self._fp_by_kw: Dict[str, int] = dict()
  15. self._tn_by_kw: Dict[str, int] = dict()
  16. self._fn_by_kw: Dict[str, int] = dict()
  17. self._kw_prefix = kw_prefix
  18. self._gt_kw = gt_kw
  19. def reset(self):
  20. self._tp_by_kw: Dict[str, int] = dict()
  21. self._fp_by_kw: Dict[str, int] = dict()
  22. self._tn_by_kw: Dict[str, int] = dict()
  23. self._fn_by_kw: Dict[str, int] = dict()
  24. def update_summaries_based_on_model_output(self, model_output: Dict[str, Tensor]) -> None:
  25. gt = (model_output[self._gt_kw].detach().cpu().numpy() >= 0.5).astype(int)
  26. # looking for prefixes
  27. for k in model_output.keys():
  28. if k.startswith(self._kw_prefix):
  29. if k not in self._tp_by_kw:
  30. self._tp_by_kw[k] = 0
  31. self._tn_by_kw[k] = 0
  32. self._fp_by_kw[k] = 0
  33. self._fn_by_kw[k] = 0
  34. pred = (model_output[k].cpu().numpy() >= 0.5).astype(int)
  35. self._tp_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 1))
  36. self._fp_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 1))
  37. self._tn_by_kw[k] += np.sum(np.logical_and(gt == 0, pred == 0))
  38. self._fn_by_kw[k] += np.sum(np.logical_and(gt == 1, pred == 0))
  39. def get_titles_of_evaluation_metrics(self) -> List[str]:
  40. return [f'Loyalty_{k}_{metric}' for k in self._tp_by_kw.keys() for metric in ['Acc', 'Sens', 'Spec', 'AvgSS']]
  41. def _get_values_of_evaluation_metrics(self, kw) -> List[str]:
  42. tp = self._tp_by_kw[kw]
  43. tn = self._tn_by_kw[kw]
  44. fp = self._fp_by_kw[kw]
  45. fn = self._fn_by_kw[kw]
  46. p = tp + fn
  47. n = tn + fp
  48. if p + n > 0:
  49. accuracy = (tp + tn) * 100.0 / (n + p)
  50. else:
  51. accuracy = -1
  52. if p > 0:
  53. sensitivity = 100.0 * tp / p
  54. else:
  55. sensitivity = -1
  56. if n > 0:
  57. specificity = 100.0 * tn / max(n, 1)
  58. else:
  59. specificity = -1
  60. if sensitivity > -1 and specificity > -1:
  61. avg_ss = 0.5 * (sensitivity + specificity)
  62. elif sensitivity > -1:
  63. avg_ss = sensitivity
  64. else:
  65. avg_ss = specificity
  66. return ['%.2f' % accuracy, '%.2f' % sensitivity,
  67. '%.2f' % specificity, '%.2f' % avg_ss]
  68. def get_values_of_evaluation_metrics(self) -> List[str]:
  69. return [
  70. v
  71. for k in self._tp_by_kw.keys()
  72. for v in self._get_values_of_evaluation_metrics(k)]
  73. @classmethod
  74. def standard_creator(cls, prefix_kw: str, pred_kw: str = 'positive_class_probability') -> Callable[[Model, DataLoader, 'BaseConfig'], 'BinFaithfulnessEvaluator']:
  75. return partial(BinFaithfulnessEvaluator, prefix_kw, pred_kw)
  76. def print_evaluation_metrics(self, title: str) -> None:
  77. """ For more readable printing! """
  78. print(f'{title}:')
  79. for k in self._tp_by_kw.keys():
  80. print(
  81. f'\t{k}:: ' +
  82. ', '.join([f'{m_name}: {m_val}'
  83. for m_name, m_val in zip(['Acc', 'Sens', 'Spec', 'AvgSS'], self._get_values_of_evaluation_metrics(k))]))