123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- """ The evaluator """
- from abc import abstractmethod, ABC
- from typing import List, TYPE_CHECKING, Dict
-
- import torch
- from tqdm import tqdm
-
- from ..data.dataflow import DataFlow
- if TYPE_CHECKING:
- from ..models.model import Model
- from ..data.data_loader import DataLoader
- from ..configs.base_config import BaseConfig
-
-
- class Evaluator(ABC):
- """ The evaluator """
-
- def __init__(self, model: 'Model', data_loader: 'DataLoader', conf: 'BaseConfig'):
- """ Model is a predictor, data_loader is a subclass of type data_loading.NormalDataLoader which contains information about the samples and how to
- iterate over them. """
- self.model = model
- self.data_loader = data_loader
- self.conf = conf
- self.dataflow = DataFlow[Dict[str, torch.Tensor]](model, data_loader, conf.device)
-
- def evaluate(self, max_iters: int = None):
- """ CAUTION: Resets the data_loader,
- Iterates over the samples (as much as and how data_loader specifies),
- calculates the overall evaluation requirements and prints them.
- Title specified the title of the string for printing evaluation metrics.
- classes_to_use specifies the labels of the samples to do the function on them,
- None means all"""
-
- # setting in dataflow
- max_iters = float('inf') if max_iters is None else max_iters
-
- with torch.no_grad():
-
- # Running the model in evaluation mode
- self.model.eval()
-
- # initiating variables for running evaluations
- self.reset()
-
- with self.dataflow, tqdm(enumerate(self.dataflow.iterate())) as pbar:
- for iters, model_output in pbar:
- self.update_summaries_based_on_model_output(model_output)
- del model_output
-
- pbar.set_description(self.get_evaluation_metrics())
-
- if iters + 1 >= max_iters:
- break
-
- @abstractmethod
- def update_summaries_based_on_model_output(
- self, model_output: Dict[str, torch.Tensor]) -> None:
- """ Updates the inner variables responsible of keeping a summary over data
- based on the new outputs of the model, so when needed evaluation metrics
- can be calculated based on these summaries. Mostly used in train and validation phase
- or eval phase in which we only need evaluation metrics."""
-
- @abstractmethod
- def reset(self):
- """ Resets the held information for a new evaluation round!"""
-
- @abstractmethod
- def get_titles_of_evaluation_metrics(self) -> List[str]:
- """ Returns a list of strings containing the titles of the evaluation metrics"""
-
- @abstractmethod
- def get_values_of_evaluation_metrics(self) -> List[str]:
- """ Returns a list of values for the calculated evaluation metrics,
- converted to string with the desired format!"""
-
- def get_evaluation_metrics(self) -> None:
- return ', '.join(["%s: %s" % (eval_title, eval_value) for (eval_title, eval_value) in
- zip(
- self.get_titles_of_evaluation_metrics(),
- self.get_values_of_evaluation_metrics())])
-
- def print_evaluation_metrics(self, title: str) -> None:
- """ Prints the values of the evaluation metrics"""
- print("%s: %s" % (title, self.get_evaluation_metrics()), flush=True)
|