Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

demo_utils.py 2.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import numpy as np
  5. import os
  6. __all__ = ["mkdir", "nms", "multiclass_nms", "demo_postprocess"]
  7. def mkdir(path):
  8. if not os.path.exists(path):
  9. os.makedirs(path)
  10. def nms(boxes, scores, nms_thr):
  11. """Single class NMS implemented in Numpy."""
  12. x1 = boxes[:, 0]
  13. y1 = boxes[:, 1]
  14. x2 = boxes[:, 2]
  15. y2 = boxes[:, 3]
  16. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  17. order = scores.argsort()[::-1]
  18. keep = []
  19. while order.size > 0:
  20. i = order[0]
  21. keep.append(i)
  22. xx1 = np.maximum(x1[i], x1[order[1:]])
  23. yy1 = np.maximum(y1[i], y1[order[1:]])
  24. xx2 = np.minimum(x2[i], x2[order[1:]])
  25. yy2 = np.minimum(y2[i], y2[order[1:]])
  26. w = np.maximum(0.0, xx2 - xx1 + 1)
  27. h = np.maximum(0.0, yy2 - yy1 + 1)
  28. inter = w * h
  29. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  30. inds = np.where(ovr <= nms_thr)[0]
  31. order = order[inds + 1]
  32. return keep
  33. def multiclass_nms(boxes, scores, nms_thr, score_thr):
  34. """Multiclass NMS implemented in Numpy"""
  35. final_dets = []
  36. num_classes = scores.shape[1]
  37. for cls_ind in range(num_classes):
  38. cls_scores = scores[:, cls_ind]
  39. valid_score_mask = cls_scores > score_thr
  40. if valid_score_mask.sum() == 0:
  41. continue
  42. else:
  43. valid_scores = cls_scores[valid_score_mask]
  44. valid_boxes = boxes[valid_score_mask]
  45. keep = nms(valid_boxes, valid_scores, nms_thr)
  46. if len(keep) > 0:
  47. cls_inds = np.ones((len(keep), 1)) * cls_ind
  48. dets = np.concatenate(
  49. [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
  50. )
  51. final_dets.append(dets)
  52. if len(final_dets) == 0:
  53. return None
  54. return np.concatenate(final_dets, 0)
  55. def demo_postprocess(outputs, img_size, p6=False):
  56. grids = []
  57. expanded_strides = []
  58. if not p6:
  59. strides = [8, 16, 32]
  60. else:
  61. strides = [8, 16, 32, 64]
  62. hsizes = [img_size[0] // stride for stride in strides]
  63. wsizes = [img_size[1] // stride for stride in strides]
  64. for hsize, wsize, stride in zip(hsizes, wsizes, strides):
  65. xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
  66. grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
  67. grids.append(grid)
  68. shape = grid.shape[:2]
  69. expanded_strides.append(np.full((*shape, 1), stride))
  70. grids = np.concatenate(grids, 1)
  71. expanded_strides = np.concatenate(expanded_strides, 1)
  72. outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
  73. outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
  74. return outputs