|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- #!/usr/bin/env python
- # -*- encoding: utf-8 -*-
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
-
- class IOUloss(nn.Module):
- def __init__(self, reduction="none", loss_type="iou"):
- super(IOUloss, self).__init__()
- self.reduction = reduction
- self.loss_type = loss_type
-
- def forward(self, pred, target):
- assert pred.shape[0] == target.shape[0]
-
- pred = pred.view(-1, 4)
- target = target.view(-1, 4)
- tl = torch.max(
- (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
- )
- br = torch.min(
- (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
- )
-
- area_p = torch.prod(pred[:, 2:], 1)
- area_g = torch.prod(target[:, 2:], 1)
-
- en = (tl < br).type(tl.type()).prod(dim=1)
- area_i = torch.prod(br - tl, 1) * en
- iou = (area_i) / (area_p + area_g - area_i + 1e-16)
-
- if self.loss_type == "iou":
- loss = 1 - iou ** 2
- elif self.loss_type == "giou":
- c_tl = torch.min(
- (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
- )
- c_br = torch.max(
- (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
- )
- area_c = torch.prod(c_br - c_tl, 1)
- giou = iou - (area_c - area_i) / area_c.clamp(1e-16)
- loss = 1 - giou.clamp(min=-1.0, max=1.0)
-
- if self.reduction == "mean":
- loss = loss.mean()
- elif self.reduction == "sum":
- loss = loss.sum()
-
- return loss
-
-
- def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
- """
- Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
- Args:
- inputs: A float tensor of arbitrary shape.
- The predictions for each example.
- targets: A float tensor with the same shape as inputs. Stores the binary
- classification label for each element in inputs
- (0 for the negative class and 1 for the positive class).
- alpha: (optional) Weighting factor in range (0,1) to balance
- positive vs negative examples. Default = -1 (no weighting).
- gamma: Exponent of the modulating factor (1 - p_t) to
- balance easy vs hard examples.
- Returns:
- Loss tensor
- """
- prob = inputs.sigmoid()
- ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
- p_t = prob * targets + (1 - prob) * (1 - targets)
- loss = ce_loss * ((1 - p_t) ** gamma)
-
- if alpha >= 0:
- alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
- loss = alpha_t * loss
- #return loss.mean(0).sum() / num_boxes
- return loss.sum() / num_boxes
|