blazenet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from paddle import fluid
  18. from paddle.fluid.param_attr import ParamAttr
  19. from ppdet.experimental import mixed_precision_global_state
  20. from ppdet.core.workspace import register
  21. __all__ = ['BlazeNet']
  22. @register
  23. class BlazeNet(object):
  24. """
  25. BlazeFace, see https://arxiv.org/abs/1907.05047
  26. Args:
  27. blaze_filters (list): number of filter for each blaze block
  28. double_blaze_filters (list): number of filter for each double_blaze block
  29. with_extra_blocks (bool): whether or not extra blocks should be added
  30. lite_edition (bool): whether or not is blazeface-lite
  31. use_5x5kernel (bool): whether or not filter size is 5x5 in depth-wise conv
  32. """
  33. def __init__(
  34. self,
  35. blaze_filters=[[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]],
  36. double_blaze_filters=[[48, 24, 96, 2], [96, 24, 96], [96, 24, 96],
  37. [96, 24, 96, 2], [96, 24, 96], [96, 24, 96]],
  38. with_extra_blocks=True,
  39. lite_edition=False,
  40. use_5x5kernel=True):
  41. super(BlazeNet, self).__init__()
  42. self.blaze_filters = blaze_filters
  43. self.double_blaze_filters = double_blaze_filters
  44. self.with_extra_blocks = with_extra_blocks
  45. self.lite_edition = lite_edition
  46. self.use_5x5kernel = use_5x5kernel
  47. def __call__(self, input):
  48. if not self.lite_edition:
  49. conv1_num_filters = self.blaze_filters[0][0]
  50. conv = self._conv_norm(
  51. input=input,
  52. num_filters=conv1_num_filters,
  53. filter_size=3,
  54. stride=2,
  55. padding=1,
  56. act='relu',
  57. name="conv1")
  58. for k, v in enumerate(self.blaze_filters):
  59. assert len(v) in [2, 3], \
  60. "blaze_filters {} not in [2, 3]"
  61. if len(v) == 2:
  62. conv = self.BlazeBlock(
  63. conv,
  64. v[0],
  65. v[1],
  66. use_5x5kernel=self.use_5x5kernel,
  67. name='blaze_{}'.format(k))
  68. elif len(v) == 3:
  69. conv = self.BlazeBlock(
  70. conv,
  71. v[0],
  72. v[1],
  73. stride=v[2],
  74. use_5x5kernel=self.use_5x5kernel,
  75. name='blaze_{}'.format(k))
  76. layers = []
  77. for k, v in enumerate(self.double_blaze_filters):
  78. assert len(v) in [3, 4], \
  79. "blaze_filters {} not in [3, 4]"
  80. if len(v) == 3:
  81. conv = self.BlazeBlock(
  82. conv,
  83. v[0],
  84. v[1],
  85. double_channels=v[2],
  86. use_5x5kernel=self.use_5x5kernel,
  87. name='double_blaze_{}'.format(k))
  88. elif len(v) == 4:
  89. layers.append(conv)
  90. conv = self.BlazeBlock(
  91. conv,
  92. v[0],
  93. v[1],
  94. double_channels=v[2],
  95. stride=v[3],
  96. use_5x5kernel=self.use_5x5kernel,
  97. name='double_blaze_{}'.format(k))
  98. layers.append(conv)
  99. if not self.with_extra_blocks:
  100. return layers[-1]
  101. return layers[-2], layers[-1]
  102. else:
  103. conv1 = self._conv_norm(
  104. input=input,
  105. num_filters=24,
  106. filter_size=5,
  107. stride=2,
  108. padding=2,
  109. act='relu',
  110. name="conv1")
  111. conv2 = self.Blaze_lite(conv1, 24, 24, 1, 'conv2')
  112. conv3 = self.Blaze_lite(conv2, 24, 28, 1, 'conv3')
  113. conv4 = self.Blaze_lite(conv3, 28, 32, 2, 'conv4')
  114. conv5 = self.Blaze_lite(conv4, 32, 36, 1, 'conv5')
  115. conv6 = self.Blaze_lite(conv5, 36, 42, 1, 'conv6')
  116. conv7 = self.Blaze_lite(conv6, 42, 48, 2, 'conv7')
  117. in_ch = 48
  118. for i in range(5):
  119. conv7 = self.Blaze_lite(conv7, in_ch, in_ch + 8, 1,
  120. 'conv{}'.format(8 + i))
  121. in_ch += 8
  122. assert in_ch == 88
  123. conv13 = self.Blaze_lite(conv7, 88, 96, 2, 'conv13')
  124. for i in range(4):
  125. conv13 = self.Blaze_lite(conv13, 96, 96, 1,
  126. 'conv{}'.format(14 + i))
  127. return conv7, conv13
  128. def BlazeBlock(self,
  129. input,
  130. in_channels,
  131. out_channels,
  132. double_channels=None,
  133. stride=1,
  134. use_5x5kernel=True,
  135. name=None):
  136. assert stride in [1, 2]
  137. use_pool = not stride == 1
  138. use_double_block = double_channels is not None
  139. act = 'relu' if use_double_block else None
  140. mixed_precision_enabled = mixed_precision_global_state() is not None
  141. if use_5x5kernel:
  142. conv_dw = self._conv_norm(
  143. input=input,
  144. filter_size=5,
  145. num_filters=in_channels,
  146. stride=stride,
  147. padding=2,
  148. num_groups=in_channels,
  149. use_cudnn=mixed_precision_enabled,
  150. name=name + "1_dw")
  151. else:
  152. conv_dw_1 = self._conv_norm(
  153. input=input,
  154. filter_size=3,
  155. num_filters=in_channels,
  156. stride=1,
  157. padding=1,
  158. num_groups=in_channels,
  159. use_cudnn=mixed_precision_enabled,
  160. name=name + "1_dw_1")
  161. conv_dw = self._conv_norm(
  162. input=conv_dw_1,
  163. filter_size=3,
  164. num_filters=in_channels,
  165. stride=stride,
  166. padding=1,
  167. num_groups=in_channels,
  168. use_cudnn=mixed_precision_enabled,
  169. name=name + "1_dw_2")
  170. conv_pw = self._conv_norm(
  171. input=conv_dw,
  172. filter_size=1,
  173. num_filters=out_channels,
  174. stride=1,
  175. padding=0,
  176. act=act,
  177. name=name + "1_sep")
  178. if use_double_block:
  179. if use_5x5kernel:
  180. conv_dw = self._conv_norm(
  181. input=conv_pw,
  182. filter_size=5,
  183. num_filters=out_channels,
  184. stride=1,
  185. padding=2,
  186. use_cudnn=mixed_precision_enabled,
  187. name=name + "2_dw")
  188. else:
  189. conv_dw_1 = self._conv_norm(
  190. input=conv_pw,
  191. filter_size=3,
  192. num_filters=out_channels,
  193. stride=1,
  194. padding=1,
  195. num_groups=out_channels,
  196. use_cudnn=mixed_precision_enabled,
  197. name=name + "2_dw_1")
  198. conv_dw = self._conv_norm(
  199. input=conv_dw_1,
  200. filter_size=3,
  201. num_filters=out_channels,
  202. stride=1,
  203. padding=1,
  204. num_groups=out_channels,
  205. use_cudnn=mixed_precision_enabled,
  206. name=name + "2_dw_2")
  207. conv_pw = self._conv_norm(
  208. input=conv_dw,
  209. filter_size=1,
  210. num_filters=double_channels,
  211. stride=1,
  212. padding=0,
  213. name=name + "2_sep")
  214. # shortcut
  215. if use_pool:
  216. shortcut_channel = double_channels or out_channels
  217. shortcut_pool = self._pooling_block(input, stride, stride)
  218. channel_pad = self._conv_norm(
  219. input=shortcut_pool,
  220. filter_size=1,
  221. num_filters=shortcut_channel,
  222. stride=1,
  223. padding=0,
  224. name="shortcut" + name)
  225. return fluid.layers.elementwise_add(
  226. x=channel_pad, y=conv_pw, act='relu')
  227. return fluid.layers.elementwise_add(x=input, y=conv_pw, act='relu')
  228. def Blaze_lite(self, input, in_channels, out_channels, stride=1, name=None):
  229. assert stride in [1, 2]
  230. use_pool = not stride == 1
  231. ues_pad = not in_channels == out_channels
  232. conv_dw = self._conv_norm(
  233. input=input,
  234. filter_size=3,
  235. num_filters=in_channels,
  236. stride=stride,
  237. padding=1,
  238. num_groups=in_channels,
  239. name=name + "_dw")
  240. conv_pw = self._conv_norm(
  241. input=conv_dw,
  242. filter_size=1,
  243. num_filters=out_channels,
  244. stride=1,
  245. padding=0,
  246. name=name + "_sep")
  247. if use_pool:
  248. shortcut_pool = self._pooling_block(input, stride, stride)
  249. if ues_pad:
  250. conv_pad = shortcut_pool if use_pool else input
  251. channel_pad = self._conv_norm(
  252. input=conv_pad,
  253. filter_size=1,
  254. num_filters=out_channels,
  255. stride=1,
  256. padding=0,
  257. name="shortcut" + name)
  258. return fluid.layers.elementwise_add(
  259. x=channel_pad, y=conv_pw, act='relu')
  260. return fluid.layers.elementwise_add(x=input, y=conv_pw, act='relu')
  261. def _conv_norm(
  262. self,
  263. input,
  264. filter_size,
  265. num_filters,
  266. stride,
  267. padding,
  268. num_groups=1,
  269. act='relu', # None
  270. use_cudnn=True,
  271. name=None):
  272. parameter_attr = ParamAttr(
  273. learning_rate=0.1,
  274. initializer=fluid.initializer.MSRA(),
  275. name=name + "_weights")
  276. conv = fluid.layers.conv2d(
  277. input=input,
  278. num_filters=num_filters,
  279. filter_size=filter_size,
  280. stride=stride,
  281. padding=padding,
  282. groups=num_groups,
  283. act=None,
  284. use_cudnn=use_cudnn,
  285. param_attr=parameter_attr,
  286. bias_attr=False)
  287. return fluid.layers.batch_norm(input=conv, act=act)
  288. def _pooling_block(self,
  289. conv,
  290. pool_size,
  291. pool_stride,
  292. pool_padding=0,
  293. ceil_mode=True):
  294. pool = fluid.layers.pool2d(
  295. input=conv,
  296. pool_size=pool_size,
  297. pool_type='max',
  298. pool_stride=pool_stride,
  299. pool_padding=pool_padding,
  300. ceil_mode=ceil_mode)
  301. return pool