from typing import List, TYPE_CHECKING, OrderedDict as OrdDict, Dict from collections import OrderedDict import torch from ..data.data_loader import DataLoader from ..models.model import Model from ..model_evaluation.evaluator import Evaluator if TYPE_CHECKING: from ..configs.base_config import BaseConfig class LossEvaluator(Evaluator): def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'): super(LossEvaluator, self).__init__(model, data_loader, conf) self.phase = conf.phase_type # for summarized information self.avg_loss: float = 0 self._avg_other_losses: OrdDict[str, float] = OrderedDict() self.n_received_samples: int = 0 self._n_received_samples_other_losses: OrdDict[str, int] = OrderedDict() def reset(self): self.avg_loss = 0 self._avg_other_losses = OrderedDict() self.n_received_samples = 0 self._n_received_samples_other_losses = OrderedDict() def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None: n_batch = self.data_loader.get_current_batch_size() new_n = self.n_received_samples + n_batch self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \ model_output.get('loss', 0.0) * (float(n_batch) / new_n) for kw in model_output.keys(): if 'loss' in kw and kw != 'loss': old_avg = self._avg_other_losses.get(kw, 0.0) old_n = self._n_received_samples_other_losses.get(kw, 0) new_n = old_n + n_batch self._avg_other_losses[kw] = \ old_avg * (float(old_n) / new_n) + \ model_output[kw].detach().cpu() * (float(n_batch) / new_n) self._n_received_samples_other_losses[kw] = new_n def get_titles_of_evaluation_metrics(self) -> List[str]: return ['Loss'] + list(self._avg_other_losses.keys()) def get_values_of_evaluation_metrics(self) -> List[str]: return ['%.4f' % self.avg_loss] + ['%.4f' % loss for loss in self._avg_other_losses.values()]