Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolo_fpn.py 2.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch
  5. import torch.nn as nn
  6. from .darknet import Darknet
  7. from .network_blocks import BaseConv
  8. class YOLOFPN(nn.Module):
  9. """
  10. YOLOFPN module. Darknet 53 is the default backbone of this model.
  11. """
  12. def __init__(
  13. self,
  14. depth=53,
  15. in_features=["dark3", "dark4", "dark5"],
  16. ):
  17. super().__init__()
  18. self.backbone = Darknet(depth)
  19. self.in_features = in_features
  20. # out 1
  21. self.out1_cbl = self._make_cbl(512, 256, 1)
  22. self.out1 = self._make_embedding([256, 512], 512 + 256)
  23. # out 2
  24. self.out2_cbl = self._make_cbl(256, 128, 1)
  25. self.out2 = self._make_embedding([128, 256], 256 + 128)
  26. # upsample
  27. self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
  28. def _make_cbl(self, _in, _out, ks):
  29. return BaseConv(_in, _out, ks, stride=1, act="lrelu")
  30. def _make_embedding(self, filters_list, in_filters):
  31. m = nn.Sequential(
  32. *[
  33. self._make_cbl(in_filters, filters_list[0], 1),
  34. self._make_cbl(filters_list[0], filters_list[1], 3),
  35. self._make_cbl(filters_list[1], filters_list[0], 1),
  36. self._make_cbl(filters_list[0], filters_list[1], 3),
  37. self._make_cbl(filters_list[1], filters_list[0], 1),
  38. ]
  39. )
  40. return m
  41. def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"):
  42. with open(filename, "rb") as f:
  43. state_dict = torch.load(f, map_location="cpu")
  44. print("loading pretrained weights...")
  45. self.backbone.load_state_dict(state_dict)
  46. def forward(self, inputs):
  47. """
  48. Args:
  49. inputs (Tensor): input image.
  50. Returns:
  51. Tuple[Tensor]: FPN output features..
  52. """
  53. # backbone
  54. out_features = self.backbone(inputs)
  55. x2, x1, x0 = [out_features[f] for f in self.in_features]
  56. # yolo branch 1
  57. x1_in = self.out1_cbl(x0)
  58. x1_in = self.upsample(x1_in)
  59. x1_in = torch.cat([x1_in, x1], 1)
  60. out_dark4 = self.out1(x1_in)
  61. # yolo branch 2
  62. x2_in = self.out2_cbl(out_dark4)
  63. x2_in = self.upsample(x2_in)
  64. x2_in = torch.cat([x2_in, x2], 1)
  65. out_dark3 = self.out2(x2_in)
  66. outputs = (out_dark3, out_dark4, x0)
  67. return outputs