|
- from typing import Dict, List, Tuple, Optional
-
- import numpy as np
- import torch
- from torch import nn
- from torch.nn import functional as F
-
- from ..data.data_loader import DataLoader
- from ..data.dataloader_context import DataloaderContext
- from .aux_loss import AuxLoss
-
-
- class DiscriminativeWeaklySupervisedLoss(AuxLoss):
-
- def __init__(
- self, title: str, model: nn.Module,
- att_score_layer: nn.Module,
- loss_weight: float,
- neg_ratio: float, pos_ratio_range: Tuple[float, float],
- on_labels_by_channel: Dict[int, List[int]],
- discr_score_layer: nn.Module = None,
- w_attention_in_ordering: float = 1,
- w_discr_in_ordering: float = 1):
- """
- Calculates binary weakly supervised score for an attention layer
- with extra discrimination head.
-
- Args:
- title (str): The title of the loss, must be unique!
- model (Model): the base model, so the output can be modified
- att_score_layer (torch.nn.Module): A layer that gives attention score (B C ...)
- loss_weight (float): The weight of the loss
- neg_ratio (float, optional): top ratio to apply loss for negative samples. Defaults to 0.1.
- pos_ratio_range (Tuple[float, float], optional): low and top ratios to apply loss for positive samples. Defaults to (0.033, 0.278). Calculated by distribution of positive bounding boxes.
- on_labels_by_channel (Dict[int, List[int]]): The dictionary that specifies the samples related to which labels should be on in each channel.
- w_attention_in_ordering (float): The weight of the attention score used in ordering the pixels.
- w_discr_in_ordering (float): The weight of the reference score used in ordering the pixels.
- discr_score_layer (torch.nn.Module): A layer that gives discriminative score (B C ...)
- """
- layers = dict(
- att=att_score_layer,
- )
- if discr_score_layer is not None:
- layers['discr'] = discr_score_layer
- super().__init__(title, model, layers, loss_weight)
-
- self._has_discr = discr_score_layer is not None
-
- self._neg_ratio = neg_ratio
- self._pos_ratio_range = pos_ratio_range
-
- self._on_labels_by_channel = on_labels_by_channel
-
- self._w_attention_in_ordering = w_attention_in_ordering
- self._w_discr_in_ordering = w_discr_in_ordering
-
- def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_out: Dict[str, torch.Tensor]) -> torch.Tensor:
-
- discriminative_scores = None
- probabilities = None
-
- for ln, lv in layers_values:
- if ln == 'att':
- probabilities = lv
- else:
- discriminative_scores = lv
-
- dl: DataLoader = DataloaderContext.instance.dataloader
- labels = dl.get_current_batch_samples_labels()
-
- if discriminative_scores is not None:
- discrimination_loss = self._calculate_discrimination_loss(
- labels, discriminative_scores)
- assert (self._title + '_discr_loss') not in model_out, 'Trying to add ' + (self._title + '_discr_loss') + ' to model output multiple times'
- model_out[self._title + '_discr_loss'] = discrimination_loss.clone()
- else:
- discrimination_loss = torch.zeros([], requires_grad=True, device=labels.device)
-
- attention_loss = self._calculate_attention_loss(
- labels, probabilities, (discriminative_scores if discriminative_scores is not None else probabilities))
-
- assert (self._title + '_ws_loss') not in model_out, 'Trying to add ' + (self._title + '_ws_loss') + ' to model output multiple times'
- model_out[self._title + '_ws_loss'] = attention_loss.clone()
-
- loss = self._loss_weight * (discrimination_loss + attention_loss)
-
- return loss
-
- def _calculate_discrimination_loss(
- self,
- samples_labels: torch.Tensor,
- discrimination_scores: torch.Tensor) -> torch.Tensor:
-
- losses = []
-
- for channel, labels in self._on_labels_by_channel.items():
-
- on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device)
-
- on_ps = discrimination_scores[on_mask, channel, ...]
- off_ps = discrimination_scores[torch.logical_not(on_mask), channel, ...]
-
- if torch.numel(on_ps) > 0:
- losses.append(self._cal_loss(1, True, on_ps))
-
- if torch.numel(off_ps) > 0:
- losses.append(self._cal_loss(1, False, off_ps))
-
- return torch.mean(torch.stack(losses))
-
- def _calculate_attention_loss(
- self,
- samples_labels: torch.Tensor,
- attention_scores: torch.Tensor,
- discrimination_scores: torch.Tensor) -> torch.Tensor:
-
-
- losses = []
-
- for channel, labels in self._on_labels_by_channel.items():
-
- on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device)
-
- on_atts = attention_scores[on_mask, channel, ...]
- on_discr = discrimination_scores[on_mask, channel, ...].detach()
-
- off_atts = attention_scores[torch.logical_not(on_mask), channel, ...]
- off_discr = discrimination_scores[torch.logical_not(on_mask), channel, ...].detach()
-
- neg_losses = []
- pos_losses = []
-
- # loss injection to the model
-
- if torch.numel(off_atts) > 0 and self._neg_ratio > 0:
- neg_losses.append(self._cal_loss(
- self._neg_ratio, False, off_atts, off_discr, largest=True
- ))
-
- if torch.numel(on_atts) > 0:
-
- # Calculate positive top k to be positive
-
- if self._pos_ratio_range[0] > 0:
- pos_losses.append(self._cal_loss(
- self._pos_ratio_range[0], True, on_atts, on_discr, True
- ))
-
- # Calculate positive bottom k to be negative
- if self._pos_ratio_range[1] < 1:
-
- neg_losses.append(self._cal_loss(
- 1 - self._pos_ratio_range[1], False, on_atts, on_discr, False
- ))
-
- if len(neg_losses) > 0:
- losses.append(torch.stack(neg_losses).mean())
-
- if len(pos_losses) > 0:
- losses.append(torch.stack(pos_losses).mean())
-
- return torch.stack(losses).mean()
-
- def _get_inclusion_mask(
- self,
- samples_labels: np.ndarray,
- desired_labels: List[int], device: torch.device) -> torch.Tensor:
-
- with torch.no_grad():
- samples_labels = torch.from_numpy(samples_labels).to(device)
- inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0)
- aggregation = torch.sum(inclusion_mask.float(), dim=0)
- return torch.greater(aggregation, 0)
-
- def _cal_loss(
- self,
- ratio: float, positive_label: bool,
- att_scores: torch.Tensor,
- discr_scores: Optional[torch.Tensor] = None,
- largest: bool = True):
-
- if ratio == 1:
- ps = att_scores
-
- else:
-
- k = np.ceil(
- ratio * att_scores.shape[-1] * att_scores.shape[-2]).astype(int)
- ps = self._get_topk(att_scores, discr_scores, k, largest=largest)
-
- ps = ps.flatten()
-
- if positive_label:
- gt = torch.ones_like(ps)
- else:
- gt = torch.zeros_like(ps)
-
- return F.binary_cross_entropy(ps, gt)
-
- def _get_topk(self, att_scores: torch.Tensor, discr_scores: torch.Tensor,
- k: int, dim=-1, largest=True, return_indices=False) -> torch.Tensor:
-
- scores = self._pixels_scores(att_scores, discr_scores)
- b = att_scores.shape[0]
-
- top_inds = (scores.flatten(1)).topk(k, dim=dim, largest=largest, sorted=False).indices
- # B K
-
- ret_val = att_scores.flatten(1)[
- torch.repeat_interleave(
- torch.arange(b, device=att_scores.device), k).reshape(b, k),
- top_inds] # B K
-
- if not return_indices:
- return ret_val
- else:
- return ret_val, top_inds
-
- def _pixels_scores(self, attention_scores: torch.Tensor, discr_scores: torch.Tensor) -> torch.Tensor:
- return self._w_attention_in_ordering * attention_scores + self._w_discr_in_ordering * discr_scores
-
|