1234567891011121314151617181920212223242526272829303132333435 |
- from typing import TYPE_CHECKING, Type
-
- import numpy as np
-
- from ..data.data_loader import DataLoader
- from .binary_evaluator import BinaryEvaluator
- from ..models.model import Model
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
- from ..data.content_loaders.content_loader import ContentLoader
-
-
- class TagBinaryEvaluator(BinaryEvaluator):
-
- def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig', tag: str, cl_type: Type[ContentLoader]):
- super().__init__(model, data_loader, conf)
- self._tag = tag
- self._content_loader = data_loader.get_content_loader_of_interest(cl_type)
-
- @property
- def _prob_key(self) -> str:
- return f'{self._tag}_positive_class_probability'
-
- @property
- def _gt_key(self) -> str:
- return self._tag
-
- @property
- def _loss_key(self) -> str:
- return f'{self._tag}_loss'
-
- def _get_current_batch_gt(self) -> np.ndarray:
- return self._content_loader.get_placeholder_name_to_fill_function_dict()[self._gt_key](
- self.data_loader.get_current_batch_sample_indices(), None
- )
|