You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

output_modifier.py 906B

123456789101112131415161718192021222324252627
  1. from typing import Callable, Tuple, Dict
  2. from torch import nn
  3. import torch
  4. from ..models.model import Model
  5. from ..configs.base_config import BaseConfig, PhaseType
  6. class OutputModifier:
  7. def __init__(self, model: Model, modifier: Callable[[torch.Tensor], torch.Tensor], *keys: str) -> None:
  8. self._model = model
  9. self._modifier = modifier
  10. self._keys = keys
  11. def configure(self, config: BaseConfig) -> None:
  12. for phase in [PhaseType.TRAIN, PhaseType.EVAL]:
  13. config.hooks_by_phase[phase].append((self._model, self._generate_output_modifier()))
  14. def _generate_output_modifier(self):
  15. def output_modifier(_: nn.Module, __: Tuple[torch.Tensor, ...], out: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
  16. out.update({
  17. key: self._modifier(out[key]) for key in self._keys
  18. })
  19. return output_modifier