You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long. 6.2KB

  1. from typing import Dict, Union, List, Tuple
  2. import numpy as np
  3. import torch
  4. from torch.nn import functional as F
  5. from .aux_loss import AuxLoss
  6. from import DataloaderContext
  7. from import DataLoader
  8. class PoolConcordanceLossCalculator(AuxLoss):
  9. def __init__(self,
  10. title: str, model: torch.nn.Module,
  11. layer_by_name: Dict[str, torch.nn.Module],
  12. loss_weight: float, weights: Union[float, List[float]],
  13. diff_thresholds: Union[float, List[float]],
  14. labels_by_channel: Dict[int, List[int]]):
  15. super().__init__(title, model, layer_by_name, loss_weight)
  16. # at least two pools are needed
  17. assert len(layer_by_name) > 1, 'At least two pool layers are required to calculate this loss!'
  18. self._title = title
  19. self._model = model
  20. self._loss_prefix = f'{title}_JS_loss_'
  21. self._labels_by_channel = labels_by_channel
  22. self._pool_names = list(layer_by_name.keys())
  23. self._weights = weights if isinstance(weights, list) else \
  24. [weights for _ in range(len(layer_by_name) - 1)]
  25. if isinstance(weights, list):
  26. assert len(weights) == len(layer_by_name) - 1, 'Weights must have a length of pool_layers -1'
  27. self._diff_thresholds = diff_thresholds if isinstance(diff_thresholds, list) else \
  28. [diff_thresholds for _ in range(len(layer_by_name) - 1)]
  29. if isinstance(diff_thresholds, list):
  30. assert len(diff_thresholds) == len(layer_by_name) - 1, 'Diff thresholds must have a length of pool_layers -1'
  31. def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor:
  32. # reading pools from the output and deleting them
  33. pool_vals = [x[1] for x in layers_values]
  34. # getting samples labels
  35. dl: DataLoader = DataloaderContext.instance.dataloader
  36. labels = dl.get_current_batch_samples_labels()
  37. labels = torch.from_numpy(labels).to(dl._device)
  38. # sorting by shape from biggest to smallest
  39. p_shapes = np.asarray([p.shape[-1] for p in pool_vals])
  40. sorted_inds = np.argsort(-1 * p_shapes)
  41. pool_vals = [pool_vals[i] for i in sorted_inds]
  42. pool_names = [self._pool_names[i] for i in sorted_inds]
  43. loss = torch.zeros([], dtype=torch.float32, device=labels.device, requires_grad=True)
  44. # for each pair of pools, calculate loss!
  45. for i in range(len(pool_names) - 1):
  46. p_loss = self._cal_pool_pair_loss(pool_vals[i], pool_vals[i + 1], self._diff_thresholds[i], labels)
  47. assert (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') not in model_output, 'Trying to add ' + (f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}') + ' to model output multiple times'
  48. model_output[f'{self._loss_prefix}{pool_names[i]}-{pool_names[i + 1]}'] = p_loss.clone()
  49. loss = loss + self._weights[i] * p_loss
  50. return loss
  51. def _cal_pool_pair_loss(self, p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float, labels: torch.Tensor) -> torch.Tensor:
  52. # down-sampling by max-pool till reaching the same shape as p2!
  53. if p1.shape[-1] > p2.shape[-1]:
  54. p1 = F.adaptive_max_pool2d(p1, p2.shape[-2:])
  55. # jensen shannon loss, for each channel -> in the related class!
  56. loss = torch.tensor(0.0, requires_grad=True, device=p1.device)
  57. for channel, r_labels in self._labels_by_channel.items():
  58. on_mask = self._get_inclusion_mask(labels, r_labels)
  59. ip1 = p1[on_mask, channel, ...]
  60. ip2 = p2[on_mask, channel, ...]
  61. if torch.numel(ip1) > 0:
  62. if ip1.shape != ip2.shape:
  63. print(f'Problem in shape of concordance loss in {self._title}')
  64. loss = loss + jensen_shannon_divergence(
  65. ip1[:, None, ...], ip2[:, None, ...], diff_threshold)
  66. return loss
  67. def _get_inclusion_mask(
  68. self,
  69. samples_labels: torch.Tensor,
  70. desired_labels: List[int]) -> torch.Tensor:
  71. with torch.no_grad():
  72. inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0)
  73. aggregation = torch.sum(inclusion_mask.float(), dim=0)
  74. return torch.greater(aggregation, 0)
  75. def jensen_shannon_divergence(p1: torch.Tensor, p2: torch.Tensor, diff_threshold: float) -> torch.Tensor:
  76. """
  77. Calculates the jensen shannon loss between two distributions p1 and p2
  78. Args:
  79. p1 (torch.Tensor): A tensor of shape B C ... in range [0, 1] that is the probabilities for each neuron in the 1st distribution.
  80. p2 (torch.Tensor): A tensor of the same shape as src, in range [0, 1] that is the probabilities for each neuron in the 2nd distribution.
  81. diff_threshold (float): Threshold between p1 and p2 to decide whether to consider one pixel in loss
  82. Returns:
  83. torch.Tensor: The calculated loss.
  84. """
  85. assert p1.shape == p2.shape, 'The tensors must have the same shape'
  86. assert 0 <= diff_threshold < 1, 'The difference threshold should be in range [0, 1)'
  87. # Reshaping tensors
  88. p1 = torch.transpose(p1, 0, 1).flatten(1).transpose(0, 1)
  89. p2 = torch.transpose(p2, 0, 1).flatten(1).transpose(0, 1)
  90. # if binary, append class 0!
  91. if p1.shape[1] == 1:
  92. p1 =[1 - p1, p1], dim=1)
  93. p2 =[1 - p2, p2], dim=1)
  94. with torch.no_grad():
  95. mask = torch.abs(p1 - p2).detach() >= diff_threshold
  96. mask = mask.max(dim=-1)[0]
  97. # to make sure error does not result in log(negative)!
  98. lp1 = torch.log(torch.maximum(p1 + 1e-4, 1e-6 + torch.zeros_like(p1)))
  99. lp2 = torch.log(torch.maximum(p2 + 1e-4, 1e-6 + torch.zeros_like(p2)))
  100. loss = 0.5 * (
  101. _smean(mask, torch.sum(p1 * (lp1 - lp2), dim=-1)) +
  102. _smean(mask, torch.sum(p2 * (lp2 - lp1), dim=-1))
  103. )
  104. return loss
  105. def _smean(mask: torch.Tensor, val: torch.Tensor) -> torch.Tensor:
  106. if torch.any(mask.bool()):
  107. return torch.mean(val[mask])
  108. else:
  109. return torch.tensor(0.0, requires_grad=True, device=mask.device)