from typing import List, Tuple, Dict, Union import math import torch from torch.nn import functional as F from ..data.data_loader import DataLoader from ..data.dataloader_context import DataloaderContext from .aux_loss import AuxLoss class BBLoss(AuxLoss): def __init__( self, title: str, model: torch.nn.Module, layer_by_name: Dict[str, torch.nn.Module], loss_weight: float, inf_mask_kw: str, layers_loss_weights: Union[float, List[float]], bb_fill_ratio: float): super().__init__(title, model, layer_by_name, loss_weight) self._inf_mask_kw = inf_mask_kw self._bb_fill_ratio = bb_fill_ratio if isinstance(layers_loss_weights, list): assert len(layers_loss_weights) == len(layer_by_name), f'There should be as many weights as layers, one per layer. Expected {len(layer_by_name)} found {len(layers_loss_weights)}' self._layers_loss_weights = layers_loss_weights else: self._layers_loss_weights = [layers_loss_weights for _ in range(len(layer_by_name))] def _calculate_loss( self, layers_values: List[Tuple[str, torch.Tensor]], model_output: Dict[str, torch.Tensor]) -> torch.Tensor: # gathering extra information dl: DataLoader = DataloaderContext.instance.dataloader xray_y = torch.from_numpy(dl.get_current_batch_samples_labels()).to(dl._device) infection_mask = dl.get_current_batch_data(keyword=self._inf_mask_kw).to(dl._device) xray_y = xray_y.bool() infection_mask = infection_mask[xray_y] PB, _, H, W = infection_mask.shape loss = torch.zeros([], dtype=torch.float32, device=xray_y.device, requires_grad=True) for li, (ln, lo) in enumerate(layers_values): out = F.interpolate(lo, size=(H, W), mode='bilinear', align_corners=True) out = out.flatten(start_dim=1) neg_out = out[~xray_y] pos_out = out[xray_y] NB = neg_out.shape[0] infection_mask = infection_mask.flatten(start_dim=1) pos_out_infection = torch.where(infection_mask < 1, torch.zeros_like(pos_out, requires_grad=True), pos_out) neg_losses = [] pos_losses = [] if PB > 0: bb_area = infection_mask.sum(dim=-1) # B k = (self._bb_fill_ratio * bb_area.quantile(q=0.5))\ .ceil()\ .long() # Top positive in bb pixels must be positive pos_infection_topk, pos_infection_indices = pos_out_infection\ .topk(k, dim=-1, sorted=False) pos_infection_batch_index = torch.arange(PB)\ .to(pos_infection_indices.device)\ .unsqueeze(1)\ .repeat_interleave(k, dim=1) pos_weight = infection_mask[pos_infection_batch_index, pos_infection_indices].floor() # make non ones to be zero pos_losses.append(F.binary_cross_entropy( pos_infection_topk, torch.ones_like(pos_infection_topk), pos_weight )) if (infection_mask == 0).any(): # All positive out bb pixels must be negative pos_out_non_infection = pos_out[infection_mask == 0] neg_losses.append(F.binary_cross_entropy( pos_out_non_infection, torch.zeros_like(pos_out_non_infection) )) if NB > 0: if PB > 0: # Top negative pixels must be negative neg_k = int(math.ceil(PB * k * 1.0 / NB)) neg_out_topk = neg_out.topk(neg_k, dim=-1, sorted=False)[0] neg_losses.append(F.binary_cross_entropy( neg_out_topk, torch.zeros_like(neg_out_topk) )) else: # All negative pixels must be negative neg_losses.append(F.binary_cross_entropy( neg_out, torch.zeros_like(neg_out) )) losses = [] if len(neg_losses) > 0: losses.append(torch.stack(neg_losses).mean()) if len(pos_losses) > 0: losses.append(torch.stack(pos_losses).mean()) l_loss = torch.stack(losses).mean() model_output[f'{self._title}_{ln}_loss'] = l_loss loss = loss + self._layers_loss_weights[li] * l_loss return loss