You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

weakly_supervised.py 8.5KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. from typing import Dict, List, Tuple, Optional
  2. import numpy as np
  3. import torch
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from ..data.data_loader import DataLoader
  7. from ..data.dataloader_context import DataloaderContext
  8. from .aux_loss import AuxLoss
  9. class DiscriminativeWeaklySupervisedLoss(AuxLoss):
  10. def __init__(
  11. self, title: str, model: nn.Module,
  12. att_score_layer: nn.Module,
  13. loss_weight: float,
  14. neg_ratio: float, pos_ratio_range: Tuple[float, float],
  15. on_labels_by_channel: Dict[int, List[int]],
  16. discr_score_layer: nn.Module = None,
  17. w_attention_in_ordering: float = 1,
  18. w_discr_in_ordering: float = 1):
  19. """
  20. Calculates binary weakly supervised score for an attention layer
  21. with extra discrimination head.
  22. Args:
  23. title (str): The title of the loss, must be unique!
  24. model (Model): the base model, so the output can be modified
  25. att_score_layer (torch.nn.Module): A layer that gives attention score (B C ...)
  26. loss_weight (float): The weight of the loss
  27. neg_ratio (float, optional): top ratio to apply loss for negative samples. Defaults to 0.1.
  28. pos_ratio_range (Tuple[float, float], optional): low and top ratios to apply loss for positive samples. Defaults to (0.033, 0.278). Calculated by distribution of positive bounding boxes.
  29. on_labels_by_channel (Dict[int, List[int]]): The dictionary that specifies the samples related to which labels should be on in each channel.
  30. w_attention_in_ordering (float): The weight of the attention score used in ordering the pixels.
  31. w_discr_in_ordering (float): The weight of the reference score used in ordering the pixels.
  32. discr_score_layer (torch.nn.Module): A layer that gives discriminative score (B C ...)
  33. """
  34. layers = dict(
  35. att=att_score_layer,
  36. )
  37. if discr_score_layer is not None:
  38. layers['discr'] = discr_score_layer
  39. super().__init__(title, model, layers, loss_weight)
  40. self._has_discr = discr_score_layer is not None
  41. self._neg_ratio = neg_ratio
  42. self._pos_ratio_range = pos_ratio_range
  43. self._on_labels_by_channel = on_labels_by_channel
  44. self._w_attention_in_ordering = w_attention_in_ordering
  45. self._w_discr_in_ordering = w_discr_in_ordering
  46. def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_out: Dict[str, torch.Tensor]) -> torch.Tensor:
  47. discriminative_scores = None
  48. probabilities = None
  49. for ln, lv in layers_values:
  50. if ln == 'att':
  51. probabilities = lv
  52. else:
  53. discriminative_scores = lv
  54. dl: DataLoader = DataloaderContext.instance.dataloader
  55. labels = dl.get_current_batch_samples_labels()
  56. if discriminative_scores is not None:
  57. discrimination_loss = self._calculate_discrimination_loss(
  58. labels, discriminative_scores)
  59. assert (self._title + '_discr_loss') not in model_out, 'Trying to add ' + (self._title + '_discr_loss') + ' to model output multiple times'
  60. model_out[self._title + '_discr_loss'] = discrimination_loss.clone()
  61. else:
  62. discrimination_loss = torch.zeros([], requires_grad=True, device=labels.device)
  63. attention_loss = self._calculate_attention_loss(
  64. labels, probabilities, (discriminative_scores if discriminative_scores is not None else probabilities))
  65. assert (self._title + '_ws_loss') not in model_out, 'Trying to add ' + (self._title + '_ws_loss') + ' to model output multiple times'
  66. model_out[self._title + '_ws_loss'] = attention_loss.clone()
  67. loss = self._loss_weight * (discrimination_loss + attention_loss)
  68. return loss
  69. def _calculate_discrimination_loss(
  70. self,
  71. samples_labels: torch.Tensor,
  72. discrimination_scores: torch.Tensor) -> torch.Tensor:
  73. losses = []
  74. for channel, labels in self._on_labels_by_channel.items():
  75. on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device)
  76. on_ps = discrimination_scores[on_mask, channel, ...]
  77. off_ps = discrimination_scores[torch.logical_not(on_mask), channel, ...]
  78. if torch.numel(on_ps) > 0:
  79. losses.append(self._cal_loss(1, True, on_ps))
  80. if torch.numel(off_ps) > 0:
  81. losses.append(self._cal_loss(1, False, off_ps))
  82. return torch.mean(torch.stack(losses))
  83. def _calculate_attention_loss(
  84. self,
  85. samples_labels: torch.Tensor,
  86. attention_scores: torch.Tensor,
  87. discrimination_scores: torch.Tensor) -> torch.Tensor:
  88. losses = []
  89. for channel, labels in self._on_labels_by_channel.items():
  90. on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device)
  91. on_atts = attention_scores[on_mask, channel, ...]
  92. on_discr = discrimination_scores[on_mask, channel, ...].detach()
  93. off_atts = attention_scores[torch.logical_not(on_mask), channel, ...]
  94. off_discr = discrimination_scores[torch.logical_not(on_mask), channel, ...].detach()
  95. neg_losses = []
  96. pos_losses = []
  97. # loss injection to the model
  98. if torch.numel(off_atts) > 0 and self._neg_ratio > 0:
  99. neg_losses.append(self._cal_loss(
  100. self._neg_ratio, False, off_atts, off_discr, largest=True
  101. ))
  102. if torch.numel(on_atts) > 0:
  103. # Calculate positive top k to be positive
  104. if self._pos_ratio_range[0] > 0:
  105. pos_losses.append(self._cal_loss(
  106. self._pos_ratio_range[0], True, on_atts, on_discr, True
  107. ))
  108. # Calculate positive bottom k to be negative
  109. if self._pos_ratio_range[1] < 1:
  110. neg_losses.append(self._cal_loss(
  111. 1 - self._pos_ratio_range[1], False, on_atts, on_discr, False
  112. ))
  113. if len(neg_losses) > 0:
  114. losses.append(torch.stack(neg_losses).mean())
  115. if len(pos_losses) > 0:
  116. losses.append(torch.stack(pos_losses).mean())
  117. return torch.stack(losses).mean()
  118. def _get_inclusion_mask(
  119. self,
  120. samples_labels: np.ndarray,
  121. desired_labels: List[int], device: torch.device) -> torch.Tensor:
  122. with torch.no_grad():
  123. samples_labels = torch.from_numpy(samples_labels).to(device)
  124. inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0)
  125. aggregation = torch.sum(inclusion_mask.float(), dim=0)
  126. return torch.greater(aggregation, 0)
  127. def _cal_loss(
  128. self,
  129. ratio: float, positive_label: bool,
  130. att_scores: torch.Tensor,
  131. discr_scores: Optional[torch.Tensor] = None,
  132. largest: bool = True):
  133. if ratio == 1:
  134. ps = att_scores
  135. else:
  136. k = np.ceil(
  137. ratio * att_scores.shape[-1] * att_scores.shape[-2]).astype(int)
  138. ps = self._get_topk(att_scores, discr_scores, k, largest=largest)
  139. ps = ps.flatten()
  140. if positive_label:
  141. gt = torch.ones_like(ps)
  142. else:
  143. gt = torch.zeros_like(ps)
  144. return F.binary_cross_entropy(ps, gt)
  145. def _get_topk(self, att_scores: torch.Tensor, discr_scores: torch.Tensor,
  146. k: int, dim=-1, largest=True, return_indices=False) -> torch.Tensor:
  147. scores = self._pixels_scores(att_scores, discr_scores)
  148. b = att_scores.shape[0]
  149. top_inds = (scores.flatten(1)).topk(k, dim=dim, largest=largest, sorted=False).indices
  150. # B K
  151. ret_val = att_scores.flatten(1)[
  152. torch.repeat_interleave(
  153. torch.arange(b, device=att_scores.device), k).reshape(b, k),
  154. top_inds] # B K
  155. if not return_indices:
  156. return ret_val
  157. else:
  158. return ret_val, top_inds
  159. def _pixels_scores(self, attention_scores: torch.Tensor, discr_scores: torch.Tensor) -> torch.Tensor:
  160. return self._w_attention_in_ordering * attention_scores + self._w_discr_in_ordering * discr_scores