Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

losses.py 2.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. class IOUloss(nn.Module):
  8. def __init__(self, reduction="none", loss_type="iou"):
  9. super(IOUloss, self).__init__()
  10. self.reduction = reduction
  11. self.loss_type = loss_type
  12. def forward(self, pred, target):
  13. assert pred.shape[0] == target.shape[0]
  14. pred = pred.view(-1, 4)
  15. target = target.view(-1, 4)
  16. tl = torch.max(
  17. (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
  18. )
  19. br = torch.min(
  20. (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
  21. )
  22. area_p = torch.prod(pred[:, 2:], 1)
  23. area_g = torch.prod(target[:, 2:], 1)
  24. en = (tl < br).type(tl.type()).prod(dim=1)
  25. area_i = torch.prod(br - tl, 1) * en
  26. iou = (area_i) / (area_p + area_g - area_i + 1e-16)
  27. if self.loss_type == "iou":
  28. loss = 1 - iou ** 2
  29. elif self.loss_type == "giou":
  30. c_tl = torch.min(
  31. (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
  32. )
  33. c_br = torch.max(
  34. (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
  35. )
  36. area_c = torch.prod(c_br - c_tl, 1)
  37. giou = iou - (area_c - area_i) / area_c.clamp(1e-16)
  38. loss = 1 - giou.clamp(min=-1.0, max=1.0)
  39. if self.reduction == "mean":
  40. loss = loss.mean()
  41. elif self.reduction == "sum":
  42. loss = loss.sum()
  43. return loss
  44. def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
  45. """
  46. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  47. Args:
  48. inputs: A float tensor of arbitrary shape.
  49. The predictions for each example.
  50. targets: A float tensor with the same shape as inputs. Stores the binary
  51. classification label for each element in inputs
  52. (0 for the negative class and 1 for the positive class).
  53. alpha: (optional) Weighting factor in range (0,1) to balance
  54. positive vs negative examples. Default = -1 (no weighting).
  55. gamma: Exponent of the modulating factor (1 - p_t) to
  56. balance easy vs hard examples.
  57. Returns:
  58. Loss tensor
  59. """
  60. prob = inputs.sigmoid()
  61. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  62. p_t = prob * targets + (1 - prob) * (1 - targets)
  63. loss = ce_loss * ((1 - p_t) ** gamma)
  64. if alpha >= 0:
  65. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  66. loss = alpha_t * loss
  67. #return loss.mean(0).sum() / num_boxes
  68. return loss.sum() / num_boxes