""" The evaluator """ from abc import abstractmethod, ABC from typing import List, TYPE_CHECKING, Dict import torch from tqdm import tqdm from ..data.dataflow import DataFlow if TYPE_CHECKING: from ..models.model import Model from ..data.data_loader import DataLoader from ..configs.base_config import BaseConfig class Evaluator(ABC): """ The evaluator """ def __init__(self, model: 'Model', data_loader: 'DataLoader', conf: 'BaseConfig'): """ Model is a predictor, data_loader is a subclass of type data_loading.NormalDataLoader which contains information about the samples and how to iterate over them. """ self.model = model self.data_loader = data_loader self.conf = conf self.dataflow = DataFlow[Dict[str, torch.Tensor]](model, data_loader, conf.device) def evaluate(self, max_iters: int = None): """ CAUTION: Resets the data_loader, Iterates over the samples (as much as and how data_loader specifies), calculates the overall evaluation requirements and prints them. Title specified the title of the string for printing evaluation metrics. classes_to_use specifies the labels of the samples to do the function on them, None means all""" # setting in dataflow max_iters = float('inf') if max_iters is None else max_iters with torch.no_grad(): # Running the model in evaluation mode self.model.eval() # initiating variables for running evaluations self.reset() with self.dataflow, tqdm(enumerate(self.dataflow.iterate())) as pbar: for iters, model_output in pbar: self.update_summaries_based_on_model_output(model_output) del model_output pbar.set_description(self.get_evaluation_metrics()) if iters + 1 >= max_iters: break @abstractmethod def update_summaries_based_on_model_output( self, model_output: Dict[str, torch.Tensor]) -> None: """ Updates the inner variables responsible of keeping a summary over data based on the new outputs of the model, so when needed evaluation metrics can be calculated based on these summaries. Mostly used in train and validation phase or eval phase in which we only need evaluation metrics.""" @abstractmethod def reset(self): """ Resets the held information for a new evaluation round!""" @abstractmethod def get_titles_of_evaluation_metrics(self) -> List[str]: """ Returns a list of strings containing the titles of the evaluation metrics""" @abstractmethod def get_values_of_evaluation_metrics(self) -> List[str]: """ Returns a list of values for the calculated evaluation metrics, converted to string with the desired format!""" def get_evaluation_metrics(self) -> None: return ', '.join(["%s: %s" % (eval_title, eval_value) for (eval_title, eval_value) in zip( self.get_titles_of_evaluation_metrics(), self.get_values_of_evaluation_metrics())]) def print_evaluation_metrics(self, title: str) -> None: """ Prints the values of the evaluation metrics""" print("%s: %s" % (title, self.get_evaluation_metrics()), flush=True)