You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

aux_output.py 1.6KB

1 year ago
12345678910111213141516171819202122232425262728293031323334353637383940
  1. from typing import Dict, Tuple
  2. import warnings
  3. import torch
  4. from torch import nn
  5. from ..utils.hooker import Hook
  6. from ..models.model import Model
  7. from ..configs.base_config import BaseConfig, PhaseType
  8. class AuxOutput:
  9. def __init__(self, model: Model, layers_by_name: Dict[str, nn.Module]) -> None:
  10. self._model = model
  11. self._layers_by_name = layers_by_name
  12. self._outputs: Dict[str, torch.Tensor] = {}
  13. def _generate_layer_hook(self, name: str) -> Hook:
  14. def layer_hook(_: nn.Module, __: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None:
  15. assert name not in self._outputs, f'{name} is already in model output'
  16. self._outputs[name] = output
  17. return layer_hook
  18. def _generate_output_modifier(self) -> Hook:
  19. def output_modifier(_: nn.Module, __: Tuple[torch.Tensor, ...], output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
  20. if any(name in output for name in self._outputs):
  21. warnings.warn('one of aux outputs is already in model output')
  22. while len(self._outputs) > 0:
  23. name, value = self._outputs.popitem()
  24. assert name not in output, f'{name} is already in model output'
  25. output[name] = value
  26. return output
  27. return output_modifier
  28. def configure(self, config: BaseConfig) -> None:
  29. for phase in [PhaseType.TRAIN, PhaseType.EVAL]:
  30. config.hooks_by_phase[phase] += [(l, self._generate_layer_hook(n)) for n, l in self._layers_by_name.items()]
  31. config.hooks_by_phase[phase].append((self._model, self._generate_output_modifier()))