You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss_evaluator.py 2.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from typing import List, TYPE_CHECKING, OrderedDict as OrdDict, Dict
  2. from collections import OrderedDict
  3. import torch
  4. from ..data.data_loader import DataLoader
  5. from ..models.model import Model
  6. from ..model_evaluation.evaluator import Evaluator
  7. if TYPE_CHECKING:
  8. from ..configs.base_config import BaseConfig
  9. class LossEvaluator(Evaluator):
  10. def __init__(self, model: Model, data_loader: DataLoader, conf: 'BaseConfig'):
  11. super(LossEvaluator, self).__init__(model, data_loader, conf)
  12. self.phase = conf.phase_type
  13. # for summarized information
  14. self.avg_loss: float = 0
  15. self._avg_other_losses: OrdDict[str, float] = OrderedDict()
  16. self.n_received_samples: int = 0
  17. self._n_received_samples_other_losses: OrdDict[str, int] = OrderedDict()
  18. def reset(self):
  19. self.avg_loss = 0
  20. self._avg_other_losses = OrderedDict()
  21. self.n_received_samples = 0
  22. self._n_received_samples_other_losses = OrderedDict()
  23. def update_summaries_based_on_model_output(self, model_output: Dict[str, torch.Tensor]) -> None:
  24. n_batch = self.data_loader.get_current_batch_size()
  25. new_n = self.n_received_samples + n_batch
  26. self.avg_loss = self.avg_loss * (float(self.n_received_samples) / new_n) + \
  27. model_output.get('loss', 0.0) * (float(n_batch) / new_n)
  28. for kw in model_output.keys():
  29. if 'loss' in kw and kw != 'loss':
  30. old_avg = self._avg_other_losses.get(kw, 0.0)
  31. old_n = self._n_received_samples_other_losses.get(kw, 0)
  32. new_n = old_n + n_batch
  33. self._avg_other_losses[kw] = \
  34. old_avg * (float(old_n) / new_n) + \
  35. model_output[kw].detach().cpu() * (float(n_batch) / new_n)
  36. self._n_received_samples_other_losses[kw] = new_n
  37. def get_titles_of_evaluation_metrics(self) -> List[str]:
  38. return ['Loss'] + list(self._avg_other_losses.keys())
  39. def get_values_of_evaluation_metrics(self) -> List[str]:
  40. return ['%.4f' % self.avg_loss] + ['%.4f' % loss for loss in self._avg_other_losses.values()]