You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long. 8.5KB

  1. from typing import Dict, List, Tuple, Optional
  2. import numpy as np
  3. import torch
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from import DataLoader
  7. from import DataloaderContext
  8. from .aux_loss import AuxLoss
  9. class DiscriminativeWeaklySupervisedLoss(AuxLoss):
  10. def __init__(
  11. self, title: str, model: nn.Module,
  12. att_score_layer: nn.Module,
  13. loss_weight: float,
  14. neg_ratio: float, pos_ratio_range: Tuple[float, float],
  15. on_labels_by_channel: Dict[int, List[int]],
  16. discr_score_layer: nn.Module = None,
  17. w_attention_in_ordering: float = 1,
  18. w_discr_in_ordering: float = 1):
  19. """
  20. Calculates binary weakly supervised score for an attention layer
  21. with extra discrimination head.
  22. Args:
  23. title (str): The title of the loss, must be unique!
  24. model (Model): the base model, so the output can be modified
  25. att_score_layer (torch.nn.Module): A layer that gives attention score (B C ...)
  26. loss_weight (float): The weight of the loss
  27. neg_ratio (float, optional): top ratio to apply loss for negative samples. Defaults to 0.1.
  28. pos_ratio_range (Tuple[float, float], optional): low and top ratios to apply loss for positive samples. Defaults to (0.033, 0.278). Calculated by distribution of positive bounding boxes.
  29. on_labels_by_channel (Dict[int, List[int]]): The dictionary that specifies the samples related to which labels should be on in each channel.
  30. w_attention_in_ordering (float): The weight of the attention score used in ordering the pixels.
  31. w_discr_in_ordering (float): The weight of the reference score used in ordering the pixels.
  32. discr_score_layer (torch.nn.Module): A layer that gives discriminative score (B C ...)
  33. """
  34. layers = dict(
  35. att=att_score_layer,
  36. )
  37. if discr_score_layer is not None:
  38. layers['discr'] = discr_score_layer
  39. super().__init__(title, model, layers, loss_weight)
  40. self._has_discr = discr_score_layer is not None
  41. self._neg_ratio = neg_ratio
  42. self._pos_ratio_range = pos_ratio_range
  43. self._on_labels_by_channel = on_labels_by_channel
  44. self._w_attention_in_ordering = w_attention_in_ordering
  45. self._w_discr_in_ordering = w_discr_in_ordering
  46. def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_out: Dict[str, torch.Tensor]) -> torch.Tensor:
  47. discriminative_scores = None
  48. probabilities = None
  49. for ln, lv in layers_values:
  50. if ln == 'att':
  51. probabilities = lv
  52. else:
  53. discriminative_scores = lv
  54. dl: DataLoader = DataloaderContext.instance.dataloader
  55. labels = dl.get_current_batch_samples_labels()
  56. if discriminative_scores is not None:
  57. discrimination_loss = self._calculate_discrimination_loss(
  58. labels, discriminative_scores)
  59. assert (self._title + '_discr_loss') not in model_out, 'Trying to add ' + (self._title + '_discr_loss') + ' to model output multiple times'
  60. model_out[self._title + '_discr_loss'] = discrimination_loss.clone()
  61. else:
  62. discrimination_loss = torch.zeros([], requires_grad=True, device=labels.device)
  63. attention_loss = self._calculate_attention_loss(
  64. labels, probabilities, (discriminative_scores if discriminative_scores is not None else probabilities))
  65. assert (self._title + '_ws_loss') not in model_out, 'Trying to add ' + (self._title + '_ws_loss') + ' to model output multiple times'
  66. model_out[self._title + '_ws_loss'] = attention_loss.clone()
  67. loss = self._loss_weight * (discrimination_loss + attention_loss)
  68. return loss
  69. def _calculate_discrimination_loss(
  70. self,
  71. samples_labels: torch.Tensor,
  72. discrimination_scores: torch.Tensor) -> torch.Tensor:
  73. losses = []
  74. for channel, labels in self._on_labels_by_channel.items():
  75. on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device)
  76. on_ps = discrimination_scores[on_mask, channel, ...]
  77. off_ps = discrimination_scores[torch.logical_not(on_mask), channel, ...]
  78. if torch.numel(on_ps) > 0:
  79. losses.append(self._cal_loss(1, True, on_ps))
  80. if torch.numel(off_ps) > 0:
  81. losses.append(self._cal_loss(1, False, off_ps))
  82. return torch.mean(torch.stack(losses))
  83. def _calculate_attention_loss(
  84. self,
  85. samples_labels: torch.Tensor,
  86. attention_scores: torch.Tensor,
  87. discrimination_scores: torch.Tensor) -> torch.Tensor:
  88. losses = []
  89. for channel, labels in self._on_labels_by_channel.items():
  90. on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device)
  91. on_atts = attention_scores[on_mask, channel, ...]
  92. on_discr = discrimination_scores[on_mask, channel, ...].detach()
  93. off_atts = attention_scores[torch.logical_not(on_mask), channel, ...]
  94. off_discr = discrimination_scores[torch.logical_not(on_mask), channel, ...].detach()
  95. neg_losses = []
  96. pos_losses = []
  97. # loss injection to the model
  98. if torch.numel(off_atts) > 0 and self._neg_ratio > 0:
  99. neg_losses.append(self._cal_loss(
  100. self._neg_ratio, False, off_atts, off_discr, largest=True
  101. ))
  102. if torch.numel(on_atts) > 0:
  103. # Calculate positive top k to be positive
  104. if self._pos_ratio_range[0] > 0:
  105. pos_losses.append(self._cal_loss(
  106. self._pos_ratio_range[0], True, on_atts, on_discr, True
  107. ))
  108. # Calculate positive bottom k to be negative
  109. if self._pos_ratio_range[1] < 1:
  110. neg_losses.append(self._cal_loss(
  111. 1 - self._pos_ratio_range[1], False, on_atts, on_discr, False
  112. ))
  113. if len(neg_losses) > 0:
  114. losses.append(torch.stack(neg_losses).mean())
  115. if len(pos_losses) > 0:
  116. losses.append(torch.stack(pos_losses).mean())
  117. return torch.stack(losses).mean()
  118. def _get_inclusion_mask(
  119. self,
  120. samples_labels: np.ndarray,
  121. desired_labels: List[int], device: torch.device) -> torch.Tensor:
  122. with torch.no_grad():
  123. samples_labels = torch.from_numpy(samples_labels).to(device)
  124. inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0)
  125. aggregation = torch.sum(inclusion_mask.float(), dim=0)
  126. return torch.greater(aggregation, 0)
  127. def _cal_loss(
  128. self,
  129. ratio: float, positive_label: bool,
  130. att_scores: torch.Tensor,
  131. discr_scores: Optional[torch.Tensor] = None,
  132. largest: bool = True):
  133. if ratio == 1:
  134. ps = att_scores
  135. else:
  136. k = np.ceil(
  137. ratio * att_scores.shape[-1] * att_scores.shape[-2]).astype(int)
  138. ps = self._get_topk(att_scores, discr_scores, k, largest=largest)
  139. ps = ps.flatten()
  140. if positive_label:
  141. gt = torch.ones_like(ps)
  142. else:
  143. gt = torch.zeros_like(ps)
  144. return F.binary_cross_entropy(ps, gt)
  145. def _get_topk(self, att_scores: torch.Tensor, discr_scores: torch.Tensor,
  146. k: int, dim=-1, largest=True, return_indices=False) -> torch.Tensor:
  147. scores = self._pixels_scores(att_scores, discr_scores)
  148. b = att_scores.shape[0]
  149. top_inds = (scores.flatten(1)).topk(k, dim=dim, largest=largest, sorted=False).indices
  150. # B K
  151. ret_val = att_scores.flatten(1)[
  152. torch.repeat_interleave(
  153. torch.arange(b, device=att_scores.device), k).reshape(b, k),
  154. top_inds] # B K
  155. if not return_indices:
  156. return ret_val
  157. else:
  158. return ret_val, top_inds
  159. def _pixels_scores(self, attention_scores: torch.Tensor, discr_scores: torch.Tensor) -> torch.Tensor:
  160. return self._w_attention_in_ordering * attention_scores + self._w_discr_in_ordering * discr_scores