Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolo_pafpn.py 3.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch
  5. import torch.nn as nn
  6. from .darknet import CSPDarknet
  7. from .network_blocks import BaseConv, CSPLayer, DWConv
  8. class YOLOPAFPN(nn.Module):
  9. """
  10. YOLOv3 model. Darknet 53 is the default backbone of this model.
  11. """
  12. def __init__(
  13. self,
  14. depth=1.0,
  15. width=1.0,
  16. in_features=("dark3", "dark4", "dark5"),
  17. in_channels=[256, 512, 1024],
  18. depthwise=False,
  19. act="silu",
  20. ):
  21. super().__init__()
  22. self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
  23. self.in_features = in_features
  24. self.in_channels = in_channels
  25. Conv = DWConv if depthwise else BaseConv
  26. self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
  27. self.lateral_conv0 = BaseConv(
  28. int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act
  29. )
  30. self.C3_p4 = CSPLayer(
  31. int(2 * in_channels[1] * width),
  32. int(in_channels[1] * width),
  33. round(3 * depth),
  34. False,
  35. depthwise=depthwise,
  36. act=act,
  37. ) # cat
  38. self.reduce_conv1 = BaseConv(
  39. int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act
  40. )
  41. self.C3_p3 = CSPLayer(
  42. int(2 * in_channels[0] * width),
  43. int(in_channels[0] * width),
  44. round(3 * depth),
  45. False,
  46. depthwise=depthwise,
  47. act=act,
  48. )
  49. # bottom-up conv
  50. self.bu_conv2 = Conv(
  51. int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act
  52. )
  53. self.C3_n3 = CSPLayer(
  54. int(2 * in_channels[0] * width),
  55. int(in_channels[1] * width),
  56. round(3 * depth),
  57. False,
  58. depthwise=depthwise,
  59. act=act,
  60. )
  61. # bottom-up conv
  62. self.bu_conv1 = Conv(
  63. int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act
  64. )
  65. self.C3_n4 = CSPLayer(
  66. int(2 * in_channels[1] * width),
  67. int(in_channels[2] * width),
  68. round(3 * depth),
  69. False,
  70. depthwise=depthwise,
  71. act=act,
  72. )
  73. def forward(self, input):
  74. """
  75. Args:
  76. inputs: input images.
  77. Returns:
  78. Tuple[Tensor]: FPN feature.
  79. """
  80. # backbone
  81. out_features = self.backbone(input)
  82. features = [out_features[f] for f in self.in_features]
  83. [x2, x1, x0] = features
  84. fpn_out0 = self.lateral_conv0(x0) # 1024->512/32
  85. f_out0 = self.upsample(fpn_out0) # 512/16
  86. f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16
  87. f_out0 = self.C3_p4(f_out0) # 1024->512/16
  88. fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16
  89. f_out1 = self.upsample(fpn_out1) # 256/8
  90. f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8
  91. pan_out2 = self.C3_p3(f_out1) # 512->256/8
  92. p_out1 = self.bu_conv2(pan_out2) # 256->256/16
  93. p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16
  94. pan_out1 = self.C3_n3(p_out1) # 512->512/16
  95. p_out0 = self.bu_conv1(pan_out1) # 512->512/32
  96. p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32
  97. pan_out0 = self.C3_n4(p_out0) # 1024->1024/32
  98. outputs = (pan_out2, pan_out1, pan_out0)
  99. return outputs