|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- from typing import Dict, Union, List, Tuple
-
- import numpy as np
- import torch
- from torch.nn import functional as F
-
- from .aux_loss import AuxLoss
- from ..data.dataloader_context import DataloaderContext
- from ..data.data_loader import DataLoader
-
-
-
- class PoolConcordanceLossCalculator(AuxLoss):
-
- def __init__(self,
- title: str, model: torch.nn.Module,
- layer_by_name: Dict[str, torch.nn.Module],
- loss_weight: float, weights: Union[float, List[float]],
- diff_thresholds: Union[float, List[float]],
- labels_by_channel: Dict[int, List[int]]):
- super().__init__(title, model, layer_by_name, loss_weight)
-
- # at least two pools are needed
- assert len(layer_by_name) > 1, 'At least two pool layers are required to calculate this loss!'
-
- self._title = title
- self._model = model
- self._loss_prefix = f'{title}_JS_loss_'
-
- self._labels_by_channel = labels_by_channel
-
- self._pool_names = list(layer_by_name.keys())
- self._weights = weights if isinstance(weights, list) else \
- [weights for _ in range(len(layer_by_name) - 1)]
- if isinstance(weights, list):
- assert len(weights) == len(layer_by_name) - 1, 'Weights must have a length of pool_layers -1'
-
- self._diff_thresholds = diff_thresholds if isinstance(diff_thresholds, list) else \
- [diff_thresholds for _ in range(len(layer_by_name) - 1)]
- if isinstance(diff_thresholds, list):
- assert len(diff_thresholds) == len(layer_by_name) - 1, 'Diff thresholds must have a length of pool_layers -1'
-
- def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor:
-
- # reading pools from the output and deleting them
- pool_vals = [x[1] for x in layers_values]
-
- # getting samples labels
- dl: DataLoader = DataloaderContext.instance.dataloader
- labels = dl.get_current_batch_samples_labels()
- labels = torch.from_numpy(labels).to(dl._device)
-
- # sorting by shape from biggest to smallest
- p_shapes = np.asarray([p.shape[-1] for p in pool_vals])
- sorted_inds = np.argsort(-1 * p_shapes)
- pool_vals = [pool_vals[i] for i in sorted_inds]
- pool_names = [self._pool_names[i] for i in sorted_inds]
-
- loss = torch.zeros([], dtype=torch.float32, device=labels.device, requires_grad=True)
-
- # for each pair of pools, calculate loss!
- for i in range(len(pool_names) - 1):
- p_loss = self._cal_pool_pair_loss(pool_vals[i], pool_vals[i + 1], self._diff_thresholds[i], labels)
-
- assert (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') not in model_output, 'Trying to add ' + (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') + ' to model output multiple times'
-
- model_output[f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}'] = p_loss.clone()
- loss = loss + self._weights[i] * p_loss
-
- return loss
-
- def _cal_pool_pair_loss(self, p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float, labels: torch.Tensor) -> torch.Tensor:
-
- # down-sampling by max-pool till reaching the same shape as p2!
- if p1.shape[-1] > p2.shape[-1]:
- p1 = F.adaptive_max_pool2d(p1, p2.shape[-2:])
-
- # jensen shannon loss, for each channel -> in the related class!
- loss = torch.tensor(0.0, requires_grad=True, device=p1.device)
-
- for channel, r_labels in self._labels_by_channel.items():
- on_mask = self._get_inclusion_mask(labels, r_labels)
-
- ip1 = p1[on_mask, channel, ...]
- ip2 = p2[on_mask, channel, ...]
-
- if torch.numel(ip1) > 0:
-
- if ip1.shape != ip2.shape:
- print(f'Problem in shape of concordance loss in {self._title}')
-
- loss = loss + jensen_shannon_divergence(
- ip1[:, None, ...], ip2[:, None, ...], diff_threshold)
-
- return loss
-
- def _get_inclusion_mask(
- self,
- samples_labels: torch.Tensor,
- desired_labels: List[int]) -> torch.Tensor:
-
- with torch.no_grad():
- inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0)
- aggregation = torch.sum(inclusion_mask.float(), dim=0)
- return torch.greater(aggregation, 0)
-
-
- def jensen_shannon_divergence(p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float) -> torch.Tensor:
- """
- Calculates the jensen shannon loss between two distributions p1 and p2
-
- Args:
- p1 (torch.Tensor): A tensor of shape B C ... in range [0, 1] that is the probabilities for each neuron in the 1st distribution.
- p2 (torch.Tensor): A tensor of the same shape as src, in range [0, 1] that is the probabilities for each neuron in the 2nd distribution.
- diff_threshold (float): Threshold between p1 and p2 to decide whether to consider one pixel in loss
- Returns:
- torch.Tensor: The calculated loss.
- """
-
- assert p1.shape == p2.shape, 'The tensors must have the same shape'
- assert 0 <= diff_threshold < 1, 'The difference threshold should be in range [0, 1)'
-
- # Reshaping tensors
- p1 = torch.transpose(p1, 0, 1).flatten(1).transpose(0, 1)
- p2 = torch.transpose(p2, 0, 1).flatten(1).transpose(0, 1)
-
- # if binary, append class 0!
- if p1.shape[1] == 1:
- p1 = torch.cat([1 - p1, p1], dim=1)
- p2 = torch.cat([1 - p2, p2], dim=1)
-
- with torch.no_grad():
- mask = torch.abs(p1 - p2).detach() >= diff_threshold
- mask = mask.max(dim=-1)[0]
-
- # to make sure error does not result in log(negative)!
- lp1 = torch.log(torch.maximum(p1 + 1e-4, 1e-6 + torch.zeros_like(p1)))
- lp2 = torch.log(torch.maximum(p2 + 1e-4, 1e-6 + torch.zeros_like(p2)))
-
- loss = 0.5 * (
- _smean(mask, torch.sum(p1 * (lp1 - lp2), dim=-1)) +
- _smean(mask, torch.sum(p2 * (lp2 - lp1), dim=-1))
- )
-
- return loss
-
-
- def _smean(mask: torch.Tensor, val: torch.Tensor) -> torch.Tensor:
- if torch.any(mask.bool()):
- return torch.mean(val[mask])
- else:
- return torch.tensor(0.0, requires_grad=True, device=mask.device)
|