Browse Source

Add files via upload

main
rucv 1 year ago
parent
commit
97d543ee44
No account linked to committer's email address
1 changed files with 148 additions and 0 deletions
  1. 148
    0
      utils/loss.py

+ 148
- 0
utils/loss.py View File

@@ -0,0 +1,148 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt import opt

"""BCE loss"""


class BCELoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(BCELoss, self).__init__()
self.bceloss = nn.BCELoss(weight=weight, size_average=size_average)

def forward(self, pred, target):
size = pred.size(0)
pred_flat = pred.view(size, -1)
target_flat = target.view(size, -1)

loss = self.bceloss(pred_flat, target_flat)

return loss


"""Dice loss"""


class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()

def forward(self, pred, target):
smooth = 1

size = pred.size(0)

pred_flat = pred.view(size, -1)
target_flat = target.view(size, -1)

intersection = pred_flat * target_flat
dice_score = (2 * intersection.sum(1) + smooth)/(pred_flat.sum(1) + target_flat.sum(1) + smooth)
dice_loss = 1 - dice_score.sum()/size

return dice_loss


"""BCE + DICE Loss + IoULoss"""


class BceDiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(BceDiceLoss, self).__init__()
#self.bce = BCELoss(weight, size_average)
self.fl = FocalLoss()
self.dice = DiceLoss()
self.iou = IoULoss()

def forward(self, pred, target):
fcloss = self.fl(pred, target)
diceloss = self.dice(pred, target)
iouloss = self.iou(pred, target)
loss = fcloss + diceloss + iouloss #Use the obmination of loss

return loss


""" Deep Supervision Loss"""


def DeepSupervisionLoss(pred, gt):
d0, d1, d2, d3, d4 = pred[0:]

criterion = BceDiceLoss()

loss0 = criterion(d0, gt)
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True)
loss1 = criterion(d1, gt)
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True)
loss2 = criterion(d2, gt)
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True)
loss3 = criterion(d3, gt)
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True)
loss4 = criterion(d4, gt)

return loss0 + loss1 + loss2 + loss3 + loss4


class FocalLoss(nn.modules.loss._WeightedLoss):
def __init__(self, weight=None, gamma=4,reduction='mean'):
super(FocalLoss, self).__init__(weight,reduction=reduction)
self.gamma = gamma
self.weight = weight #weight parameter will act as the alpha parameter to balance class weights

def forward(self, input, target):

ce_loss = F.binary_cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
return focal_loss


class IoULoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(IoULoss, self).__init__()

def forward(self, inputs, targets, smooth=1):

#comment out if your model contains a sigmoid or equivalent activation layer
inputs = torch.sigmoid(inputs)

#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)

#intersection is equivalent to True Positive count
#union is the mutually inclusive area of all labels & predictions
intersection = (inputs * targets).sum()
total = (inputs + targets).sum()
union = total - intersection

IoU = (intersection + smooth)/(union + smooth)

return 1 - IoU




# Code from PraNet

class IoULossW(nn.Module):
def __init__(self, weight=None, size_average=True):
super(IoULossW, self).__init__()

def forward(self, inputs, targets, smooth=1):

#comment out if your model contains a sigmoid or equivalent activation layer
pred = torch.sigmoid(inputs)
weit = 1 + 5*torch.abs(F.avg_pool2d(targets, kernel_size=31, stride=1, padding=15) - targets)
inter = ((pred * targets)* weit).sum(dim=(2, 3))
#print(inter.shape)
union = ((pred + targets)* weit).sum(dim=(2, 3))
wiou = 1 - (inter + 1)/(union - inter+1)
wiou = wiou.mean()*opt.weight_const

return wiou

Loading…
Cancel
Save