from typing import Callable, Tuple, Dict from torch import nn import torch from ..models.model import Model from ..configs.base_config import BaseConfig, PhaseType class OutputModifier: def __init__(self, model: Model, modifier: Callable[[torch.Tensor], torch.Tensor], *keys: str) -> None: self._model = model self._modifier = modifier self._keys = keys def configure(self, config: BaseConfig) -> None: for phase in [PhaseType.TRAIN, PhaseType.EVAL]: config.hooks_by_phase[phase].append((self._model, self._generate_output_modifier())) def _generate_output_modifier(self): def output_modifier(_: nn.Module, __: Tuple[torch.Tensor, ...], out: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: out.update({ key: self._modifier(out[key]) for key in self._keys }) return output_modifier