You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 4.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. import torch
  4. from torch.autograd import Variable
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from opt import opt
  9. """BCE loss"""
  10. class BCELoss(nn.Module):
  11. def __init__(self, weight=None, size_average=True):
  12. super(BCELoss, self).__init__()
  13. self.bceloss = nn.BCELoss(weight=weight, size_average=size_average)
  14. def forward(self, pred, target):
  15. size = pred.size(0)
  16. pred_flat = pred.view(size, -1)
  17. target_flat = target.view(size, -1)
  18. loss = self.bceloss(pred_flat, target_flat)
  19. return loss
  20. """Dice loss"""
  21. class DiceLoss(nn.Module):
  22. def __init__(self):
  23. super(DiceLoss, self).__init__()
  24. def forward(self, pred, target):
  25. smooth = 1
  26. size = pred.size(0)
  27. pred_flat = pred.view(size, -1)
  28. target_flat = target.view(size, -1)
  29. intersection = pred_flat * target_flat
  30. dice_score = (2 * intersection.sum(1) + smooth)/(pred_flat.sum(1) + target_flat.sum(1) + smooth)
  31. dice_loss = 1 - dice_score.sum()/size
  32. return dice_loss
  33. """BCE + DICE Loss + IoULoss"""
  34. class BceDiceLoss(nn.Module):
  35. def __init__(self, weight=None, size_average=True):
  36. super(BceDiceLoss, self).__init__()
  37. #self.bce = BCELoss(weight, size_average)
  38. self.fl = FocalLoss()
  39. self.dice = DiceLoss()
  40. self.iou = IoULoss()
  41. def forward(self, pred, target):
  42. fcloss = self.fl(pred, target)
  43. diceloss = self.dice(pred, target)
  44. iouloss = self.iou(pred, target)
  45. loss = fcloss + diceloss + iouloss #Use the obmination of loss
  46. return loss
  47. """ Deep Supervision Loss"""
  48. def DeepSupervisionLoss(pred, gt):
  49. d0, d1, d2, d3, d4 = pred[0:]
  50. criterion = BceDiceLoss()
  51. loss0 = criterion(d0, gt)
  52. gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
  53. loss1 = criterion(d1, gt)
  54. gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
  55. loss2 = criterion(d2, gt)
  56. gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
  57. loss3 = criterion(d3, gt)
  58. gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
  59. loss4 = criterion(d4, gt)
  60. return loss0 + loss1 + loss2 + loss3 + loss4
  61. class FocalLoss(nn.modules.loss._WeightedLoss):
  62. def __init__(self, weight=None, gamma=4,reduction='mean'):
  63. super(FocalLoss, self).__init__(weight,reduction=reduction)
  64. self.gamma = gamma
  65. self.weight = weight #weight parameter will act as the alpha parameter to balance class weights
  66. def forward(self, input, target):
  67. ce_loss = F.binary_cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
  68. pt = torch.exp(-ce_loss)
  69. focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
  70. return focal_loss
  71. class IoULoss(nn.Module):
  72. def __init__(self, weight=None, size_average=True):
  73. super(IoULoss, self).__init__()
  74. def forward(self, inputs, targets, smooth=1):
  75. #comment out if your model contains a sigmoid or equivalent activation layer
  76. inputs = torch.sigmoid(inputs)
  77. #flatten label and prediction tensors
  78. inputs = inputs.view(-1)
  79. targets = targets.view(-1)
  80. #intersection is equivalent to True Positive count
  81. #union is the mutually inclusive area of all labels & predictions
  82. intersection = (inputs * targets).sum()
  83. total = (inputs + targets).sum()
  84. union = total - intersection
  85. IoU = (intersection + smooth)/(union + smooth)
  86. return 1 - IoU
  87. # Code from PraNet
  88. class IoULossW(nn.Module):
  89. def __init__(self, weight=None, size_average=True):
  90. super(IoULossW, self).__init__()
  91. def forward(self, inputs, targets, smooth=1):
  92. #comment out if your model contains a sigmoid or equivalent activation layer
  93. pred = torch.sigmoid(inputs)
  94. weit = 1 + 5*torch.abs(F.avg_pool2d(targets, kernel_size=31, stride=1, padding=15) - targets)
  95. inter = ((pred * targets)* weit).sum(dim=(2, 3))
  96. #print(inter.shape)
  97. union = ((pred + targets)* weit).sum(dim=(2, 3))
  98. wiou = 1 - (inter + 1)/(union - inter+1)
  99. wiou = wiou.mean()*opt.weight_const
  100. return wiou