You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cw_concordance_loss.py 6.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from typing import Dict, Union, List, Tuple
  2. import numpy as np
  3. import torch
  4. from torch.nn import functional as F
  5. from .aux_loss import AuxLoss
  6. from ..data.dataloader_context import DataloaderContext
  7. from ..data.data_loader import DataLoader
  8. class PoolConcordanceLossCalculator(AuxLoss):
  9. def __init__(self,
  10. title: str, model: torch.nn.Module,
  11. layer_by_name: Dict[str, torch.nn.Module],
  12. loss_weight: float, weights: Union[float, List[float]],
  13. diff_thresholds: Union[float, List[float]],
  14. labels_by_channel: Dict[int, List[int]]):
  15. super().__init__(title, model, layer_by_name, loss_weight)
  16. # at least two pools are needed
  17. assert len(layer_by_name) > 1, 'At least two pool layers are required to calculate this loss!'
  18. self._title = title
  19. self._model = model
  20. self._loss_prefix = f'{title}_JS_loss_'
  21. self._labels_by_channel = labels_by_channel
  22. self._pool_names = list(layer_by_name.keys())
  23. self._weights = weights if isinstance(weights, list) else \
  24. [weights for _ in range(len(layer_by_name) - 1)]
  25. if isinstance(weights, list):
  26. assert len(weights) == len(layer_by_name) - 1, 'Weights must have a length of pool_layers -1'
  27. self._diff_thresholds = diff_thresholds if isinstance(diff_thresholds, list) else \
  28. [diff_thresholds for _ in range(len(layer_by_name) - 1)]
  29. if isinstance(diff_thresholds, list):
  30. assert len(diff_thresholds) == len(layer_by_name) - 1, 'Diff thresholds must have a length of pool_layers -1'
  31. def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor:
  32. # reading pools from the output and deleting them
  33. pool_vals = [x[1] for x in layers_values]
  34. # getting samples labels
  35. dl: DataLoader = DataloaderContext.instance.dataloader
  36. labels = dl.get_current_batch_samples_labels()
  37. labels = torch.from_numpy(labels).to(dl._device)
  38. # sorting by shape from biggest to smallest
  39. p_shapes = np.asarray([p.shape[-1] for p in pool_vals])
  40. sorted_inds = np.argsort(-1 * p_shapes)
  41. pool_vals = [pool_vals[i] for i in sorted_inds]
  42. pool_names = [self._pool_names[i] for i in sorted_inds]
  43. loss = torch.zeros([], dtype=torch.float32, device=labels.device, requires_grad=True)
  44. # for each pair of pools, calculate loss!
  45. for i in range(len(pool_names) - 1):
  46. p_loss = self._cal_pool_pair_loss(pool_vals[i], pool_vals[i + 1], self._diff_thresholds[i], labels)
  47. assert (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') not in model_output, 'Trying to add ' + (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') + ' to model output multiple times'
  48. model_output[f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}'] = p_loss.clone()
  49. loss = loss + self._weights[i] * p_loss
  50. return loss
  51. def _cal_pool_pair_loss(self, p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float, labels: torch.Tensor) -> torch.Tensor:
  52. # down-sampling by max-pool till reaching the same shape as p2!
  53. if p1.shape[-1] > p2.shape[-1]:
  54. p1 = F.adaptive_max_pool2d(p1, p2.shape[-2:])
  55. # jensen shannon loss, for each channel -> in the related class!
  56. loss = torch.tensor(0.0, requires_grad=True, device=p1.device)
  57. for channel, r_labels in self._labels_by_channel.items():
  58. on_mask = self._get_inclusion_mask(labels, r_labels)
  59. ip1 = p1[on_mask, channel, ...]
  60. ip2 = p2[on_mask, channel, ...]
  61. if torch.numel(ip1) > 0:
  62. if ip1.shape != ip2.shape:
  63. print(f'Problem in shape of concordance loss in {self._title}')
  64. loss = loss + jensen_shannon_divergence(
  65. ip1[:, None, ...], ip2[:, None, ...], diff_threshold)
  66. return loss
  67. def _get_inclusion_mask(
  68. self,
  69. samples_labels: torch.Tensor,
  70. desired_labels: List[int]) -> torch.Tensor:
  71. with torch.no_grad():
  72. inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0)
  73. aggregation = torch.sum(inclusion_mask.float(), dim=0)
  74. return torch.greater(aggregation, 0)
  75. def jensen_shannon_divergence(p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float) -> torch.Tensor:
  76. """
  77. Calculates the jensen shannon loss between two distributions p1 and p2
  78. Args:
  79. p1 (torch.Tensor): A tensor of shape B C ... in range [0, 1] that is the probabilities for each neuron in the 1st distribution.
  80. p2 (torch.Tensor): A tensor of the same shape as src, in range [0, 1] that is the probabilities for each neuron in the 2nd distribution.
  81. diff_threshold (float): Threshold between p1 and p2 to decide whether to consider one pixel in loss
  82. Returns:
  83. torch.Tensor: The calculated loss.
  84. """
  85. assert p1.shape == p2.shape, 'The tensors must have the same shape'
  86. assert 0 <= diff_threshold < 1, 'The difference threshold should be in range [0, 1)'
  87. # Reshaping tensors
  88. p1 = torch.transpose(p1, 0, 1).flatten(1).transpose(0, 1)
  89. p2 = torch.transpose(p2, 0, 1).flatten(1).transpose(0, 1)
  90. # if binary, append class 0!
  91. if p1.shape[1] == 1:
  92. p1 = torch.cat([1 - p1, p1], dim=1)
  93. p2 = torch.cat([1 - p2, p2], dim=1)
  94. with torch.no_grad():
  95. mask = torch.abs(p1 - p2).detach() >= diff_threshold
  96. mask = mask.max(dim=-1)[0]
  97. # to make sure error does not result in log(negative)!
  98. lp1 = torch.log(torch.maximum(p1 + 1e-4, 1e-6 + torch.zeros_like(p1)))
  99. lp2 = torch.log(torch.maximum(p2 + 1e-4, 1e-6 + torch.zeros_like(p2)))
  100. loss = 0.5 * (
  101. _smean(mask, torch.sum(p1 * (lp1 - lp2), dim=-1)) +
  102. _smean(mask, torch.sum(p2 * (lp2 - lp1), dim=-1))
  103. )
  104. return loss
  105. def _smean(mask: torch.Tensor, val: torch.Tensor) -> torch.Tensor:
  106. if torch.any(mask.bool()):
  107. return torch.mean(val[mask])
  108. else:
  109. return torch.tensor(0.0, requires_grad=True, device=mask.device)