123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
- from torch.autograd import Variable
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from opt import opt
-
- """BCE loss"""
-
-
- class BCELoss(nn.Module):
- def __init__(self, weight=None, size_average=True):
- super(BCELoss, self).__init__()
- self.bceloss = nn.BCELoss(weight=weight, size_average=size_average)
-
- def forward(self, pred, target):
- size = pred.size(0)
- pred_flat = pred.view(size, -1)
- target_flat = target.view(size, -1)
-
- loss = self.bceloss(pred_flat, target_flat)
-
- return loss
-
-
- """Dice loss"""
-
-
- class DiceLoss(nn.Module):
- def __init__(self):
- super(DiceLoss, self).__init__()
-
- def forward(self, pred, target):
- smooth = 1
-
- size = pred.size(0)
-
- pred_flat = pred.view(size, -1)
- target_flat = target.view(size, -1)
-
- intersection = pred_flat * target_flat
- dice_score = (2 * intersection.sum(1) + smooth)/(pred_flat.sum(1) + target_flat.sum(1) + smooth)
- dice_loss = 1 - dice_score.sum()/size
-
- return dice_loss
-
-
- """BCE + DICE Loss + IoULoss"""
-
-
- class BceDiceLoss(nn.Module):
- def __init__(self, weight=None, size_average=True):
- super(BceDiceLoss, self).__init__()
- #self.bce = BCELoss(weight, size_average)
- self.fl = FocalLoss()
- self.dice = DiceLoss()
- self.iou = IoULoss()
-
- def forward(self, pred, target):
- fcloss = self.fl(pred, target)
- diceloss = self.dice(pred, target)
- iouloss = self.iou(pred, target)
- loss = fcloss + diceloss + iouloss #Use the obmination of loss
-
- return loss
-
-
- """ Deep Supervision Loss"""
-
-
- def DeepSupervisionLoss(pred, gt):
- d0, d1, d2, d3, d4 = pred[0:]
-
- criterion = BceDiceLoss()
-
- loss0 = criterion(d0, gt)
- gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
- loss1 = criterion(d1, gt)
- gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
- loss2 = criterion(d2, gt)
- gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
- loss3 = criterion(d3, gt)
- gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
- loss4 = criterion(d4, gt)
-
- return loss0 + loss1 + loss2 + loss3 + loss4
-
-
- class FocalLoss(nn.modules.loss._WeightedLoss):
- def __init__(self, weight=None, gamma=4,reduction='mean'):
- super(FocalLoss, self).__init__(weight,reduction=reduction)
- self.gamma = gamma
- self.weight = weight #weight parameter will act as the alpha parameter to balance class weights
-
- def forward(self, input, target):
-
- ce_loss = F.binary_cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
- pt = torch.exp(-ce_loss)
- focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
- return focal_loss
-
-
- class IoULoss(nn.Module):
- def __init__(self, weight=None, size_average=True):
- super(IoULoss, self).__init__()
-
- def forward(self, inputs, targets, smooth=1):
-
- #comment out if your model contains a sigmoid or equivalent activation layer
- inputs = torch.sigmoid(inputs)
-
- #flatten label and prediction tensors
- inputs = inputs.view(-1)
- targets = targets.view(-1)
-
- #intersection is equivalent to True Positive count
- #union is the mutually inclusive area of all labels & predictions
- intersection = (inputs * targets).sum()
- total = (inputs + targets).sum()
- union = total - intersection
-
- IoU = (intersection + smooth)/(union + smooth)
-
- return 1 - IoU
-
-
-
-
- # Code from PraNet
-
- class IoULossW(nn.Module):
- def __init__(self, weight=None, size_average=True):
- super(IoULossW, self).__init__()
-
- def forward(self, inputs, targets, smooth=1):
-
- #comment out if your model contains a sigmoid or equivalent activation layer
- pred = torch.sigmoid(inputs)
- weit = 1 + 5*torch.abs(F.avg_pool2d(targets, kernel_size=31, stride=1, padding=15) - targets)
- inter = ((pred * targets)* weit).sum(dim=(2, 3))
- #print(inter.shape)
- union = ((pred + targets)* weit).sum(dim=(2, 3))
- wiou = 1 - (inter + 1)/(union - inter+1)
- wiou = wiou.mean()*opt.weight_const
-
- return wiou
|