Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolo_head.py 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. from loguru import logger
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from yolox.utils import bboxes_iou
  9. import math
  10. from .losses import IOUloss
  11. from .network_blocks import BaseConv, DWConv
  12. class YOLOXHead(nn.Module):
  13. def __init__(
  14. self,
  15. num_classes,
  16. width=1.0,
  17. strides=[8, 16, 32],
  18. in_channels=[256, 512, 1024],
  19. act="silu",
  20. depthwise=False,
  21. ):
  22. """
  23. Args:
  24. act (str): activation type of conv. Defalut value: "silu".
  25. depthwise (bool): wheather apply depthwise conv in conv branch. Defalut value: False.
  26. """
  27. super().__init__()
  28. self.n_anchors = 1
  29. self.num_classes = num_classes
  30. self.decode_in_inference = True # for deploy, set to False
  31. self.cls_convs = nn.ModuleList()
  32. self.reg_convs = nn.ModuleList()
  33. self.cls_preds = nn.ModuleList()
  34. self.reg_preds = nn.ModuleList()
  35. self.obj_preds = nn.ModuleList()
  36. self.stems = nn.ModuleList()
  37. Conv = DWConv if depthwise else BaseConv
  38. for i in range(len(in_channels)):
  39. self.stems.append(
  40. BaseConv(
  41. in_channels=int(in_channels[i] * width),
  42. out_channels=int(256 * width),
  43. ksize=1,
  44. stride=1,
  45. act=act,
  46. )
  47. )
  48. self.cls_convs.append(
  49. nn.Sequential(
  50. *[
  51. Conv(
  52. in_channels=int(256 * width),
  53. out_channels=int(256 * width),
  54. ksize=3,
  55. stride=1,
  56. act=act,
  57. ),
  58. Conv(
  59. in_channels=int(256 * width),
  60. out_channels=int(256 * width),
  61. ksize=3,
  62. stride=1,
  63. act=act,
  64. ),
  65. ]
  66. )
  67. )
  68. self.reg_convs.append(
  69. nn.Sequential(
  70. *[
  71. Conv(
  72. in_channels=int(256 * width),
  73. out_channels=int(256 * width),
  74. ksize=3,
  75. stride=1,
  76. act=act,
  77. ),
  78. Conv(
  79. in_channels=int(256 * width),
  80. out_channels=int(256 * width),
  81. ksize=3,
  82. stride=1,
  83. act=act,
  84. ),
  85. ]
  86. )
  87. )
  88. self.cls_preds.append(
  89. nn.Conv2d(
  90. in_channels=int(256 * width),
  91. out_channels=self.n_anchors * self.num_classes,
  92. kernel_size=1,
  93. stride=1,
  94. padding=0,
  95. )
  96. )
  97. self.reg_preds.append(
  98. nn.Conv2d(
  99. in_channels=int(256 * width),
  100. out_channels=4,
  101. kernel_size=1,
  102. stride=1,
  103. padding=0,
  104. )
  105. )
  106. self.obj_preds.append(
  107. nn.Conv2d(
  108. in_channels=int(256 * width),
  109. out_channels=self.n_anchors * 1,
  110. kernel_size=1,
  111. stride=1,
  112. padding=0,
  113. )
  114. )
  115. self.use_l1 = False
  116. self.l1_loss = nn.L1Loss(reduction="none")
  117. self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
  118. self.iou_loss = IOUloss(reduction="none")
  119. self.strides = strides
  120. self.grids = [torch.zeros(1)] * len(in_channels)
  121. self.expanded_strides = [None] * len(in_channels)
  122. def initialize_biases(self, prior_prob):
  123. for conv in self.cls_preds:
  124. b = conv.bias.view(self.n_anchors, -1)
  125. b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
  126. conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  127. for conv in self.obj_preds:
  128. b = conv.bias.view(self.n_anchors, -1)
  129. b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
  130. conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  131. def forward(self, xin, labels=None, imgs=None):
  132. outputs = []
  133. origin_preds = []
  134. x_shifts = []
  135. y_shifts = []
  136. expanded_strides = []
  137. for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
  138. zip(self.cls_convs, self.reg_convs, self.strides, xin)
  139. ):
  140. x = self.stems[k](x)
  141. cls_x = x
  142. reg_x = x
  143. cls_feat = cls_conv(cls_x)
  144. cls_output = self.cls_preds[k](cls_feat)
  145. reg_feat = reg_conv(reg_x)
  146. reg_output = self.reg_preds[k](reg_feat)
  147. obj_output = self.obj_preds[k](reg_feat)
  148. if self.training:
  149. output = torch.cat([reg_output, obj_output, cls_output], 1)
  150. output, grid = self.get_output_and_grid(
  151. output, k, stride_this_level, xin[0].type()
  152. )
  153. x_shifts.append(grid[:, :, 0])
  154. y_shifts.append(grid[:, :, 1])
  155. expanded_strides.append(
  156. torch.zeros(1, grid.shape[1])
  157. .fill_(stride_this_level)
  158. .type_as(xin[0])
  159. )
  160. if self.use_l1:
  161. batch_size = reg_output.shape[0]
  162. hsize, wsize = reg_output.shape[-2:]
  163. reg_output = reg_output.view(
  164. batch_size, self.n_anchors, 4, hsize, wsize
  165. )
  166. reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
  167. batch_size, -1, 4
  168. )
  169. origin_preds.append(reg_output.clone())
  170. else:
  171. output = torch.cat(
  172. [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
  173. )
  174. # TODO: have to see output shape to know whats going on in head
  175. outputs.append(output)
  176. if self.training:
  177. # logger.info("labels.shape:{}".format(labels.shape))
  178. # logger.info("torch.cat(outputs, 1).shape:{}".format(torch.cat(outputs, 1).shape))
  179. return self.get_losses(
  180. imgs,
  181. x_shifts,
  182. y_shifts,
  183. expanded_strides,
  184. labels,
  185. torch.cat(outputs, 1),
  186. origin_preds,
  187. dtype=xin[0].dtype,
  188. )
  189. else:
  190. self.hw = [x.shape[-2:] for x in outputs]
  191. # [batch, n_anchors_all, 85]
  192. outputs = torch.cat(
  193. [x.flatten(start_dim=2) for x in outputs], dim=2
  194. ).permute(0, 2, 1)
  195. if self.decode_in_inference:
  196. return self.decode_outputs(outputs, dtype=xin[0].type())
  197. else:
  198. return outputs
  199. def get_output_and_grid(self, output, k, stride, dtype):
  200. grid = self.grids[k]
  201. batch_size = output.shape[0]
  202. n_ch = 5 + self.num_classes
  203. hsize, wsize = output.shape[-2:]
  204. if grid.shape[2:4] != output.shape[2:4]:
  205. yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
  206. grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
  207. self.grids[k] = grid
  208. output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize)
  209. output = output.permute(0, 1, 3, 4, 2).reshape(
  210. batch_size, self.n_anchors * hsize * wsize, -1
  211. )
  212. grid = grid.view(1, -1, 2)
  213. output[..., :2] = (output[..., :2] + grid) * stride
  214. output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
  215. return output, grid
  216. def decode_outputs(self, outputs, dtype):
  217. grids = []
  218. strides = []
  219. for (hsize, wsize), stride in zip(self.hw, self.strides):
  220. yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
  221. grid = torch.stack((xv, yv), 2).view(1, -1, 2)
  222. grids.append(grid)
  223. shape = grid.shape[:2]
  224. strides.append(torch.full((*shape, 1), stride))
  225. grids = torch.cat(grids, dim=1).type(dtype)
  226. strides = torch.cat(strides, dim=1).type(dtype)
  227. outputs[..., :2] = (outputs[..., :2] + grids) * strides
  228. outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
  229. return outputs
  230. def get_losses(
  231. self,
  232. imgs,
  233. x_shifts,
  234. y_shifts,
  235. expanded_strides,
  236. labels,
  237. outputs,
  238. origin_preds,
  239. dtype,
  240. ):
  241. bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
  242. obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
  243. cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
  244. # calculate targets
  245. mixup = labels.shape[2] > 5
  246. if mixup:
  247. label_cut = labels[..., :5]
  248. else:
  249. label_cut = labels
  250. nlabel = (label_cut.sum(dim=2) > 0).sum(dim=1) # number of objects
  251. total_num_anchors = outputs.shape[1]
  252. x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all]
  253. y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all]
  254. expanded_strides = torch.cat(expanded_strides, 1)
  255. if self.use_l1:
  256. origin_preds = torch.cat(origin_preds, 1)
  257. cls_targets = []
  258. reg_targets = []
  259. l1_targets = []
  260. obj_targets = []
  261. fg_masks = []
  262. num_fg = 0.0
  263. num_gts = 0.0
  264. for batch_idx in range(outputs.shape[0]):
  265. num_gt = int(nlabel[batch_idx])
  266. num_gts += num_gt
  267. if num_gt == 0:
  268. cls_target = outputs.new_zeros((0, self.num_classes))
  269. reg_target = outputs.new_zeros((0, 4))
  270. l1_target = outputs.new_zeros((0, 4))
  271. obj_target = outputs.new_zeros((total_num_anchors, 1))
  272. fg_mask = outputs.new_zeros(total_num_anchors).bool()
  273. else:
  274. gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5]
  275. gt_classes = labels[batch_idx, :num_gt, 0]
  276. bboxes_preds_per_image = bbox_preds[batch_idx]
  277. try:
  278. (
  279. gt_matched_classes,
  280. fg_mask,
  281. pred_ious_this_matching,
  282. matched_gt_inds,
  283. num_fg_img,
  284. ) = self.get_assignments( # noqa
  285. batch_idx,
  286. num_gt,
  287. total_num_anchors,
  288. gt_bboxes_per_image,
  289. gt_classes,
  290. bboxes_preds_per_image,
  291. expanded_strides,
  292. x_shifts,
  293. y_shifts,
  294. cls_preds,
  295. bbox_preds,
  296. obj_preds,
  297. labels,
  298. imgs,
  299. )
  300. except RuntimeError as e:
  301. logger.info(
  302. "OOM RuntimeError is raised due to the huge memory cost during label assignment. \
  303. CPU mode is applied in this batch. If you want to avoid this issue, \
  304. try to reduce the batch size or image size. " + str(e)
  305. )
  306. print("OOM RuntimeError is raised due to the huge memory cost during label assignment. \
  307. CPU mode is applied in this batch. If you want to avoid this issue, \
  308. try to reduce the batch size or image size. " + str(e))
  309. torch.cuda.empty_cache()
  310. (
  311. gt_matched_classes,
  312. fg_mask,
  313. pred_ious_this_matching,
  314. matched_gt_inds,
  315. num_fg_img,
  316. ) = self.get_assignments( # noqa
  317. batch_idx,
  318. num_gt,
  319. total_num_anchors,
  320. gt_bboxes_per_image,
  321. gt_classes,
  322. bboxes_preds_per_image,
  323. expanded_strides,
  324. x_shifts,
  325. y_shifts,
  326. cls_preds,
  327. bbox_preds,
  328. obj_preds,
  329. labels,
  330. imgs,
  331. "cpu",
  332. )
  333. torch.cuda.empty_cache()
  334. num_fg += num_fg_img
  335. cls_target = F.one_hot(
  336. gt_matched_classes.to(torch.int64), self.num_classes
  337. ) * pred_ious_this_matching.unsqueeze(-1)
  338. obj_target = fg_mask.unsqueeze(-1)
  339. reg_target = gt_bboxes_per_image[matched_gt_inds]
  340. if self.use_l1:
  341. l1_target = self.get_l1_target(
  342. outputs.new_zeros((num_fg_img, 4)),
  343. gt_bboxes_per_image[matched_gt_inds],
  344. expanded_strides[0][fg_mask],
  345. x_shifts=x_shifts[0][fg_mask],
  346. y_shifts=y_shifts[0][fg_mask],
  347. )
  348. cls_targets.append(cls_target)
  349. reg_targets.append(reg_target)
  350. obj_targets.append(obj_target.to(dtype))
  351. fg_masks.append(fg_mask)
  352. if self.use_l1:
  353. l1_targets.append(l1_target)
  354. cls_targets = torch.cat(cls_targets, 0)
  355. reg_targets = torch.cat(reg_targets, 0)
  356. obj_targets = torch.cat(obj_targets, 0)
  357. fg_masks = torch.cat(fg_masks, 0)
  358. if self.use_l1:
  359. l1_targets = torch.cat(l1_targets, 0)
  360. # TODO: check loss parts shapes
  361. num_fg = max(num_fg, 1)
  362. loss_iou = (
  363. self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)
  364. ).sum() / num_fg
  365. loss_obj = (
  366. self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)
  367. ).sum() / num_fg
  368. loss_cls = (
  369. self.bcewithlog_loss(
  370. cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets
  371. )
  372. ).sum() / num_fg
  373. if self.use_l1:
  374. loss_l1 = (
  375. self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)
  376. ).sum() / num_fg
  377. else:
  378. loss_l1 = 0.0
  379. reg_weight = 5.0
  380. loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
  381. return (
  382. loss,
  383. reg_weight * loss_iou,
  384. loss_obj,
  385. loss_cls,
  386. loss_l1,
  387. num_fg / max(num_gts, 1),
  388. )
  389. def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
  390. l1_target[:, 0] = gt[:, 0] / stride - x_shifts
  391. l1_target[:, 1] = gt[:, 1] / stride - y_shifts
  392. l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps)
  393. l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps)
  394. return l1_target
  395. @torch.no_grad()
  396. def get_assignments(
  397. self,
  398. batch_idx,
  399. num_gt,
  400. total_num_anchors,
  401. gt_bboxes_per_image,
  402. gt_classes,
  403. bboxes_preds_per_image,
  404. expanded_strides,
  405. x_shifts,
  406. y_shifts,
  407. cls_preds,
  408. bbox_preds,
  409. obj_preds,
  410. labels,
  411. imgs,
  412. mode="gpu",
  413. ):
  414. if mode == "cpu":
  415. print("------------CPU Mode for This Batch-------------")
  416. gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
  417. bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
  418. gt_classes = gt_classes.cpu().float()
  419. expanded_strides = expanded_strides.cpu().float()
  420. x_shifts = x_shifts.cpu()
  421. y_shifts = y_shifts.cpu()
  422. img_size = imgs.shape[2:]
  423. fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
  424. gt_bboxes_per_image,
  425. expanded_strides,
  426. x_shifts,
  427. y_shifts,
  428. total_num_anchors,
  429. num_gt,
  430. img_size
  431. )
  432. bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
  433. cls_preds_ = cls_preds[batch_idx][fg_mask]
  434. obj_preds_ = obj_preds[batch_idx][fg_mask]
  435. num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
  436. if mode == "cpu":
  437. gt_bboxes_per_image = gt_bboxes_per_image.cpu()
  438. bboxes_preds_per_image = bboxes_preds_per_image.cpu()
  439. pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
  440. gt_cls_per_image = (
  441. F.one_hot(gt_classes.to(torch.int64), self.num_classes)
  442. .float()
  443. .unsqueeze(1)
  444. .repeat(1, num_in_boxes_anchor, 1)
  445. )
  446. pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
  447. if mode == "cpu":
  448. cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()
  449. with torch.cuda.amp.autocast(enabled=False):
  450. cls_preds_ = (
  451. cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
  452. * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
  453. )
  454. pair_wise_cls_loss = F.binary_cross_entropy(
  455. cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
  456. ).sum(-1)
  457. del cls_preds_
  458. cost = (
  459. pair_wise_cls_loss
  460. + 3.0 * pair_wise_ious_loss
  461. + 100000.0 * (~is_in_boxes_and_center)
  462. )
  463. (
  464. num_fg,
  465. gt_matched_classes,
  466. pred_ious_this_matching,
  467. matched_gt_inds,
  468. ) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
  469. del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
  470. if mode == "cpu":
  471. gt_matched_classes = gt_matched_classes.cuda()
  472. fg_mask = fg_mask.cuda()
  473. pred_ious_this_matching = pred_ious_this_matching.cuda()
  474. matched_gt_inds = matched_gt_inds.cuda()
  475. return (
  476. gt_matched_classes,
  477. fg_mask,
  478. pred_ious_this_matching,
  479. matched_gt_inds,
  480. num_fg,
  481. )
  482. def get_in_boxes_info(
  483. self,
  484. gt_bboxes_per_image,
  485. expanded_strides,
  486. x_shifts,
  487. y_shifts,
  488. total_num_anchors,
  489. num_gt,
  490. img_size
  491. ):
  492. expanded_strides_per_image = expanded_strides[0]
  493. x_shifts_per_image = x_shifts[0] * expanded_strides_per_image
  494. y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
  495. x_centers_per_image = (
  496. (x_shifts_per_image + 0.5 * expanded_strides_per_image)
  497. .unsqueeze(0)
  498. .repeat(num_gt, 1)
  499. ) # [n_anchor] -> [n_gt, n_anchor]
  500. y_centers_per_image = (
  501. (y_shifts_per_image + 0.5 * expanded_strides_per_image)
  502. .unsqueeze(0)
  503. .repeat(num_gt, 1)
  504. )
  505. gt_bboxes_per_image_l = (
  506. (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
  507. .unsqueeze(1)
  508. .repeat(1, total_num_anchors)
  509. )
  510. gt_bboxes_per_image_r = (
  511. (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
  512. .unsqueeze(1)
  513. .repeat(1, total_num_anchors)
  514. )
  515. gt_bboxes_per_image_t = (
  516. (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
  517. .unsqueeze(1)
  518. .repeat(1, total_num_anchors)
  519. )
  520. gt_bboxes_per_image_b = (
  521. (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
  522. .unsqueeze(1)
  523. .repeat(1, total_num_anchors)
  524. )
  525. b_l = x_centers_per_image - gt_bboxes_per_image_l
  526. b_r = gt_bboxes_per_image_r - x_centers_per_image
  527. b_t = y_centers_per_image - gt_bboxes_per_image_t
  528. b_b = gt_bboxes_per_image_b - y_centers_per_image
  529. bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
  530. is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
  531. is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
  532. # in fixed center
  533. center_radius = 2.5
  534. # clip center inside image
  535. gt_bboxes_per_image_clip = gt_bboxes_per_image[:, 0:2].clone()
  536. gt_bboxes_per_image_clip[:, 0] = torch.clamp(gt_bboxes_per_image_clip[:, 0], min=0, max=img_size[1])
  537. gt_bboxes_per_image_clip[:, 1] = torch.clamp(gt_bboxes_per_image_clip[:, 1], min=0, max=img_size[0])
  538. gt_bboxes_per_image_l = (gt_bboxes_per_image_clip[:, 0]).unsqueeze(1).repeat(
  539. 1, total_num_anchors
  540. ) - center_radius * expanded_strides_per_image.unsqueeze(0)
  541. gt_bboxes_per_image_r = (gt_bboxes_per_image_clip[:, 0]).unsqueeze(1).repeat(
  542. 1, total_num_anchors
  543. ) + center_radius * expanded_strides_per_image.unsqueeze(0)
  544. gt_bboxes_per_image_t = (gt_bboxes_per_image_clip[:, 1]).unsqueeze(1).repeat(
  545. 1, total_num_anchors
  546. ) - center_radius * expanded_strides_per_image.unsqueeze(0)
  547. gt_bboxes_per_image_b = (gt_bboxes_per_image_clip[:, 1]).unsqueeze(1).repeat(
  548. 1, total_num_anchors
  549. ) + center_radius * expanded_strides_per_image.unsqueeze(0)
  550. c_l = x_centers_per_image - gt_bboxes_per_image_l
  551. c_r = gt_bboxes_per_image_r - x_centers_per_image
  552. c_t = y_centers_per_image - gt_bboxes_per_image_t
  553. c_b = gt_bboxes_per_image_b - y_centers_per_image
  554. center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
  555. is_in_centers = center_deltas.min(dim=-1).values > 0.0
  556. is_in_centers_all = is_in_centers.sum(dim=0) > 0
  557. # in boxes and in centers
  558. is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
  559. is_in_boxes_and_center = (
  560. is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
  561. )
  562. del gt_bboxes_per_image_clip
  563. return is_in_boxes_anchor, is_in_boxes_and_center
  564. def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
  565. # Dynamic K
  566. # ---------------------------------------------------------------
  567. matching_matrix = torch.zeros_like(cost)
  568. ious_in_boxes_matrix = pair_wise_ious
  569. n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
  570. topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
  571. dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
  572. for gt_idx in range(num_gt):
  573. _, pos_idx = torch.topk(
  574. cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False
  575. )
  576. matching_matrix[gt_idx][pos_idx] = 1.0
  577. del topk_ious, dynamic_ks, pos_idx
  578. anchor_matching_gt = matching_matrix.sum(0)
  579. if (anchor_matching_gt > 1).sum() > 0:
  580. cost_min, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
  581. matching_matrix[:, anchor_matching_gt > 1] *= 0.0
  582. matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
  583. fg_mask_inboxes = matching_matrix.sum(0) > 0.0
  584. num_fg = fg_mask_inboxes.sum().item()
  585. fg_mask[fg_mask.clone()] = fg_mask_inboxes
  586. matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
  587. gt_matched_classes = gt_classes[matched_gt_inds]
  588. pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
  589. fg_mask_inboxes
  590. ]
  591. return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds