You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

binary_tag_evaluator.py 1.1KB

1234567891011121314151617181920212223242526272829303132333435
  1. from typing import TYPE_CHECKING, Type
  2. import numpy as np
  3. from ..data.data_loader import DataLoader
  4. from .binary_evaluator import BinaryEvaluator
  5. from ..models.model import Model
  6. if TYPE_CHECKING:
  7. from ..configs.base_config import BaseConfig
  8. from ..data.content_loaders.content_loader import ContentLoader
  9. class TagBinaryEvaluator(BinaryEvaluator):
  10. def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig', tag: str, cl_type: Type[ContentLoader]):
  11. super().__init__(model, data_loader, conf)
  12. self._tag = tag
  13. self._content_loader = data_loader.get_content_loader_of_interest(cl_type)
  14. @property
  15. def _prob_key(self) -> str:
  16. return f'{self._tag}_positive_class_probability'
  17. @property
  18. def _gt_key(self) -> str:
  19. return self._tag
  20. @property
  21. def _loss_key(self) -> str:
  22. return f'{self._tag}_loss'
  23. def _get_current_batch_gt(self) -> np.ndarray:
  24. return self._content_loader.get_placeholder_name_to_fill_function_dict()[self._gt_key](
  25. self.data_loader.get_current_batch_sample_indices(), None
  26. )