|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- #!/usr/bin/env python
- # -*- encoding: utf-8 -*-
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
-
- import torch.nn as nn
-
- from .yolo_head import YOLOXHead
- from .yolo_pafpn import YOLOPAFPN
-
-
- class YOLOX(nn.Module):
- """
- YOLOX model module. The module list is defined by create_yolov3_modules function.
- The network returns loss values from three YOLO layers during training
- and detection results during test.
- """
-
- def __init__(self, backbone=None, head=None):
- super().__init__()
- if backbone is None:
- backbone = YOLOPAFPN()
- if head is None:
- head = YOLOXHead(80)
-
- self.backbone = backbone
- self.head = head
-
- def forward(self, x, targets=None):
- # fpn output content features of [dark3, dark4, dark5]
- fpn_outs = self.backbone(x)
- # print('fpn_outs', fpn_outs.keys())
-
- if self.training:
- assert targets is not None
- loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
- fpn_outs, targets, x
- )
- outputs = {
- "total_loss": loss,
- "iou_loss": iou_loss,
- "l1_loss": l1_loss,
- "conf_loss": conf_loss,
- "cls_loss": cls_loss,
- "num_fg": num_fg,
- }
- else:
- outputs = self.head(fpn_outs)
-
-
- return outputs
|