You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

aux_loss.py 3.7KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from abc import ABC, abstractmethod
  2. from functools import partial
  3. from typing import Dict, List, Tuple
  4. import torch
  5. from torch import nn
  6. from ..configs.base_config import BaseConfig, PhaseType
  7. class AuxLoss(ABC):
  8. def __init__(self, title: str, model: nn.Module,
  9. layer_by_name: Dict[str, nn.Module],
  10. loss_weight: float):
  11. """This class is used for calculating loss based on the middle layers of the network.
  12. Args:
  13. title (str): A unique title, the required inputs would be kept by this title in the output dictionary. Also the loss term would be added as title_loss
  14. model (nn.Module): The base model
  15. named_layers_list (List[Tuple[str, torch.nn.Module]]): A list of pairs, layer name and the layer. The names used for one instance of this class should be unique.
  16. loss_weight (float): The weight of the loss in the final loss. The loss term would be automatically added to the model's loss in model_output dictionary by this factor.
  17. """
  18. self._title = title
  19. self._model = model
  20. self._loss_weight = loss_weight
  21. self._layers_names = layer_by_name.keys()
  22. self._layers_outputs_by_name: Dict[str, torch.Tensor] = dict()
  23. self.hook_pairs = [
  24. (module, partial(self._save_layer_val_hook, name))
  25. for name, module in layer_by_name.items()
  26. ] + [(model, self._add_loss)]
  27. def _save_layer_val_hook(self, layer_name, _, __, layer_out):
  28. self._layers_outputs_by_name[layer_name] = layer_out
  29. def configure(self, config: BaseConfig) -> None:
  30. """Must be called to add the necessary configs and hooks so the loss will be calculated!
  31. Args:
  32. config (BaseConfig): The base config!
  33. """
  34. for phase in PhaseType:
  35. config.hooks_by_phase[phase] += self.hook_pairs
  36. def _add_loss(self, _, __, model_out: Dict[str, torch.Tensor]):
  37. """ A hook for calculating and adding loss the the model's output dictionary
  38. Args:
  39. _ ([type]): Model, useless! (just the hook format)
  40. __ ([type]): Model input, also useless! (just the hook format)
  41. model_out (Dict[str, torch.Tensor]): The output dictionary of the model
  42. """
  43. # popping values of tensors
  44. layers_values = []
  45. for layer_name in self._layers_names:
  46. assert layer_name in self._layers_outputs_by_name, f'The output of {layer_name} was not found. It seems the forward function has not been called.'
  47. layers_values.append((layer_name, self._layers_outputs_by_name.pop(layer_name)))
  48. loss = self._calculate_loss(layers_values, model_out)
  49. # adding loss to output dictionary
  50. assert f'{self._title}_loss' not in model_out, f'Trying to add {self._title}_loss to model\'s output multiple times!'
  51. model_out[f'{self._title}_loss'] = loss.clone()
  52. # adding loss to the main loss
  53. model_out['loss'] = model_out.get('loss', torch.zeros([], dtype=loss.dtype, device=loss.device)) + self._loss_weight * loss
  54. return model_out
  55. @abstractmethod
  56. def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor:
  57. """Calculates loss based on the output of the layers given in the constructor.
  58. Args:
  59. layers_values (List[Tuple[str, torch.Tensor]]): List of layers names and their output value
  60. model_output (Dict[str, torch.Tensor]): The output dictionary of the model, if something else is required for the loss calculation! or it can be used to add stuff into the output!
  61. Returns:
  62. torch.Tensor: The loss
  63. """