You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluator.py 3.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. """ The evaluator """
  2. from abc import abstractmethod, ABC
  3. from typing import List, TYPE_CHECKING, Dict
  4. import torch
  5. from tqdm import tqdm
  6. from ..data.dataflow import DataFlow
  7. if TYPE_CHECKING:
  8. from ..models.model import Model
  9. from ..data.data_loader import DataLoader
  10. from ..configs.base_config import BaseConfig
  11. class Evaluator(ABC):
  12. """ The evaluator """
  13. def __init__(self, model: 'Model', data_loader: 'DataLoader', conf: 'BaseConfig'):
  14. """ Model is a predictor, data_loader is a subclass of type data_loading.NormalDataLoader which contains information about the samples and how to
  15. iterate over them. """
  16. self.model = model
  17. self.data_loader = data_loader
  18. self.conf = conf
  19. self.dataflow = DataFlow[Dict[str, torch.Tensor]](model, data_loader, conf.device)
  20. def evaluate(self, max_iters: int = None):
  21. """ CAUTION: Resets the data_loader,
  22. Iterates over the samples (as much as and how data_loader specifies),
  23. calculates the overall evaluation requirements and prints them.
  24. Title specified the title of the string for printing evaluation metrics.
  25. classes_to_use specifies the labels of the samples to do the function on them,
  26. None means all"""
  27. # setting in dataflow
  28. max_iters = float('inf') if max_iters is None else max_iters
  29. with torch.no_grad():
  30. # Running the model in evaluation mode
  31. self.model.eval()
  32. # initiating variables for running evaluations
  33. self.reset()
  34. with self.dataflow, tqdm(enumerate(self.dataflow.iterate())) as pbar:
  35. for iters, model_output in pbar:
  36. self.update_summaries_based_on_model_output(model_output)
  37. del model_output
  38. pbar.set_description(self.get_evaluation_metrics())
  39. if iters + 1 >= max_iters:
  40. break
  41. @abstractmethod
  42. def update_summaries_based_on_model_output(
  43. self, model_output: Dict[str, torch.Tensor]) -> None:
  44. """ Updates the inner variables responsible of keeping a summary over data
  45. based on the new outputs of the model, so when needed evaluation metrics
  46. can be calculated based on these summaries. Mostly used in train and validation phase
  47. or eval phase in which we only need evaluation metrics."""
  48. @abstractmethod
  49. def reset(self):
  50. """ Resets the held information for a new evaluation round!"""
  51. @abstractmethod
  52. def get_titles_of_evaluation_metrics(self) -> List[str]:
  53. """ Returns a list of strings containing the titles of the evaluation metrics"""
  54. @abstractmethod
  55. def get_values_of_evaluation_metrics(self) -> List[str]:
  56. """ Returns a list of values for the calculated evaluation metrics,
  57. converted to string with the desired format!"""
  58. def get_evaluation_metrics(self) -> None:
  59. return ', '.join(["%s: %s" % (eval_title, eval_value) for (eval_title, eval_value) in
  60. zip(
  61. self.get_titles_of_evaluation_metrics(),
  62. self.get_values_of_evaluation_metrics())])
  63. def print_evaluation_metrics(self, title: str) -> None:
  64. """ Prints the values of the evaluation metrics"""
  65. print("%s: %s" % (title, self.get_evaluation_metrics()), flush=True)