You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

binary_evaluator.py 3.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from typing import List, TYPE_CHECKING, Dict
  2. import numpy as np
  3. import torch
  4. from ..data.data_loader import DataLoader
  5. from ..models.model import Model
  6. from ..model_evaluation.evaluator import Evaluator
  7. if TYPE_CHECKING:
  8. from ..configs.base_config import BaseConfig
  9. class BinaryEvaluator(Evaluator):
  10. def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
  11. super(BinaryEvaluator, self).__init__(model, data_loader, conf)
  12. # for summarized information
  13. self.tp: int = 0
  14. self.tn: int = 0
  15. self.fp: int = 0
  16. self.fn: int = 0
  17. self.avg_loss: float = 0
  18. self.n_received_samples: int = 0
  19. def reset(self):
  20. self.tp = 0
  21. self.tn = 0
  22. self.fp = 0
  23. self.fn = 0
  24. self.avg_loss = 0
  25. self.n_received_samples = 0
  26. @property
  27. def _prob_key(self) -> str:
  28. return 'positive_class_probability'
  29. @property
  30. def _loss_key(self) -> str:
  31. return 'loss'
  32. def _get_current_batch_gt(self) -> np.ndarray:
  33. return self.data_loader.get_current_batch_samples_labels()
  34. def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:
  35. assert self._prob_key in model_output, \
  36. f"model's output must contain {self._prob_key}"
  37. gt = self._get_current_batch_gt()
  38. prediction = (model_output[self._prob_key].cpu().numpy() >= 0.5).astype(int)
  39. ntp = int(np.sum(np.logical_and(gt == 1, prediction == 1)))
  40. ntn = int(np.sum(np.logical_and(gt == 0, prediction == 0)))
  41. nfp = int(np.sum(np.logical_and(gt == 0, prediction == 1)))
  42. nfn = int(np.sum(np.logical_and(gt == 1, prediction == 0)))
  43. self.tp += ntp
  44. self.tn += ntn
  45. self.fp += nfp
  46. self.fn += nfn
  47. new_n = self.n_received_samples + len(gt)
  48. self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \
  49. model_output.get(self._loss_key, 0.0) * (float(len(gt)) / new_n)
  50. self.n_received_samples = new_n
  51. def get_titles_of_evaluation_metrics(self) -> List[str]:
  52. return ['Loss', 'Acc', 'Sens', 'Spec', 'BAcc', 'N']
  53. def get_values_of_evaluation_metrics(self) -> List[str]:
  54. p = self.tp + self.fn
  55. n = self.tn + self.fp
  56. if p + n > 0:
  57. accuracy = (self.tp + self.tn) * 100.0 / (n + p)
  58. else:
  59. accuracy = -1
  60. if p > 0:
  61. sensitivity = 100.0 * self.tp / p
  62. else:
  63. sensitivity = -1
  64. if n > 0:
  65. specificity = 100.0 * self.tn / max(n, 1)
  66. else:
  67. specificity = -1
  68. if sensitivity > -1 and specificity > -1:
  69. avg_ss = 0.5 * (sensitivity + specificity)
  70. elif sensitivity > -1:
  71. avg_ss = sensitivity
  72. else:
  73. avg_ss = specificity
  74. return ['%.4f' % self.avg_loss, '%.2f' % accuracy, '%.2f' % sensitivity,
  75. '%.2f' % specificity, '%.2f' % avg_ss, str(self.n_received_samples)]