You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

metrics.py 2.1KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import torch
  2. """
  3. The evaluation implementation refers to the following paper:
  4. "Selective Feature Aggregation Network with Area-Boundary Constraints for Polyp Segmentation"
  5. https://github.com/Yuqi-cuhk/Polyp-Seg
  6. """
  7. def evaluate(pred, gt, th):
  8. if isinstance(pred, (list, tuple)):
  9. pred = pred[0]
  10. pred_binary = (pred >= th).float()
  11. pred_binary_inverse = (pred_binary == 0).float()
  12. gt_binary = (gt >= th).float()
  13. gt_binary_inverse = (gt_binary == 0).float()
  14. TP = pred_binary.mul(gt_binary).sum()
  15. FP = pred_binary.mul(gt_binary_inverse).sum()
  16. TN = pred_binary_inverse.mul(gt_binary_inverse).sum()
  17. FN = pred_binary_inverse.mul(gt_binary).sum()
  18. if TP.item() == 0:
  19. # print('TP=0 now!')
  20. # print('Epoch: {}'.format(epoch))
  21. # print('i_batch: {}'.format(i_batch))
  22. TP = torch.Tensor([1]).cuda()
  23. # recall
  24. Recall = TP / (TP + FN)
  25. # Specificity or true negative rate
  26. Specificity = TN / (TN + FP)
  27. # Precision or positive predictive value
  28. Precision = TP / (TP + FP)
  29. # F1 score = Dice
  30. F1 = 2 * Precision * Recall / (Precision + Recall)
  31. # F2 score
  32. F2 = 5 * Precision * Recall / (4 * Precision + Recall)
  33. # Overall accuracy
  34. ACC_overall = (TP + TN) / (TP + FP + FN + TN)
  35. # IoU for poly
  36. IoU_poly = TP / (TP + FP + FN)
  37. # IoU for background
  38. IoU_bg = TN / (TN + FP + FN)
  39. # mean IoU
  40. IoU_mean = (IoU_poly + IoU_bg) / 2.0
  41. #Dice
  42. Dice = (2 * TP)/(2*TP + FN + FP)
  43. return Recall, Specificity, Precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean, Dice
  44. class Metrics(object):
  45. def __init__(self, metrics_list):
  46. self.metrics = {}
  47. for metric in metrics_list:
  48. self.metrics[metric] = 0
  49. def update(self, **kwargs):
  50. for k, v in kwargs.items():
  51. assert (k in self.metrics.keys()), "The k {} is not in metrics".format(k)
  52. if isinstance(v, torch.Tensor):
  53. v = v.item()
  54. self.metrics[k] += v
  55. def mean(self, total):
  56. mean_metrics = {}
  57. for k, v in self.metrics.items():
  58. mean_metrics[k] = v / total
  59. return mean_metrics