You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

multieval_evaluator.py 2.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from typing import List, TYPE_CHECKING, Dict, OrderedDict, Type
  2. from collections import OrderedDict as ODict
  3. import torch
  4. from ..data.data_loader import DataLoader
  5. from ..models.model import Model
  6. from ..model_evaluation.evaluator import Evaluator
  7. if TYPE_CHECKING:
  8. from ..configs.base_config import BaseConfig
  9. # If you want your model to be analyzed from different points of view by different evaluators, you can use this holder!
  10. class MultiEvaluatorEvaluator(Evaluator):
  11. def __init__(
  12. self, model: Model, data_loader: DataLoader, conf: 'BaseConfig',
  13. evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]):
  14. """
  15. evaluators_cls_by_name: The key is an arbitrary name to call the instance,
  16. cls is the constructor
  17. """
  18. super(MultiEvaluatorEvaluator, self).__init__(model, data_loader, conf)
  19. # Making all evaluators!
  20. self._evaluators_by_name: OrderedDict[str, Evaluator] = ODict()
  21. for eval_name, eval_cls in evaluators_cls_by_name.items():
  22. self._evaluators_by_name[eval_name] = eval_cls(model, data_loader, conf)
  23. def reset(self):
  24. for evaluator in self._evaluators_by_name.values():
  25. evaluator.reset()
  26. def get_titles_of_evaluation_metrics(self) -> List[str]:
  27. titles = []
  28. for eval_name, evaluator in self._evaluators_by_name.items():
  29. titles += [eval_name + '_' + t for t in evaluator.get_titles_of_evaluation_metrics()]
  30. return titles
  31. def get_values_of_evaluation_metrics(self) -> List[str]:
  32. metrics = []
  33. for evaluator in self._evaluators_by_name.values():
  34. metrics += evaluator.get_values_of_evaluation_metrics()
  35. return metrics
  36. def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:
  37. for evaluator in self._evaluators_by_name.values():
  38. evaluator.update_summaries_based_on_model_output(model_output)
  39. @staticmethod
  40. def create_standard_multi_evaluator_evaluator_maker(evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]):
  41. """ For making a constructor, consistent with the known Evaluator"""
  42. def typical_maker(model: Model, data_loader: DataLoader, conf: 'BaseConfig') -> MultiEvaluatorEvaluator:
  43. return MultiEvaluatorEvaluator(model, data_loader, conf, evaluators_cls_by_name)
  44. return typical_maker
  45. def print_evaluation_metrics(self, title: str) -> None:
  46. """ For more readable printing! """
  47. print(f'{title}:')
  48. for e_name, e_obj in self._evaluators_by_name.items():
  49. e_obj.print_evaluation_metrics('\t' + e_name)