blazenet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. from paddle.nn.initializer import KaimingNormal
  22. from ppdet.core.workspace import register, serializable
  23. from ..shape_spec import ShapeSpec
  24. __all__ = ['BlazeNet']
  25. def hard_swish(x):
  26. return x * F.relu6(x + 3) / 6.
  27. class ConvBNLayer(nn.Layer):
  28. def __init__(self,
  29. in_channels,
  30. out_channels,
  31. kernel_size,
  32. stride,
  33. padding,
  34. num_groups=1,
  35. act='relu',
  36. conv_lr=0.1,
  37. conv_decay=0.,
  38. norm_decay=0.,
  39. norm_type='bn',
  40. name=None):
  41. super(ConvBNLayer, self).__init__()
  42. self.act = act
  43. self._conv = nn.Conv2D(
  44. in_channels,
  45. out_channels,
  46. kernel_size=kernel_size,
  47. stride=stride,
  48. padding=padding,
  49. groups=num_groups,
  50. weight_attr=ParamAttr(
  51. learning_rate=conv_lr, initializer=KaimingNormal()),
  52. bias_attr=False)
  53. if norm_type in ['bn', 'sync_bn']:
  54. self._batch_norm = nn.BatchNorm2D(out_channels)
  55. def forward(self, x):
  56. x = self._conv(x)
  57. x = self._batch_norm(x)
  58. if self.act == "relu":
  59. x = F.relu(x)
  60. elif self.act == "relu6":
  61. x = F.relu6(x)
  62. elif self.act == 'leaky':
  63. x = F.leaky_relu(x)
  64. elif self.act == 'hard_swish':
  65. x = hard_swish(x)
  66. return x
  67. class BlazeBlock(nn.Layer):
  68. def __init__(self,
  69. in_channels,
  70. out_channels1,
  71. out_channels2,
  72. double_channels=None,
  73. stride=1,
  74. use_5x5kernel=True,
  75. act='relu',
  76. name=None):
  77. super(BlazeBlock, self).__init__()
  78. assert stride in [1, 2]
  79. self.use_pool = not stride == 1
  80. self.use_double_block = double_channels is not None
  81. self.conv_dw = []
  82. if use_5x5kernel:
  83. self.conv_dw.append(
  84. self.add_sublayer(
  85. name + "1_dw",
  86. ConvBNLayer(
  87. in_channels=in_channels,
  88. out_channels=out_channels1,
  89. kernel_size=5,
  90. stride=stride,
  91. padding=2,
  92. num_groups=out_channels1,
  93. name=name + "1_dw")))
  94. else:
  95. self.conv_dw.append(
  96. self.add_sublayer(
  97. name + "1_dw_1",
  98. ConvBNLayer(
  99. in_channels=in_channels,
  100. out_channels=out_channels1,
  101. kernel_size=3,
  102. stride=1,
  103. padding=1,
  104. num_groups=out_channels1,
  105. name=name + "1_dw_1")))
  106. self.conv_dw.append(
  107. self.add_sublayer(
  108. name + "1_dw_2",
  109. ConvBNLayer(
  110. in_channels=out_channels1,
  111. out_channels=out_channels1,
  112. kernel_size=3,
  113. stride=stride,
  114. padding=1,
  115. num_groups=out_channels1,
  116. name=name + "1_dw_2")))
  117. self.act = act if self.use_double_block else None
  118. self.conv_pw = ConvBNLayer(
  119. in_channels=out_channels1,
  120. out_channels=out_channels2,
  121. kernel_size=1,
  122. stride=1,
  123. padding=0,
  124. act=self.act,
  125. name=name + "1_sep")
  126. if self.use_double_block:
  127. self.conv_dw2 = []
  128. if use_5x5kernel:
  129. self.conv_dw2.append(
  130. self.add_sublayer(
  131. name + "2_dw",
  132. ConvBNLayer(
  133. in_channels=out_channels2,
  134. out_channels=out_channels2,
  135. kernel_size=5,
  136. stride=1,
  137. padding=2,
  138. num_groups=out_channels2,
  139. name=name + "2_dw")))
  140. else:
  141. self.conv_dw2.append(
  142. self.add_sublayer(
  143. name + "2_dw_1",
  144. ConvBNLayer(
  145. in_channels=out_channels2,
  146. out_channels=out_channels2,
  147. kernel_size=3,
  148. stride=1,
  149. padding=1,
  150. num_groups=out_channels2,
  151. name=name + "1_dw_1")))
  152. self.conv_dw2.append(
  153. self.add_sublayer(
  154. name + "2_dw_2",
  155. ConvBNLayer(
  156. in_channels=out_channels2,
  157. out_channels=out_channels2,
  158. kernel_size=3,
  159. stride=1,
  160. padding=1,
  161. num_groups=out_channels2,
  162. name=name + "2_dw_2")))
  163. self.conv_pw2 = ConvBNLayer(
  164. in_channels=out_channels2,
  165. out_channels=double_channels,
  166. kernel_size=1,
  167. stride=1,
  168. padding=0,
  169. name=name + "2_sep")
  170. # shortcut
  171. if self.use_pool:
  172. shortcut_channel = double_channels or out_channels2
  173. self._shortcut = []
  174. self._shortcut.append(
  175. self.add_sublayer(
  176. name + '_shortcut_pool',
  177. nn.MaxPool2D(
  178. kernel_size=stride, stride=stride, ceil_mode=True)))
  179. self._shortcut.append(
  180. self.add_sublayer(
  181. name + '_shortcut_conv',
  182. ConvBNLayer(
  183. in_channels=in_channels,
  184. out_channels=shortcut_channel,
  185. kernel_size=1,
  186. stride=1,
  187. padding=0,
  188. name="shortcut" + name)))
  189. def forward(self, x):
  190. y = x
  191. for conv_dw_block in self.conv_dw:
  192. y = conv_dw_block(y)
  193. y = self.conv_pw(y)
  194. if self.use_double_block:
  195. for conv_dw2_block in self.conv_dw2:
  196. y = conv_dw2_block(y)
  197. y = self.conv_pw2(y)
  198. if self.use_pool:
  199. for shortcut in self._shortcut:
  200. x = shortcut(x)
  201. return F.relu(paddle.add(x, y))
  202. @register
  203. @serializable
  204. class BlazeNet(nn.Layer):
  205. """
  206. BlazeFace, see https://arxiv.org/abs/1907.05047
  207. Args:
  208. blaze_filters (list): number of filter for each blaze block.
  209. double_blaze_filters (list): number of filter for each double_blaze block.
  210. use_5x5kernel (bool): whether or not filter size is 5x5 in depth-wise conv.
  211. """
  212. def __init__(
  213. self,
  214. blaze_filters=[[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]],
  215. double_blaze_filters=[[48, 24, 96, 2], [96, 24, 96], [96, 24, 96],
  216. [96, 24, 96, 2], [96, 24, 96], [96, 24, 96]],
  217. use_5x5kernel=True,
  218. act=None):
  219. super(BlazeNet, self).__init__()
  220. conv1_num_filters = blaze_filters[0][0]
  221. self.conv1 = ConvBNLayer(
  222. in_channels=3,
  223. out_channels=conv1_num_filters,
  224. kernel_size=3,
  225. stride=2,
  226. padding=1,
  227. name="conv1")
  228. in_channels = conv1_num_filters
  229. self.blaze_block = []
  230. self._out_channels = []
  231. for k, v in enumerate(blaze_filters):
  232. assert len(v) in [2, 3], \
  233. "blaze_filters {} not in [2, 3]"
  234. if len(v) == 2:
  235. self.blaze_block.append(
  236. self.add_sublayer(
  237. 'blaze_{}'.format(k),
  238. BlazeBlock(
  239. in_channels,
  240. v[0],
  241. v[1],
  242. use_5x5kernel=use_5x5kernel,
  243. act=act,
  244. name='blaze_{}'.format(k))))
  245. elif len(v) == 3:
  246. self.blaze_block.append(
  247. self.add_sublayer(
  248. 'blaze_{}'.format(k),
  249. BlazeBlock(
  250. in_channels,
  251. v[0],
  252. v[1],
  253. stride=v[2],
  254. use_5x5kernel=use_5x5kernel,
  255. act=act,
  256. name='blaze_{}'.format(k))))
  257. in_channels = v[1]
  258. for k, v in enumerate(double_blaze_filters):
  259. assert len(v) in [3, 4], \
  260. "blaze_filters {} not in [3, 4]"
  261. if len(v) == 3:
  262. self.blaze_block.append(
  263. self.add_sublayer(
  264. 'double_blaze_{}'.format(k),
  265. BlazeBlock(
  266. in_channels,
  267. v[0],
  268. v[1],
  269. double_channels=v[2],
  270. use_5x5kernel=use_5x5kernel,
  271. act=act,
  272. name='double_blaze_{}'.format(k))))
  273. elif len(v) == 4:
  274. self.blaze_block.append(
  275. self.add_sublayer(
  276. 'double_blaze_{}'.format(k),
  277. BlazeBlock(
  278. in_channels,
  279. v[0],
  280. v[1],
  281. double_channels=v[2],
  282. stride=v[3],
  283. use_5x5kernel=use_5x5kernel,
  284. act=act,
  285. name='double_blaze_{}'.format(k))))
  286. in_channels = v[2]
  287. self._out_channels.append(in_channels)
  288. def forward(self, inputs):
  289. outs = []
  290. y = self.conv1(inputs['image'])
  291. for block in self.blaze_block:
  292. y = block(y)
  293. outs.append(y)
  294. return [outs[-4], outs[-1]]
  295. @property
  296. def out_shape(self):
  297. return [
  298. ShapeSpec(channels=c)
  299. for c in [self._out_channels[-4], self._out_channels[-1]]
  300. ]