You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 2.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import torch
  2. from torch import nn
  3. import numpy as np
  4. class loss_fn(torch.nn.Module):
  5. def __init__(self, alpha=0.7, gamma=2.0, epsilon=1e-5):
  6. super(loss_fn, self).__init__()
  7. self.alpha = alpha
  8. self.gamma = gamma
  9. self.epsilon = epsilon
  10. def focal_tversky(self, y_pred, y_true, gamma=0.75):
  11. pt_1 = self.tversky_loss(y_pred, y_true)
  12. return torch.pow((1 - pt_1), gamma)
  13. def dice_loss(self, probs, gt, eps=1):
  14. intersection = (probs * gt).sum(dim=(-2,-1))
  15. dice_coeff = (2.0 * intersection + eps) / (probs.sum(dim=(-2,-1)) + gt.sum(dim=(-2,-1)) + eps)
  16. loss = 1 - dice_coeff.mean()
  17. return loss
  18. def focal_loss(self, probs, gt, gamma=4):
  19. probs = probs.reshape(-1, 1)
  20. gt = gt.reshape(-1, 1)
  21. probs = torch.cat((1 - probs, probs), dim=1)
  22. pt = probs.gather(1, gt.long())
  23. modulating_factor = (1 - pt) ** gamma
  24. # modulating_factor = (3**(10*((1-pt)-0.5)))*(1 - pt) ** gamma
  25. modulating_factor[pt>0.55] = 0.1*modulating_factor[pt>0.55]
  26. focal_loss = -modulating_factor * torch.log(pt + 1e-12)
  27. # Compute the mean focal loss
  28. loss = focal_loss.mean()
  29. return loss # Store as a Python number to save memory
  30. def forward(self, probs, target):
  31. self.gamma=8
  32. dice_loss = self.dice_loss(probs, target)
  33. # tversky_loss = self.tversky_loss(logits, target)
  34. # Focal Loss
  35. focal_loss = self.focal_loss(probs, target,self.gamma)
  36. alpha = 20.0
  37. # Combined Loss
  38. combined_loss = alpha * focal_loss + dice_loss
  39. return combined_loss
  40. def img_enhance(img2, coef=0.2):
  41. img_mean = np.mean(img2)
  42. img_max = np.max(img2)
  43. val = (img_max - img_mean) * coef + img_mean
  44. img2[img2 < img_mean * 0.7] = img_mean * 0.7
  45. img2[img2 > val] = val
  46. return img2
  47. def dice_coefficient(logits, gt):
  48. eps=1
  49. binary_mask = logits>0
  50. # raise ValueError( binary_mask.shape,gt.shape)
  51. intersection = (binary_mask * gt).sum(dim=(-2,-1))
  52. dice_scores = (2.0 * intersection + eps) / (binary_mask.sum(dim=(-2,-1)) + gt.sum(dim=(-2,-1)) + eps)
  53. # raise ValueError(intersection.shape , binary_mask.shape,gt.shape)
  54. return dice_scores.mean()
  55. def calculate_accuracy(pred, target):
  56. correct = (pred == target).sum().item()
  57. total = target.numel()
  58. return correct / total
  59. def calculate_sensitivity(pred, target):
  60. smooth = 1
  61. # Also known as recall
  62. true_positive = ((pred == 1) & (target == 1)).sum().item()
  63. false_negative = ((pred == 0) & (target == 1)).sum().item()
  64. return (true_positive + smooth) / ((true_positive + false_negative) + smooth)
  65. def calculate_specificity(pred, target):
  66. smooth = 1
  67. true_negative = ((pred == 0) & (target == 0)).sum().item()
  68. false_positive = ((pred == 1) & (target == 0)).sum().item()
  69. return (true_negative + smooth) / ((true_negative + false_positive ) + smooth)