Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolox.py 1.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch.nn as nn
  5. from .yolo_head import YOLOXHead
  6. from .yolo_pafpn import YOLOPAFPN
  7. class YOLOX(nn.Module):
  8. """
  9. YOLOX model module. The module list is defined by create_yolov3_modules function.
  10. The network returns loss values from three YOLO layers during training
  11. and detection results during test.
  12. """
  13. def __init__(self, backbone=None, head=None):
  14. super().__init__()
  15. if backbone is None:
  16. backbone = YOLOPAFPN()
  17. if head is None:
  18. head = YOLOXHead(80)
  19. self.backbone = backbone
  20. self.head = head
  21. def forward(self, x, targets=None):
  22. # fpn output content features of [dark3, dark4, dark5]
  23. fpn_outs = self.backbone(x)
  24. # print('fpn_outs', fpn_outs.keys())
  25. if self.training:
  26. assert targets is not None
  27. loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
  28. fpn_outs, targets, x
  29. )
  30. outputs = {
  31. "total_loss": loss,
  32. "iou_loss": iou_loss,
  33. "l1_loss": l1_loss,
  34. "conf_loss": conf_loss,
  35. "cls_loss": cls_loss,
  36. "num_fg": num_fg,
  37. }
  38. else:
  39. outputs = self.head(fpn_outs)
  40. return outputs