from abc import ABC, abstractmethod from functools import partial from typing import Dict, List, Tuple import torch from torch import nn from ..configs.base_config import BaseConfig, PhaseType class AuxLoss(ABC): def __init__(self, title: str, model: nn.Module, layer_by_name: Dict[str, nn.Module], loss_weight: float): """This class is used for calculating loss based on the middle layers of the network. Args: title (str): A unique title, the required inputs would be kept by this title in the output dictionary. Also the loss term would be added as title_loss model (nn.Module): The base model named_layers_list (List[Tuple[str, torch.nn.Module]]): A list of pairs, layer name and the layer. The names used for one instance of this class should be unique. loss_weight (float): The weight of the loss in the final loss. The loss term would be automatically added to the model's loss in model_output dictionary by this factor. """ self._title = title self._model = model self._loss_weight = loss_weight self._layers_names = layer_by_name.keys() self._layers_outputs_by_name: Dict[str, torch.Tensor] = dict() self.hook_pairs = [ (module, partial(self._save_layer_val_hook, name)) for name, module in layer_by_name.items() ] + [(model, self._add_loss)] def _save_layer_val_hook(self, layer_name, _, __, layer_out): self._layers_outputs_by_name[layer_name] = layer_out def configure(self, config: BaseConfig) -> None: """Must be called to add the necessary configs and hooks so the loss will be calculated! Args: config (BaseConfig): The base config! """ for phase in PhaseType: config.hooks_by_phase[phase] += self.hook_pairs def _add_loss(self, _, __, model_out: Dict[str, torch.Tensor]): """ A hook for calculating and adding loss the the model's output dictionary Args: _ ([type]): Model, useless! (just the hook format) __ ([type]): Model input, also useless! (just the hook format) model_out (Dict[str, torch.Tensor]): The output dictionary of the model """ # popping values of tensors layers_values = [] for layer_name in self._layers_names: assert layer_name in self._layers_outputs_by_name, f'The output of {layer_name} was not found. It seems the forward function has not been called.' layers_values.append((layer_name, self._layers_outputs_by_name.pop(layer_name))) loss = self._calculate_loss(layers_values, model_out) # adding loss to output dictionary assert f'{self._title}_loss' not in model_out, f'Trying to add {self._title}_loss to model\'s output multiple times!' model_out[f'{self._title}_loss'] = loss.clone() # adding loss to the main loss model_out['loss'] = model_out.get('loss', torch.zeros([], dtype=loss.dtype, device=loss.device)) + self._loss_weight * loss return model_out @abstractmethod def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor: """Calculates loss based on the output of the layers given in the constructor. Args: layers_values (List[Tuple[str, torch.Tensor]]): List of layers names and their output value model_output (Dict[str, torch.Tensor]): The output dictionary of the model, if something else is required for the loss calculation! or it can be used to add stuff into the output! Returns: torch.Tensor: The loss """