You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

bb_supervised.py 4.9KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from typing import List, Tuple, Dict, Union
  2. import math
  3. import torch
  4. from torch.nn import functional as F
  5. from ..data.data_loader import DataLoader
  6. from ..data.dataloader_context import DataloaderContext
  7. from .aux_loss import AuxLoss
  8. class BBLoss(AuxLoss):
  9. def __init__(
  10. self, title: str, model: torch.nn.Module,
  11. layer_by_name: Dict[str, torch.nn.Module],
  12. loss_weight: float,
  13. inf_mask_kw: str,
  14. layers_loss_weights: Union[float, List[float]],
  15. bb_fill_ratio: float):
  16. super().__init__(title, model, layer_by_name, loss_weight)
  17. self._inf_mask_kw = inf_mask_kw
  18. self._bb_fill_ratio = bb_fill_ratio
  19. if isinstance(layers_loss_weights, list):
  20. assert len(layers_loss_weights) == len(layer_by_name), f'There should be as many weights as layers, one per layer. Expected {len(layer_by_name)} found {len(layers_loss_weights)}'
  21. self._layers_loss_weights = layers_loss_weights
  22. else:
  23. self._layers_loss_weights = [layers_loss_weights for _ in range(len(layer_by_name))]
  24. def _calculate_loss(
  25. self,
  26. layers_values: List[Tuple[str, torch.Tensor]],
  27. model_output: Dict[str, torch.Tensor]) -> torch.Tensor:
  28. # gathering extra information
  29. dl: DataLoader = DataloaderContext.instance.dataloader
  30. xray_y = torch.from_numpy(dl.get_current_batch_samples_labels()).to(dl._device)
  31. infection_mask = dl.get_current_batch_data(keyword=self._inf_mask_kw).to(dl._device)
  32. xray_y = xray_y.bool()
  33. infection_mask = infection_mask[xray_y]
  34. PB, _, H, W = infection_mask.shape
  35. loss = torch.zeros([], dtype=torch.float32,
  36. device=xray_y.device, requires_grad=True)
  37. for li, (ln, lo) in enumerate(layers_values):
  38. out = F.interpolate(lo,
  39. size=(H, W),
  40. mode='bilinear',
  41. align_corners=True)
  42. out = out.flatten(start_dim=1)
  43. neg_out = out[~xray_y]
  44. pos_out = out[xray_y]
  45. NB = neg_out.shape[0]
  46. infection_mask = infection_mask.flatten(start_dim=1)
  47. pos_out_infection = torch.where(infection_mask < 1,
  48. torch.zeros_like(pos_out, requires_grad=True),
  49. pos_out)
  50. neg_losses = []
  51. pos_losses = []
  52. if PB > 0:
  53. bb_area = infection_mask.sum(dim=-1) # B
  54. k = (self._bb_fill_ratio * bb_area.quantile(q=0.5))\
  55. .ceil()\
  56. .long()
  57. # Top positive in bb pixels must be positive
  58. pos_infection_topk, pos_infection_indices = pos_out_infection\
  59. .topk(k, dim=-1, sorted=False)
  60. pos_infection_batch_index = torch.arange(PB)\
  61. .to(pos_infection_indices.device)\
  62. .unsqueeze(1)\
  63. .repeat_interleave(k, dim=1)
  64. pos_weight = infection_mask[pos_infection_batch_index, pos_infection_indices].floor() # make non ones to be zero
  65. pos_losses.append(F.binary_cross_entropy(
  66. pos_infection_topk,
  67. torch.ones_like(pos_infection_topk),
  68. pos_weight
  69. ))
  70. if (infection_mask == 0).any():
  71. # All positive out bb pixels must be negative
  72. pos_out_non_infection = pos_out[infection_mask == 0]
  73. neg_losses.append(F.binary_cross_entropy(
  74. pos_out_non_infection,
  75. torch.zeros_like(pos_out_non_infection)
  76. ))
  77. if NB > 0:
  78. if PB > 0:
  79. # Top negative pixels must be negative
  80. neg_k = int(math.ceil(PB * k * 1.0 / NB))
  81. neg_out_topk = neg_out.topk(neg_k, dim=-1, sorted=False)[0]
  82. neg_losses.append(F.binary_cross_entropy(
  83. neg_out_topk,
  84. torch.zeros_like(neg_out_topk)
  85. ))
  86. else:
  87. # All negative pixels must be negative
  88. neg_losses.append(F.binary_cross_entropy(
  89. neg_out,
  90. torch.zeros_like(neg_out)
  91. ))
  92. losses = []
  93. if len(neg_losses) > 0:
  94. losses.append(torch.stack(neg_losses).mean())
  95. if len(pos_losses) > 0:
  96. losses.append(torch.stack(pos_losses).mean())
  97. l_loss = torch.stack(losses).mean()
  98. model_output[f'{self._title}_{ln}_loss'] = l_loss
  99. loss = loss + self._layers_loss_weights[li] * l_loss
  100. return loss