from typing import Dict, Tuple import warnings import torch from torch import nn from ..utils.hooker import Hook from ..models.model import Model from ..configs.base_config import BaseConfig, PhaseType class AuxOutput: def __init__(self, model: Model, layers_by_name: Dict[str, nn.Module]) -> None: self._model = model self._layers_by_name = layers_by_name self._outputs: Dict[str, torch.Tensor] = {} def _generate_layer_hook(self, name: str) -> Hook: def layer_hook(_: nn.Module, __: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None: assert name not in self._outputs, f'{name} is already in model output' self._outputs[name] = output return layer_hook def _generate_output_modifier(self) -> Hook: def output_modifier(_: nn.Module, __: Tuple[torch.Tensor, ...], output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if any(name in output for name in self._outputs): warnings.warn('one of aux outputs is already in model output') while len(self._outputs) > 0: name, value = self._outputs.popitem() assert name not in output, f'{name} is already in model output' output[name] = value return output return output_modifier def configure(self, config: BaseConfig) -> None: for phase in [PhaseType.TRAIN, PhaseType.EVAL]: config.hooks_by_phase[phase] += [(l, self._generate_layer_hook(n)) for n, l in self._layers_by_name.items()] config.hooks_by_phase[phase].append((self._model, self._generate_output_modifier()))