|
|
@@ -0,0 +1,148 @@ |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
import torch |
|
|
|
from torch.autograd import Variable |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
from opt import opt |
|
|
|
|
|
|
|
"""BCE loss""" |
|
|
|
|
|
|
|
|
|
|
|
class BCELoss(nn.Module): |
|
|
|
def __init__(self, weight=None, size_average=True): |
|
|
|
super(BCELoss, self).__init__() |
|
|
|
self.bceloss = nn.BCELoss(weight=weight, size_average=size_average) |
|
|
|
|
|
|
|
def forward(self, pred, target): |
|
|
|
size = pred.size(0) |
|
|
|
pred_flat = pred.view(size, -1) |
|
|
|
target_flat = target.view(size, -1) |
|
|
|
|
|
|
|
loss = self.bceloss(pred_flat, target_flat) |
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
"""Dice loss""" |
|
|
|
|
|
|
|
|
|
|
|
class DiceLoss(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super(DiceLoss, self).__init__() |
|
|
|
|
|
|
|
def forward(self, pred, target): |
|
|
|
smooth = 1 |
|
|
|
|
|
|
|
size = pred.size(0) |
|
|
|
|
|
|
|
pred_flat = pred.view(size, -1) |
|
|
|
target_flat = target.view(size, -1) |
|
|
|
|
|
|
|
intersection = pred_flat * target_flat |
|
|
|
dice_score = (2 * intersection.sum(1) + smooth)/(pred_flat.sum(1) + target_flat.sum(1) + smooth) |
|
|
|
dice_loss = 1 - dice_score.sum()/size |
|
|
|
|
|
|
|
return dice_loss |
|
|
|
|
|
|
|
|
|
|
|
"""BCE + DICE Loss + IoULoss""" |
|
|
|
|
|
|
|
|
|
|
|
class BceDiceLoss(nn.Module): |
|
|
|
def __init__(self, weight=None, size_average=True): |
|
|
|
super(BceDiceLoss, self).__init__() |
|
|
|
#self.bce = BCELoss(weight, size_average) |
|
|
|
self.fl = FocalLoss() |
|
|
|
self.dice = DiceLoss() |
|
|
|
self.iou = IoULoss() |
|
|
|
|
|
|
|
def forward(self, pred, target): |
|
|
|
fcloss = self.fl(pred, target) |
|
|
|
diceloss = self.dice(pred, target) |
|
|
|
iouloss = self.iou(pred, target) |
|
|
|
loss = fcloss + diceloss + iouloss #Use the obmination of loss |
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
""" Deep Supervision Loss""" |
|
|
|
|
|
|
|
|
|
|
|
def DeepSupervisionLoss(pred, gt): |
|
|
|
d0, d1, d2, d3, d4 = pred[0:] |
|
|
|
|
|
|
|
criterion = BceDiceLoss() |
|
|
|
|
|
|
|
loss0 = criterion(d0, gt) |
|
|
|
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True) |
|
|
|
loss1 = criterion(d1, gt) |
|
|
|
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True) |
|
|
|
loss2 = criterion(d2, gt) |
|
|
|
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True) |
|
|
|
loss3 = criterion(d3, gt) |
|
|
|
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True) |
|
|
|
loss4 = criterion(d4, gt) |
|
|
|
|
|
|
|
return loss0 + loss1 + loss2 + loss3 + loss4 |
|
|
|
|
|
|
|
|
|
|
|
class FocalLoss(nn.modules.loss._WeightedLoss): |
|
|
|
def __init__(self, weight=None, gamma=4,reduction='mean'): |
|
|
|
super(FocalLoss, self).__init__(weight,reduction=reduction) |
|
|
|
self.gamma = gamma |
|
|
|
self.weight = weight #weight parameter will act as the alpha parameter to balance class weights |
|
|
|
|
|
|
|
def forward(self, input, target): |
|
|
|
|
|
|
|
ce_loss = F.binary_cross_entropy(input, target,reduction=self.reduction,weight=self.weight) |
|
|
|
pt = torch.exp(-ce_loss) |
|
|
|
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean() |
|
|
|
return focal_loss |
|
|
|
|
|
|
|
|
|
|
|
class IoULoss(nn.Module): |
|
|
|
def __init__(self, weight=None, size_average=True): |
|
|
|
super(IoULoss, self).__init__() |
|
|
|
|
|
|
|
def forward(self, inputs, targets, smooth=1): |
|
|
|
|
|
|
|
#comment out if your model contains a sigmoid or equivalent activation layer |
|
|
|
inputs = torch.sigmoid(inputs) |
|
|
|
|
|
|
|
#flatten label and prediction tensors |
|
|
|
inputs = inputs.view(-1) |
|
|
|
targets = targets.view(-1) |
|
|
|
|
|
|
|
#intersection is equivalent to True Positive count |
|
|
|
#union is the mutually inclusive area of all labels & predictions |
|
|
|
intersection = (inputs * targets).sum() |
|
|
|
total = (inputs + targets).sum() |
|
|
|
union = total - intersection |
|
|
|
|
|
|
|
IoU = (intersection + smooth)/(union + smooth) |
|
|
|
|
|
|
|
return 1 - IoU |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Code from PraNet |
|
|
|
|
|
|
|
class IoULossW(nn.Module): |
|
|
|
def __init__(self, weight=None, size_average=True): |
|
|
|
super(IoULossW, self).__init__() |
|
|
|
|
|
|
|
def forward(self, inputs, targets, smooth=1): |
|
|
|
|
|
|
|
#comment out if your model contains a sigmoid or equivalent activation layer |
|
|
|
pred = torch.sigmoid(inputs) |
|
|
|
weit = 1 + 5*torch.abs(F.avg_pool2d(targets, kernel_size=31, stride=1, padding=15) - targets) |
|
|
|
inter = ((pred * targets)* weit).sum(dim=(2, 3)) |
|
|
|
#print(inter.shape) |
|
|
|
union = ((pred + targets)* weit).sum(dim=(2, 3)) |
|
|
|
wiou = 1 - (inter + 1)/(union - inter+1) |
|
|
|
wiou = wiou.mean()*opt.weight_const |
|
|
|
|
|
|
|
return wiou |