Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_blocks.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch
  5. import torch.nn as nn
  6. class SiLU(nn.Module):
  7. """export-friendly version of nn.SiLU()"""
  8. @staticmethod
  9. def forward(x):
  10. return x * torch.sigmoid(x)
  11. def get_activation(name="silu", inplace=True):
  12. if name == "silu":
  13. module = nn.SiLU(inplace=inplace)
  14. elif name == "relu":
  15. module = nn.ReLU(inplace=inplace)
  16. elif name == "lrelu":
  17. module = nn.LeakyReLU(0.1, inplace=inplace)
  18. else:
  19. raise AttributeError("Unsupported act type: {}".format(name))
  20. return module
  21. class BaseConv(nn.Module):
  22. """A Conv2d -> Batchnorm -> silu/leaky relu block"""
  23. def __init__(
  24. self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"
  25. ):
  26. super().__init__()
  27. # same padding
  28. pad = (ksize - 1) // 2
  29. self.conv = nn.Conv2d(
  30. in_channels,
  31. out_channels,
  32. kernel_size=ksize,
  33. stride=stride,
  34. padding=pad,
  35. groups=groups,
  36. bias=bias,
  37. )
  38. self.bn = nn.BatchNorm2d(out_channels)
  39. self.act = get_activation(act, inplace=True)
  40. def forward(self, x):
  41. return self.act(self.bn(self.conv(x)))
  42. def fuseforward(self, x):
  43. return self.act(self.conv(x))
  44. class DWConv(nn.Module):
  45. """Depthwise Conv + Conv"""
  46. def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
  47. super().__init__()
  48. self.dconv = BaseConv(
  49. in_channels,
  50. in_channels,
  51. ksize=ksize,
  52. stride=stride,
  53. groups=in_channels,
  54. act=act,
  55. )
  56. self.pconv = BaseConv(
  57. in_channels, out_channels, ksize=1, stride=1, groups=1, act=act
  58. )
  59. def forward(self, x):
  60. x = self.dconv(x)
  61. return self.pconv(x)
  62. class Bottleneck(nn.Module):
  63. # Standard bottleneck
  64. def __init__(
  65. self,
  66. in_channels,
  67. out_channels,
  68. shortcut=True,
  69. expansion=0.5,
  70. depthwise=False,
  71. act="silu",
  72. ):
  73. super().__init__()
  74. hidden_channels = int(out_channels * expansion)
  75. Conv = DWConv if depthwise else BaseConv
  76. self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
  77. self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
  78. self.use_add = shortcut and in_channels == out_channels
  79. def forward(self, x):
  80. y = self.conv2(self.conv1(x))
  81. if self.use_add:
  82. y = y + x
  83. return y
  84. class ResLayer(nn.Module):
  85. "Residual layer with `in_channels` inputs."
  86. def __init__(self, in_channels: int):
  87. super().__init__()
  88. mid_channels = in_channels // 2
  89. self.layer1 = BaseConv(
  90. in_channels, mid_channels, ksize=1, stride=1, act="lrelu"
  91. )
  92. self.layer2 = BaseConv(
  93. mid_channels, in_channels, ksize=3, stride=1, act="lrelu"
  94. )
  95. def forward(self, x):
  96. out = self.layer2(self.layer1(x))
  97. return x + out
  98. class SPPBottleneck(nn.Module):
  99. """Spatial pyramid pooling layer used in YOLOv3-SPP"""
  100. def __init__(
  101. self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"
  102. ):
  103. super().__init__()
  104. hidden_channels = in_channels // 2
  105. self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
  106. self.m = nn.ModuleList(
  107. [
  108. nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
  109. for ks in kernel_sizes
  110. ]
  111. )
  112. conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
  113. self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
  114. def forward(self, x):
  115. x = self.conv1(x)
  116. x = torch.cat([x] + [m(x) for m in self.m], dim=1)
  117. x = self.conv2(x)
  118. return x
  119. class CSPLayer(nn.Module):
  120. """C3 in yolov5, CSP Bottleneck with 3 convolutions"""
  121. def __init__(
  122. self,
  123. in_channels,
  124. out_channels,
  125. n=1,
  126. shortcut=True,
  127. expansion=0.5,
  128. depthwise=False,
  129. act="silu",
  130. ):
  131. """
  132. Args:
  133. in_channels (int): input channels.
  134. out_channels (int): output channels.
  135. n (int): number of Bottlenecks. Default value: 1.
  136. """
  137. # ch_in, ch_out, number, shortcut, groups, expansion
  138. super().__init__()
  139. hidden_channels = int(out_channels * expansion) # hidden channels
  140. self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
  141. self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
  142. self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
  143. module_list = [
  144. Bottleneck(
  145. hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act
  146. )
  147. for _ in range(n)
  148. ]
  149. self.m = nn.Sequential(*module_list)
  150. def forward(self, x):
  151. x_1 = self.conv1(x)
  152. x_2 = self.conv2(x)
  153. x_1 = self.m(x_1)
  154. x = torch.cat((x_1, x_2), dim=1)
  155. return self.conv3(x)
  156. class Focus(nn.Module):
  157. """Focus width and height information into channel space."""
  158. def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
  159. super().__init__()
  160. self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
  161. def forward(self, x):
  162. # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
  163. patch_top_left = x[..., ::2, ::2]
  164. patch_top_right = x[..., ::2, 1::2]
  165. patch_bot_left = x[..., 1::2, ::2]
  166. patch_bot_right = x[..., 1::2, 1::2]
  167. x = torch.cat(
  168. (
  169. patch_top_left,
  170. patch_bot_left,
  171. patch_top_right,
  172. patch_bot_right,
  173. ),
  174. dim=1,
  175. )
  176. return self.conv(x)