|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- from typing import List, TYPE_CHECKING, OrderedDict as OrdDict, Dict
- from collections import OrderedDict
-
- import torch
-
- from ..data.data_loader import DataLoader
- from ..models.model import Model
- from ..model_evaluation.evaluator import Evaluator
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
-
-
- class LossEvaluator(Evaluator):
-
- def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
- super(LossEvaluator, self).__init__(model, data_loader, conf)
-
- self.phase = conf.phase_type
-
- # for summarized information
- self.avg_loss: float = 0
- self._avg_other_losses: OrdDict[str, float] = OrderedDict()
- self.n_received_samples: int = 0
- self._n_received_samples_other_losses: OrdDict[str, int] = OrderedDict()
-
- def reset(self):
- self.avg_loss = 0
- self._avg_other_losses = OrderedDict()
- self.n_received_samples = 0
- self._n_received_samples_other_losses = OrderedDict()
-
- def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:
-
- n_batch = self.data_loader.get_current_batch_size()
-
- new_n = self.n_received_samples + n_batch
- self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \
- model_output.get('loss', 0.0) * (float(n_batch) / new_n)
-
- for kw in model_output.keys():
- if 'loss' in kw and kw != 'loss':
-
- old_avg = self._avg_other_losses.get(kw, 0.0)
- old_n = self._n_received_samples_other_losses.get(kw, 0)
-
- new_n = old_n + n_batch
-
- self._avg_other_losses[kw] = \
- old_avg * (float(old_n) / new_n) + \
- model_output[kw].detach().cpu() * (float(n_batch) / new_n)
- self._n_received_samples_other_losses[kw] = new_n
-
- def get_titles_of_evaluation_metrics(self) -> List[str]:
- return ['Loss'] + list(self._avg_other_losses.keys())
-
- def get_values_of_evaluation_metrics(self) -> List[str]:
- return ['%.4f' % self.avg_loss] + ['%.4f' % loss for loss in self._avg_other_losses.values()]
|