123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- from typing import List, TYPE_CHECKING, Dict, OrderedDict, Type
- from collections import OrderedDict as ODict
-
- import torch
-
- from ..data.data_loader import DataLoader
- from ..models.model import Model
- from ..model_evaluation.evaluator import Evaluator
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
-
- # If you want your model to be analyzed from different points of view by different evaluators, you can use this holder!
-
-
- class MultiEvaluatorEvaluator(Evaluator):
-
- def __init__(
- self, model: Model, data_loader: DataLoader, conf: 'BaseConfig',
- evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]):
- """
- evaluators_cls_by_name: The key is an arbitrary name to call the instance,
- cls is the constructor
- """
- super(MultiEvaluatorEvaluator, self).__init__(model, data_loader, conf)
-
- # Making all evaluators!
- self._evaluators_by_name: OrderedDict[str, Evaluator] = ODict()
- for eval_name, eval_cls in evaluators_cls_by_name.items():
- self._evaluators_by_name[eval_name] = eval_cls(model, data_loader, conf)
-
- def reset(self):
- for evaluator in self._evaluators_by_name.values():
- evaluator.reset()
-
- def get_titles_of_evaluation_metrics(self) -> List[str]:
- titles = []
- for eval_name, evaluator in self._evaluators_by_name.items():
- titles += [eval_name + '_' + t for t in evaluator.get_titles_of_evaluation_metrics()]
- return titles
-
- def get_values_of_evaluation_metrics(self) -> List[str]:
- metrics = []
- for evaluator in self._evaluators_by_name.values():
- metrics += evaluator.get_values_of_evaluation_metrics()
- return metrics
-
- def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:
- for evaluator in self._evaluators_by_name.values():
- evaluator.update_summaries_based_on_model_output(model_output)
-
- @staticmethod
- def create_standard_multi_evaluator_evaluator_maker(evaluators_cls_by_name: OrderedDict[str, Type[Evaluator]]):
- """ For making a constructor, consistent with the known Evaluator"""
- def typical_maker(model: Model, data_loader: DataLoader, conf: 'BaseConfig') -> MultiEvaluatorEvaluator:
- return MultiEvaluatorEvaluator(model, data_loader, conf, evaluators_cls_by_name)
- return typical_maker
-
- def print_evaluation_metrics(self, title: str) -> None:
- """ For more readable printing! """
-
- print(f'{title}:')
- for e_name, e_obj in self._evaluators_by_name.items():
- e_obj.print_evaluation_metrics('\t' + e_name)
|