| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- import torch
- from torch import nn
- import numpy as np
-
-
-
- class loss_fn(torch.nn.Module):
- def __init__(self, alpha=0.7, gamma=2.0, epsilon=1e-5):
- super(loss_fn, self).__init__()
- self.alpha = alpha
- self.gamma = gamma
- self.epsilon = epsilon
-
-
- def focal_tversky(self, y_pred, y_true, gamma=0.75):
- pt_1 = self.tversky_loss(y_pred, y_true)
- return torch.pow((1 - pt_1), gamma)
- def dice_loss(self, probs, gt, eps=1):
-
- intersection = (probs * gt).sum(dim=(-2,-1))
- dice_coeff = (2.0 * intersection + eps) / (probs.sum(dim=(-2,-1)) + gt.sum(dim=(-2,-1)) + eps)
- loss = 1 - dice_coeff.mean()
- return loss
-
- def focal_loss(self, probs, gt, gamma=4):
- probs = probs.reshape(-1, 1)
- gt = gt.reshape(-1, 1)
- probs = torch.cat((1 - probs, probs), dim=1)
-
- pt = probs.gather(1, gt.long())
- modulating_factor = (1 - pt) ** gamma
- # modulating_factor = (3**(10*((1-pt)-0.5)))*(1 - pt) ** gamma
- modulating_factor[pt>0.55] = 0.1*modulating_factor[pt>0.55]
-
- focal_loss = -modulating_factor * torch.log(pt + 1e-12)
-
- # Compute the mean focal loss
- loss = focal_loss.mean()
- return loss # Store as a Python number to save memory
-
- def forward(self, probs, target):
-
- self.gamma=8
- dice_loss = self.dice_loss(probs, target)
- # tversky_loss = self.tversky_loss(logits, target)
-
- # Focal Loss
- focal_loss = self.focal_loss(probs, target,self.gamma)
- alpha = 20.0
- # Combined Loss
- combined_loss = alpha * focal_loss + dice_loss
- return combined_loss
-
- def img_enhance(img2, coef=0.2):
- img_mean = np.mean(img2)
- img_max = np.max(img2)
- val = (img_max - img_mean) * coef + img_mean
- img2[img2 < img_mean * 0.7] = img_mean * 0.7
- img2[img2 > val] = val
- return img2
-
-
-
- def dice_coefficient(logits, gt):
- eps=1
- binary_mask = logits>0
- # raise ValueError( binary_mask.shape,gt.shape)
- intersection = (binary_mask * gt).sum(dim=(-2,-1))
- dice_scores = (2.0 * intersection + eps) / (binary_mask.sum(dim=(-2,-1)) + gt.sum(dim=(-2,-1)) + eps)
- # raise ValueError(intersection.shape , binary_mask.shape,gt.shape)
-
- return dice_scores.mean()
-
- def calculate_accuracy(pred, target):
- correct = (pred == target).sum().item()
- total = target.numel()
- return correct / total
-
- def calculate_sensitivity(pred, target):
- smooth = 1
- # Also known as recall
- true_positive = ((pred == 1) & (target == 1)).sum().item()
- false_negative = ((pred == 0) & (target == 1)).sum().item()
- return (true_positive + smooth) / ((true_positive + false_negative) + smooth)
-
- def calculate_specificity(pred, target):
- smooth = 1
- true_negative = ((pred == 0) & (target == 0)).sum().item()
- false_positive = ((pred == 1) & (target == 0)).sum().item()
- return (true_negative + smooth) / ((true_negative + false_positive ) + smooth)
|