from typing import Dict, List, Tuple, Optional import numpy as np import torch from torch import nn from torch.nn import functional as F from ..data.data_loader import DataLoader from ..data.dataloader_context import DataloaderContext from .aux_loss import AuxLoss class DiscriminativeWeaklySupervisedLoss(AuxLoss): def __init__( self, title: str, model: nn.Module, att_score_layer: nn.Module, loss_weight: float, neg_ratio: float, pos_ratio_range: Tuple[float, float], on_labels_by_channel: Dict[int, List[int]], discr_score_layer: nn.Module = None, w_attention_in_ordering: float = 1, w_discr_in_ordering: float = 1): """ Calculates binary weakly supervised score for an attention layer with extra discrimination head. Args: title (str): The title of the loss, must be unique! model (Model): the base model, so the output can be modified att_score_layer (torch.nn.Module): A layer that gives attention score (B C ...) loss_weight (float): The weight of the loss neg_ratio (float, optional): top ratio to apply loss for negative samples. Defaults to 0.1. pos_ratio_range (Tuple[float, float], optional): low and top ratios to apply loss for positive samples. Defaults to (0.033, 0.278). Calculated by distribution of positive bounding boxes. on_labels_by_channel (Dict[int, List[int]]): The dictionary that specifies the samples related to which labels should be on in each channel. w_attention_in_ordering (float): The weight of the attention score used in ordering the pixels. w_discr_in_ordering (float): The weight of the reference score used in ordering the pixels. discr_score_layer (torch.nn.Module): A layer that gives discriminative score (B C ...) """ layers = dict( att=att_score_layer, ) if discr_score_layer is not None: layers['discr'] = discr_score_layer super().__init__(title, model, layers, loss_weight) self._has_discr = discr_score_layer is not None self._neg_ratio = neg_ratio self._pos_ratio_range = pos_ratio_range self._on_labels_by_channel = on_labels_by_channel self._w_attention_in_ordering = w_attention_in_ordering self._w_discr_in_ordering = w_discr_in_ordering def _calculate_loss(self, layers_values: List[Tuple[str, torch.Tensor]], model_out: Dict[str, torch.Tensor]) -> torch.Tensor: discriminative_scores = None probabilities = None for ln, lv in layers_values: if ln == 'att': probabilities = lv else: discriminative_scores = lv dl: DataLoader = DataloaderContext.instance.dataloader labels = dl.get_current_batch_samples_labels() if discriminative_scores is not None: discrimination_loss = self._calculate_discrimination_loss( labels, discriminative_scores) assert (self._title + '_discr_loss') not in model_out, 'Trying to add ' + (self._title + '_discr_loss') + ' to model output multiple times' model_out[self._title + '_discr_loss'] = discrimination_loss.clone() else: discrimination_loss = torch.zeros([], requires_grad=True, device=labels.device) attention_loss = self._calculate_attention_loss( labels, probabilities, (discriminative_scores if discriminative_scores is not None else probabilities)) assert (self._title + '_ws_loss') not in model_out, 'Trying to add ' + (self._title + '_ws_loss') + ' to model output multiple times' model_out[self._title + '_ws_loss'] = attention_loss.clone() loss = self._loss_weight * (discrimination_loss + attention_loss) return loss def _calculate_discrimination_loss( self, samples_labels: torch.Tensor, discrimination_scores: torch.Tensor) -> torch.Tensor: losses = [] for channel, labels in self._on_labels_by_channel.items(): on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device) on_ps = discrimination_scores[on_mask, channel, ...] off_ps = discrimination_scores[torch.logical_not(on_mask), channel, ...] if torch.numel(on_ps) > 0: losses.append(self._cal_loss(1, True, on_ps)) if torch.numel(off_ps) > 0: losses.append(self._cal_loss(1, False, off_ps)) return torch.mean(torch.stack(losses)) def _calculate_attention_loss( self, samples_labels: torch.Tensor, attention_scores: torch.Tensor, discrimination_scores: torch.Tensor) -> torch.Tensor: losses = [] for channel, labels in self._on_labels_by_channel.items(): on_mask = self._get_inclusion_mask(samples_labels, labels, discrimination_scores.device) on_atts = attention_scores[on_mask, channel, ...] on_discr = discrimination_scores[on_mask, channel, ...].detach() off_atts = attention_scores[torch.logical_not(on_mask), channel, ...] off_discr = discrimination_scores[torch.logical_not(on_mask), channel, ...].detach() neg_losses = [] pos_losses = [] # loss injection to the model if torch.numel(off_atts) > 0 and self._neg_ratio > 0: neg_losses.append(self._cal_loss( self._neg_ratio, False, off_atts, off_discr, largest=True )) if torch.numel(on_atts) > 0: # Calculate positive top k to be positive if self._pos_ratio_range[0] > 0: pos_losses.append(self._cal_loss( self._pos_ratio_range[0], True, on_atts, on_discr, True )) # Calculate positive bottom k to be negative if self._pos_ratio_range[1] < 1: neg_losses.append(self._cal_loss( 1 - self._pos_ratio_range[1], False, on_atts, on_discr, False )) if len(neg_losses) > 0: losses.append(torch.stack(neg_losses).mean()) if len(pos_losses) > 0: losses.append(torch.stack(pos_losses).mean()) return torch.stack(losses).mean() def _get_inclusion_mask( self, samples_labels: np.ndarray, desired_labels: List[int], device: torch.device) -> torch.Tensor: with torch.no_grad(): samples_labels = torch.from_numpy(samples_labels).to(device) inclusion_mask = torch.stack([samples_labels == l for l in desired_labels], dim=0) aggregation = torch.sum(inclusion_mask.float(), dim=0) return torch.greater(aggregation, 0) def _cal_loss( self, ratio: float, positive_label: bool, att_scores: torch.Tensor, discr_scores: Optional[torch.Tensor] = None, largest: bool = True): if ratio == 1: ps = att_scores else: k = np.ceil( ratio * att_scores.shape[-1] * att_scores.shape[-2]).astype(int) ps = self._get_topk(att_scores, discr_scores, k, largest=largest) ps = ps.flatten() if positive_label: gt = torch.ones_like(ps) else: gt = torch.zeros_like(ps) return F.binary_cross_entropy(ps, gt) def _get_topk(self, att_scores: torch.Tensor, discr_scores: torch.Tensor, k: int, dim=-1, largest=True, return_indices=False) -> torch.Tensor: scores = self._pixels_scores(att_scores, discr_scores) b = att_scores.shape[0] top_inds = (scores.flatten(1)).topk(k, dim=dim, largest=largest, sorted=False).indices # B K ret_val = att_scores.flatten(1)[ torch.repeat_interleave( torch.arange(b, device=att_scores.device), k).reshape(b, k), top_inds] # B K if not return_indices: return ret_val else: return ret_val, top_inds def _pixels_scores(self, attention_scores: torch.Tensor, discr_scores: torch.Tensor) -> torch.Tensor: return self._w_attention_in_ordering * attention_scores + self._w_discr_in_ordering * discr_scores