123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- from typing import List, Tuple, Dict, Union
- import math
- import torch
- from torch.nn import functional as F
-
- from ..data.data_loader import DataLoader
- from ..data.dataloader_context import DataloaderContext
- from .aux_loss import AuxLoss
-
-
- class BBLoss(AuxLoss):
-
- def __init__(
- self, title: str, model: torch.nn.Module,
- layer_by_name: Dict[str, torch.nn.Module],
- loss_weight: float,
- inf_mask_kw: str,
- layers_loss_weights: Union[float, List[float]],
- bb_fill_ratio: float):
- super().__init__(title, model, layer_by_name, loss_weight)
-
- self._inf_mask_kw = inf_mask_kw
- self._bb_fill_ratio = bb_fill_ratio
-
- if isinstance(layers_loss_weights, list):
- assert len(layers_loss_weights) == len(layer_by_name), f'There should be as many weights as layers, one per layer. Expected {len(layer_by_name)} found {len(layers_loss_weights)}'
- self._layers_loss_weights = layers_loss_weights
- else:
- self._layers_loss_weights = [layers_loss_weights for _ in range(len(layer_by_name))]
-
- def _calculate_loss(
- self,
- layers_values: List[Tuple[str, torch.Tensor]],
- model_output: Dict[str, torch.Tensor]) -> torch.Tensor:
-
- # gathering extra information
- dl: DataLoader = DataloaderContext.instance.dataloader
- xray_y = torch.from_numpy(dl.get_current_batch_samples_labels()).to(dl._device)
- infection_mask = dl.get_current_batch_data(keyword=self._inf_mask_kw).to(dl._device)
-
- xray_y = xray_y.bool()
- infection_mask = infection_mask[xray_y]
-
- PB, _, H, W = infection_mask.shape
-
- loss = torch.zeros([], dtype=torch.float32,
- device=xray_y.device, requires_grad=True)
-
- for li, (ln, lo) in enumerate(layers_values):
- out = F.interpolate(lo,
- size=(H, W),
- mode='bilinear',
- align_corners=True)
- out = out.flatten(start_dim=1)
-
- neg_out = out[~xray_y]
- pos_out = out[xray_y]
- NB = neg_out.shape[0]
-
- infection_mask = infection_mask.flatten(start_dim=1)
-
- pos_out_infection = torch.where(infection_mask < 1,
- torch.zeros_like(pos_out, requires_grad=True),
- pos_out)
-
- neg_losses = []
- pos_losses = []
-
- if PB > 0:
- bb_area = infection_mask.sum(dim=-1) # B
- k = (self._bb_fill_ratio * bb_area.quantile(q=0.5))\
- .ceil()\
- .long()
-
- # Top positive in bb pixels must be positive
- pos_infection_topk, pos_infection_indices = pos_out_infection\
- .topk(k, dim=-1, sorted=False)
- pos_infection_batch_index = torch.arange(PB)\
- .to(pos_infection_indices.device)\
- .unsqueeze(1)\
- .repeat_interleave(k, dim=1)
- pos_weight = infection_mask[pos_infection_batch_index, pos_infection_indices].floor() # make non ones to be zero
- pos_losses.append(F.binary_cross_entropy(
- pos_infection_topk,
- torch.ones_like(pos_infection_topk),
- pos_weight
- ))
-
- if (infection_mask == 0).any():
- # All positive out bb pixels must be negative
- pos_out_non_infection = pos_out[infection_mask == 0]
- neg_losses.append(F.binary_cross_entropy(
- pos_out_non_infection,
- torch.zeros_like(pos_out_non_infection)
- ))
-
- if NB > 0:
- if PB > 0:
- # Top negative pixels must be negative
- neg_k = int(math.ceil(PB * k * 1.0 / NB))
- neg_out_topk = neg_out.topk(neg_k, dim=-1, sorted=False)[0]
- neg_losses.append(F.binary_cross_entropy(
- neg_out_topk,
- torch.zeros_like(neg_out_topk)
- ))
- else:
- # All negative pixels must be negative
- neg_losses.append(F.binary_cross_entropy(
- neg_out,
- torch.zeros_like(neg_out)
- ))
-
- losses = []
- if len(neg_losses) > 0:
- losses.append(torch.stack(neg_losses).mean())
- if len(pos_losses) > 0:
- losses.append(torch.stack(pos_losses).mean())
-
- l_loss = torch.stack(losses).mean()
- model_output[f'{self._title}_{ln}_loss'] = l_loss
- loss = loss + self._layers_loss_weights[li] * l_loss
-
- return loss
|