1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- from abc import ABC, abstractmethod
- from functools import partial
- from typing import Dict, List, Tuple
-
- import torch
- from torch import nn
-
- from ..configs.base_config import BaseConfig, PhaseType
-
-
- class AuxLoss(ABC):
-
- def __init__(self, title: str, model: nn.Module,
- layer_by_name: Dict[str, nn.Module],
- loss_weight: float):
- """This class is used for calculating loss based on the middle layers of the network.
-
- Args:
- title (str): A unique title, the required inputs would be kept by this title in the output dictionary. Also the loss term would be added as title_loss
- model (nn.Module): The base model
- named_layers_list (List[Tuple[str, torch.nn.Module]]): A list of pairs, layer name and the layer. The names used for one instance of this class should be unique.
- loss_weight (float): The weight of the loss in the final loss. The loss term would be automatically added to the model's loss in model_output dictionary by this factor.
- """
-
- self._title = title
- self._model = model
- self._loss_weight = loss_weight
- self._layers_names = layer_by_name.keys()
-
- self._layers_outputs_by_name: Dict[str, torch.Tensor] = dict()
-
- self.hook_pairs = [
- (module, partial(self._save_layer_val_hook, name))
- for name, module in layer_by_name.items()
- ] + [(model, self._add_loss)]
-
- def _save_layer_val_hook(self, layer_name, _, __, layer_out):
- self._layers_outputs_by_name[layer_name] = layer_out
-
- def configure(self, config: BaseConfig) -> None:
- """Must be called to add the necessary configs and hooks so the loss will be calculated!
-
- Args:
- config (BaseConfig): The base config!
- """
-
- for phase in PhaseType:
- config.hooks_by_phase[phase] += self.hook_pairs
-
- def _add_loss(self, _, __, model_out: Dict[str, torch.Tensor]):
- """ A hook for calculating and adding loss the the model's output dictionary
-
- Args:
- _ ([type]): Model, useless! (just the hook format)
- __ ([type]): Model input, also useless! (just the hook format)
- model_out (Dict[str, torch.Tensor]): The output dictionary of the model
- """
-
- # popping values of tensors
- layers_values = []
- for layer_name in self._layers_names:
- assert layer_name in self._layers_outputs_by_name, f'The output of {layer_name} was not found. It seems the forward function has not been called.'
- layers_values.append((layer_name, self._layers_outputs_by_name.pop(layer_name)))
-
- loss = self._calculate_loss(layers_values, model_out)
-
- # adding loss to output dictionary
- assert f'{self._title}_loss' not in model_out, f'Trying to add {self._title}_loss to model\'s output multiple times!'
- model_out[f'{self._title}_loss'] = loss.clone()
-
- # adding loss to the main loss
- model_out['loss'] = model_out.get('loss', torch.zeros([], dtype=loss.dtype, device=loss.device)) + self._loss_weight * loss
-
- return model_out
-
- @abstractmethod
- def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor:
- """Calculates loss based on the output of the layers given in the constructor.
-
- Args:
- layers_values (List[Tuple[str, torch.Tensor]]): List of layers names and their output value
- model_output (Dict[str, torch.Tensor]): The output dictionary of the model, if something else is required for the loss calculation! or it can be used to add stuff into the output!
-
- Returns:
- torch.Tensor: The loss
- """
|