123456789101112131415161718192021222324252627 |
- from typing import Callable, Tuple, Dict
-
- from torch import nn
- import torch
-
- from ..models.model import Model
- from ..configs.base_config import BaseConfig, PhaseType
-
-
- class OutputModifier:
-
- def __init__(self, model: Model, modifier: Callable[[torch.Tensor], torch.Tensor], *keys: str) -> None:
- self._model = model
- self._modifier = modifier
- self._keys = keys
-
- def configure(self, config: BaseConfig) -> None:
- for phase in [PhaseType.TRAIN, PhaseType.EVAL]:
- config.hooks_by_phase[phase].append((self._model, self._generate_output_modifier()))
-
-
- def _generate_output_modifier(self):
- def output_modifier(_: nn.Module, __: Tuple[torch.Tensor, ...], out: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- out.update({
- key: self._modifier(out[key]) for key in self._keys
- })
- return output_modifier
|