csp_darknet.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle import ParamAttr
  18. from paddle.regularizer import L2Decay
  19. from ppdet.core.workspace import register, serializable
  20. from ppdet.modeling.initializer import conv_init_
  21. from ..shape_spec import ShapeSpec
  22. __all__ = [
  23. 'CSPDarkNet', 'BaseConv', 'DWConv', 'BottleNeck', 'SPPLayer', 'SPPFLayer'
  24. ]
  25. class BaseConv(nn.Layer):
  26. def __init__(self,
  27. in_channels,
  28. out_channels,
  29. ksize,
  30. stride,
  31. groups=1,
  32. bias=False,
  33. act="silu"):
  34. super(BaseConv, self).__init__()
  35. self.conv = nn.Conv2D(
  36. in_channels,
  37. out_channels,
  38. kernel_size=ksize,
  39. stride=stride,
  40. padding=(ksize - 1) // 2,
  41. groups=groups,
  42. bias_attr=bias)
  43. self.bn = nn.BatchNorm2D(
  44. out_channels,
  45. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  46. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  47. self._init_weights()
  48. def _init_weights(self):
  49. conv_init_(self.conv)
  50. def forward(self, x):
  51. # use 'x * F.sigmoid(x)' replace 'silu'
  52. x = self.bn(self.conv(x))
  53. y = x * F.sigmoid(x)
  54. return y
  55. class DWConv(nn.Layer):
  56. """Depthwise Conv"""
  57. def __init__(self,
  58. in_channels,
  59. out_channels,
  60. ksize,
  61. stride=1,
  62. bias=False,
  63. act="silu"):
  64. super(DWConv, self).__init__()
  65. self.dw_conv = BaseConv(
  66. in_channels,
  67. in_channels,
  68. ksize=ksize,
  69. stride=stride,
  70. groups=in_channels,
  71. bias=bias,
  72. act=act)
  73. self.pw_conv = BaseConv(
  74. in_channels,
  75. out_channels,
  76. ksize=1,
  77. stride=1,
  78. groups=1,
  79. bias=bias,
  80. act=act)
  81. def forward(self, x):
  82. return self.pw_conv(self.dw_conv(x))
  83. class Focus(nn.Layer):
  84. """Focus width and height information into channel space, used in YOLOX."""
  85. def __init__(self,
  86. in_channels,
  87. out_channels,
  88. ksize=3,
  89. stride=1,
  90. bias=False,
  91. act="silu"):
  92. super(Focus, self).__init__()
  93. self.conv = BaseConv(
  94. in_channels * 4,
  95. out_channels,
  96. ksize=ksize,
  97. stride=stride,
  98. bias=bias,
  99. act=act)
  100. def forward(self, inputs):
  101. # inputs [bs, C, H, W] -> outputs [bs, 4C, W/2, H/2]
  102. top_left = inputs[:, :, 0::2, 0::2]
  103. top_right = inputs[:, :, 0::2, 1::2]
  104. bottom_left = inputs[:, :, 1::2, 0::2]
  105. bottom_right = inputs[:, :, 1::2, 1::2]
  106. outputs = paddle.concat(
  107. [top_left, bottom_left, top_right, bottom_right], 1)
  108. return self.conv(outputs)
  109. class BottleNeck(nn.Layer):
  110. def __init__(self,
  111. in_channels,
  112. out_channels,
  113. shortcut=True,
  114. expansion=0.5,
  115. depthwise=False,
  116. bias=False,
  117. act="silu"):
  118. super(BottleNeck, self).__init__()
  119. hidden_channels = int(out_channels * expansion)
  120. Conv = DWConv if depthwise else BaseConv
  121. self.conv1 = BaseConv(
  122. in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
  123. self.conv2 = Conv(
  124. hidden_channels,
  125. out_channels,
  126. ksize=3,
  127. stride=1,
  128. bias=bias,
  129. act=act)
  130. self.add_shortcut = shortcut and in_channels == out_channels
  131. def forward(self, x):
  132. y = self.conv2(self.conv1(x))
  133. if self.add_shortcut:
  134. y = y + x
  135. return y
  136. class SPPLayer(nn.Layer):
  137. """Spatial Pyramid Pooling (SPP) layer used in YOLOv3-SPP and YOLOX"""
  138. def __init__(self,
  139. in_channels,
  140. out_channels,
  141. kernel_sizes=(5, 9, 13),
  142. bias=False,
  143. act="silu"):
  144. super(SPPLayer, self).__init__()
  145. hidden_channels = in_channels // 2
  146. self.conv1 = BaseConv(
  147. in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
  148. self.maxpoolings = nn.LayerList([
  149. nn.MaxPool2D(
  150. kernel_size=ks, stride=1, padding=ks // 2)
  151. for ks in kernel_sizes
  152. ])
  153. conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
  154. self.conv2 = BaseConv(
  155. conv2_channels, out_channels, ksize=1, stride=1, bias=bias, act=act)
  156. def forward(self, x):
  157. x = self.conv1(x)
  158. x = paddle.concat([x] + [mp(x) for mp in self.maxpoolings], axis=1)
  159. x = self.conv2(x)
  160. return x
  161. class SPPFLayer(nn.Layer):
  162. """ Spatial Pyramid Pooling - Fast (SPPF) layer used in YOLOv5 by Glenn Jocher,
  163. equivalent to SPP(k=(5, 9, 13))
  164. """
  165. def __init__(self,
  166. in_channels,
  167. out_channels,
  168. ksize=5,
  169. bias=False,
  170. act='silu'):
  171. super(SPPFLayer, self).__init__()
  172. hidden_channels = in_channels // 2
  173. self.conv1 = BaseConv(
  174. in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
  175. self.maxpooling = nn.MaxPool2D(
  176. kernel_size=ksize, stride=1, padding=ksize // 2)
  177. conv2_channels = hidden_channels * 4
  178. self.conv2 = BaseConv(
  179. conv2_channels, out_channels, ksize=1, stride=1, bias=bias, act=act)
  180. def forward(self, x):
  181. x = self.conv1(x)
  182. y1 = self.maxpooling(x)
  183. y2 = self.maxpooling(y1)
  184. y3 = self.maxpooling(y2)
  185. concats = paddle.concat([x, y1, y2, y3], axis=1)
  186. out = self.conv2(concats)
  187. return out
  188. class CSPLayer(nn.Layer):
  189. """CSP (Cross Stage Partial) layer with 3 convs, named C3 in YOLOv5"""
  190. def __init__(self,
  191. in_channels,
  192. out_channels,
  193. num_blocks=1,
  194. shortcut=True,
  195. expansion=0.5,
  196. depthwise=False,
  197. bias=False,
  198. act="silu"):
  199. super(CSPLayer, self).__init__()
  200. hidden_channels = int(out_channels * expansion)
  201. self.conv1 = BaseConv(
  202. in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
  203. self.conv2 = BaseConv(
  204. in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
  205. self.bottlenecks = nn.Sequential(* [
  206. BottleNeck(
  207. hidden_channels,
  208. hidden_channels,
  209. shortcut=shortcut,
  210. expansion=1.0,
  211. depthwise=depthwise,
  212. bias=bias,
  213. act=act) for _ in range(num_blocks)
  214. ])
  215. self.conv3 = BaseConv(
  216. hidden_channels * 2,
  217. out_channels,
  218. ksize=1,
  219. stride=1,
  220. bias=bias,
  221. act=act)
  222. def forward(self, x):
  223. x_1 = self.conv1(x)
  224. x_1 = self.bottlenecks(x_1)
  225. x_2 = self.conv2(x)
  226. x = paddle.concat([x_1, x_2], axis=1)
  227. x = self.conv3(x)
  228. return x
  229. @register
  230. @serializable
  231. class CSPDarkNet(nn.Layer):
  232. """
  233. CSPDarkNet backbone.
  234. Args:
  235. arch (str): Architecture of CSPDarkNet, from {P5, P6, X}, default as X,
  236. and 'X' means used in YOLOX, 'P5/P6' means used in YOLOv5.
  237. depth_mult (float): Depth multiplier, multiply number of channels in
  238. each layer, default as 1.0.
  239. width_mult (float): Width multiplier, multiply number of blocks in
  240. CSPLayer, default as 1.0.
  241. depthwise (bool): Whether to use depth-wise conv layer.
  242. act (str): Activation function type, default as 'silu'.
  243. return_idx (list): Index of stages whose feature maps are returned.
  244. """
  245. __shared__ = ['depth_mult', 'width_mult', 'act', 'trt']
  246. # in_channels, out_channels, num_blocks, add_shortcut, use_spp(use_sppf)
  247. # 'X' means setting used in YOLOX, 'P5/P6' means setting used in YOLOv5.
  248. arch_settings = {
  249. 'X': [[64, 128, 3, True, False], [128, 256, 9, True, False],
  250. [256, 512, 9, True, False], [512, 1024, 3, False, True]],
  251. 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
  252. [256, 512, 9, True, False], [512, 1024, 3, True, True]],
  253. 'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
  254. [256, 512, 9, True, False], [512, 768, 3, True, False],
  255. [768, 1024, 3, True, True]],
  256. }
  257. def __init__(self,
  258. arch='X',
  259. depth_mult=1.0,
  260. width_mult=1.0,
  261. depthwise=False,
  262. act='silu',
  263. trt=False,
  264. return_idx=[2, 3, 4]):
  265. super(CSPDarkNet, self).__init__()
  266. self.arch = arch
  267. self.return_idx = return_idx
  268. Conv = DWConv if depthwise else BaseConv
  269. arch_setting = self.arch_settings[arch]
  270. base_channels = int(arch_setting[0][0] * width_mult)
  271. # Note: differences between the latest YOLOv5 and the original YOLOX
  272. # 1. self.stem, use SPPF(in YOLOv5) or SPP(in YOLOX)
  273. # 2. use SPPF(in YOLOv5) or SPP(in YOLOX)
  274. # 3. put SPPF before(YOLOv5) or SPP after(YOLOX) the last cspdark block's CSPLayer
  275. # 4. whether SPPF(SPP)'CSPLayer add shortcut, True in YOLOv5, False in YOLOX
  276. if arch in ['P5', 'P6']:
  277. # in the latest YOLOv5, use Conv stem, and SPPF (fast, only single spp kernal size)
  278. self.stem = Conv(
  279. 3, base_channels, ksize=6, stride=2, bias=False, act=act)
  280. spp_kernal_sizes = 5
  281. elif arch in ['X']:
  282. # in the original YOLOX, use Focus stem, and SPP (three spp kernal sizes)
  283. self.stem = Focus(
  284. 3, base_channels, ksize=3, stride=1, bias=False, act=act)
  285. spp_kernal_sizes = (5, 9, 13)
  286. else:
  287. raise AttributeError("Unsupported arch type: {}".format(arch))
  288. _out_channels = [base_channels]
  289. layers_num = 1
  290. self.csp_dark_blocks = []
  291. for i, (in_channels, out_channels, num_blocks, shortcut,
  292. use_spp) in enumerate(arch_setting):
  293. in_channels = int(in_channels * width_mult)
  294. out_channels = int(out_channels * width_mult)
  295. _out_channels.append(out_channels)
  296. num_blocks = max(round(num_blocks * depth_mult), 1)
  297. stage = []
  298. conv_layer = self.add_sublayer(
  299. 'layers{}.stage{}.conv_layer'.format(layers_num, i + 1),
  300. Conv(
  301. in_channels, out_channels, 3, 2, bias=False, act=act))
  302. stage.append(conv_layer)
  303. layers_num += 1
  304. if use_spp and arch in ['X']:
  305. # in YOLOX use SPPLayer
  306. spp_layer = self.add_sublayer(
  307. 'layers{}.stage{}.spp_layer'.format(layers_num, i + 1),
  308. SPPLayer(
  309. out_channels,
  310. out_channels,
  311. kernel_sizes=spp_kernal_sizes,
  312. bias=False,
  313. act=act))
  314. stage.append(spp_layer)
  315. layers_num += 1
  316. csp_layer = self.add_sublayer(
  317. 'layers{}.stage{}.csp_layer'.format(layers_num, i + 1),
  318. CSPLayer(
  319. out_channels,
  320. out_channels,
  321. num_blocks=num_blocks,
  322. shortcut=shortcut,
  323. depthwise=depthwise,
  324. bias=False,
  325. act=act))
  326. stage.append(csp_layer)
  327. layers_num += 1
  328. if use_spp and arch in ['P5', 'P6']:
  329. # in latest YOLOv5 use SPPFLayer instead of SPPLayer
  330. sppf_layer = self.add_sublayer(
  331. 'layers{}.stage{}.sppf_layer'.format(layers_num, i + 1),
  332. SPPFLayer(
  333. out_channels,
  334. out_channels,
  335. ksize=5,
  336. bias=False,
  337. act=act))
  338. stage.append(sppf_layer)
  339. layers_num += 1
  340. self.csp_dark_blocks.append(nn.Sequential(*stage))
  341. self._out_channels = [_out_channels[i] for i in self.return_idx]
  342. self.strides = [[2, 4, 8, 16, 32, 64][i] for i in self.return_idx]
  343. def forward(self, inputs):
  344. x = inputs['image']
  345. outputs = []
  346. x = self.stem(x)
  347. for i, layer in enumerate(self.csp_dark_blocks):
  348. x = layer(x)
  349. if i + 1 in self.return_idx:
  350. outputs.append(x)
  351. return outputs
  352. @property
  353. def out_shape(self):
  354. return [
  355. ShapeSpec(
  356. channels=c, stride=s)
  357. for c, s in zip(self._out_channels, self.strides)
  358. ]