12345678910111213141516171819202122232425262728293031323334353637383940 |
-
- from typing import Dict, Tuple
- import warnings
-
- import torch
- from torch import nn
-
- from ..utils.hooker import Hook
- from ..models.model import Model
- from ..configs.base_config import BaseConfig, PhaseType
-
-
- class AuxOutput:
-
- def __init__(self, model: Model, layers_by_name: Dict[str, nn.Module]) -> None:
- self._model = model
- self._layers_by_name = layers_by_name
- self._outputs: Dict[str, torch.Tensor] = {}
-
- def _generate_layer_hook(self, name: str) -> Hook:
- def layer_hook(_: nn.Module, __: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None:
- assert name not in self._outputs, f'{name} is already in model output'
- self._outputs[name] = output
- return layer_hook
-
- def _generate_output_modifier(self) -> Hook:
- def output_modifier(_: nn.Module, __: Tuple[torch.Tensor, ...], output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- if any(name in output for name in self._outputs):
- warnings.warn('one of aux outputs is already in model output')
- while len(self._outputs) > 0:
- name, value = self._outputs.popitem()
- assert name not in output, f'{name} is already in model output'
- output[name] = value
- return output
- return output_modifier
-
- def configure(self, config: BaseConfig) -> None:
- for phase in [PhaseType.TRAIN, PhaseType.EVAL]:
- config.hooks_by_phase[phase] += [(l, self._generate_layer_hook(n)) for n, l in self._layers_by_name.items()]
- config.hooks_by_phase[phase].append((self._model, self._generate_output_modifier()))
|