from typing import Dict, Union, List, Tuple import numpy as np import torch from torch.nn import functional as F from .aux_loss import AuxLoss from ..data.dataloader_context import DataloaderContext from ..data.data_loader import DataLoader class PoolConcordanceLossCalculator(AuxLoss): def __init__(self, title: str, model: torch.nn.Module, layer_by_name: Dict[str, torch.nn.Module], loss_weight: float, weights: Union[float, List[float]], diff_thresholds: Union[float, List[float]], labels_by_channel: Dict[int, List[int]]): super().__init__(title, model, layer_by_name, loss_weight) # at least two pools are needed assert len(layer_by_name) > 1, 'At least two pool layers are required to calculate this loss!' self._title = title self._model = model self._loss_prefix = f'{title}_JS_loss_' self._labels_by_channel = labels_by_channel self._pool_names = list(layer_by_name.keys()) self._weights = weights if isinstance(weights, list) else \ [weights for _ in range(len(layer_by_name) - 1)] if isinstance(weights, list): assert len(weights) == len(layer_by_name) - 1, 'Weights must have a length of pool_layers -1' self._diff_thresholds = diff_thresholds if isinstance(diff_thresholds, list) else \ [diff_thresholds for _ in range(len(layer_by_name) - 1)] if isinstance(diff_thresholds, list): assert len(diff_thresholds) == len(layer_by_name) - 1, 'Diff thresholds must have a length of pool_layers -1' def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor: # reading pools from the output and deleting them pool_vals = [x[1] for x in layers_values] # getting samples labels dl: DataLoader = DataloaderContext.instance.dataloader labels = dl.get_current_batch_samples_labels() labels = torch.from_numpy(labels).to(dl._device) # sorting by shape from biggest to smallest p_shapes = np.asarray([p.shape[-1] for p in pool_vals]) sorted_inds = np.argsort(-1 * p_shapes) pool_vals = [pool_vals[i] for i in sorted_inds] pool_names = [self._pool_names[i] for i in sorted_inds] loss = torch.zeros([], dtype=torch.float32, device=labels.device, requires_grad=True) # for each pair of pools, calculate loss! for i in range(len(pool_names) - 1): p_loss = self._cal_pool_pair_loss(pool_vals[i], pool_vals[i + 1], self._diff_thresholds[i], labels) assert (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') not in model_output, 'Trying to add ' + (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') + ' to model output multiple times' model_output[f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}'] = p_loss.clone() loss = loss + self._weights[i] * p_loss return loss def _cal_pool_pair_loss(self, p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float, labels: torch.Tensor) -> torch.Tensor: # down-sampling by max-pool till reaching the same shape as p2! if p1.shape[-1] > p2.shape[-1]: p1 = F.adaptive_max_pool2d(p1, p2.shape[-2:]) # jensen shannon loss, for each channel -> in the related class! loss = torch.tensor(0.0, requires_grad=True, device=p1.device) for channel, r_labels in self._labels_by_channel.items(): on_mask = self._get_inclusion_mask(labels, r_labels) ip1 = p1[on_mask, channel, ...] ip2 = p2[on_mask, channel, ...] if torch.numel(ip1) > 0: if ip1.shape != ip2.shape: print(f'Problem in shape of concordance loss in {self._title}') loss = loss + jensen_shannon_divergence( ip1[:, None, ...], ip2[:, None, ...], diff_threshold) return loss def _get_inclusion_mask( self, samples_labels: torch.Tensor, desired_labels: List[int]) -> torch.Tensor: with torch.no_grad(): inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0) aggregation = torch.sum(inclusion_mask.float(), dim=0) return torch.greater(aggregation, 0) def jensen_shannon_divergence(p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float) -> torch.Tensor: """ Calculates the jensen shannon loss between two distributions p1 and p2 Args: p1 (torch.Tensor): A tensor of shape B C ... in range [0, 1] that is the probabilities for each neuron in the 1st distribution. p2 (torch.Tensor): A tensor of the same shape as src, in range [0, 1] that is the probabilities for each neuron in the 2nd distribution. diff_threshold (float): Threshold between p1 and p2 to decide whether to consider one pixel in loss Returns: torch.Tensor: The calculated loss. """ assert p1.shape == p2.shape, 'The tensors must have the same shape' assert 0 <= diff_threshold < 1, 'The difference threshold should be in range [0, 1)' # Reshaping tensors p1 = torch.transpose(p1, 0, 1).flatten(1).transpose(0, 1) p2 = torch.transpose(p2, 0, 1).flatten(1).transpose(0, 1) # if binary, append class 0! if p1.shape[1] == 1: p1 = torch.cat([1 - p1, p1], dim=1) p2 = torch.cat([1 - p2, p2], dim=1) with torch.no_grad(): mask = torch.abs(p1 - p2).detach() >= diff_threshold mask = mask.max(dim=-1)[0] # to make sure error does not result in log(negative)! lp1 = torch.log(torch.maximum(p1 + 1e-4, 1e-6 + torch.zeros_like(p1))) lp2 = torch.log(torch.maximum(p2 + 1e-4, 1e-6 + torch.zeros_like(p2))) loss = 0.5 * ( _smean(mask, torch.sum(p1 * (lp1 - lp2), dim=-1)) + _smean(mask, torch.sum(p2 * (lp2 - lp1), dim=-1)) ) return loss def _smean(mask: torch.Tensor, val: torch.Tensor) -> torch.Tensor: if torch.any(mask.bool()): return torch.mean(val[mask]) else: return torch.tensor(0.0, requires_grad=True, device=mask.device)