ghostnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. from paddle import ParamAttr
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddle.nn import AdaptiveAvgPool2D, Linear
  20. from paddle.nn.initializer import Uniform
  21. from ppdet.core.workspace import register, serializable
  22. from numbers import Integral
  23. from ..shape_spec import ShapeSpec
  24. from .mobilenet_v3 import make_divisible, ConvBNLayer
  25. __all__ = ['GhostNet']
  26. class ExtraBlockDW(nn.Layer):
  27. def __init__(self,
  28. in_c,
  29. ch_1,
  30. ch_2,
  31. stride,
  32. lr_mult,
  33. conv_decay=0.,
  34. norm_type='bn',
  35. norm_decay=0.,
  36. freeze_norm=False,
  37. name=None):
  38. super(ExtraBlockDW, self).__init__()
  39. self.pointwise_conv = ConvBNLayer(
  40. in_c=in_c,
  41. out_c=ch_1,
  42. filter_size=1,
  43. stride=1,
  44. padding=0,
  45. act='relu6',
  46. lr_mult=lr_mult,
  47. conv_decay=conv_decay,
  48. norm_type=norm_type,
  49. norm_decay=norm_decay,
  50. freeze_norm=freeze_norm,
  51. name=name + "_extra1")
  52. self.depthwise_conv = ConvBNLayer(
  53. in_c=ch_1,
  54. out_c=ch_2,
  55. filter_size=3,
  56. stride=stride,
  57. padding=1, #
  58. num_groups=int(ch_1),
  59. act='relu6',
  60. lr_mult=lr_mult,
  61. conv_decay=conv_decay,
  62. norm_type=norm_type,
  63. norm_decay=norm_decay,
  64. freeze_norm=freeze_norm,
  65. name=name + "_extra2_dw")
  66. self.normal_conv = ConvBNLayer(
  67. in_c=ch_2,
  68. out_c=ch_2,
  69. filter_size=1,
  70. stride=1,
  71. padding=0,
  72. act='relu6',
  73. lr_mult=lr_mult,
  74. conv_decay=conv_decay,
  75. norm_type=norm_type,
  76. norm_decay=norm_decay,
  77. freeze_norm=freeze_norm,
  78. name=name + "_extra2_sep")
  79. def forward(self, inputs):
  80. x = self.pointwise_conv(inputs)
  81. x = self.depthwise_conv(x)
  82. x = self.normal_conv(x)
  83. return x
  84. class SEBlock(nn.Layer):
  85. def __init__(self, num_channels, lr_mult, reduction_ratio=4, name=None):
  86. super(SEBlock, self).__init__()
  87. self.pool2d_gap = AdaptiveAvgPool2D(1)
  88. self._num_channels = num_channels
  89. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  90. med_ch = num_channels // reduction_ratio
  91. self.squeeze = Linear(
  92. num_channels,
  93. med_ch,
  94. weight_attr=ParamAttr(
  95. learning_rate=lr_mult, initializer=Uniform(-stdv, stdv)),
  96. bias_attr=ParamAttr(learning_rate=lr_mult))
  97. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  98. self.excitation = Linear(
  99. med_ch,
  100. num_channels,
  101. weight_attr=ParamAttr(
  102. learning_rate=lr_mult, initializer=Uniform(-stdv, stdv)),
  103. bias_attr=ParamAttr(learning_rate=lr_mult))
  104. def forward(self, inputs):
  105. pool = self.pool2d_gap(inputs)
  106. pool = paddle.squeeze(pool, axis=[2, 3])
  107. squeeze = self.squeeze(pool)
  108. squeeze = F.relu(squeeze)
  109. excitation = self.excitation(squeeze)
  110. excitation = paddle.clip(x=excitation, min=0, max=1)
  111. excitation = paddle.unsqueeze(excitation, axis=[2, 3])
  112. out = paddle.multiply(inputs, excitation)
  113. return out
  114. class GhostModule(nn.Layer):
  115. def __init__(self,
  116. in_channels,
  117. output_channels,
  118. kernel_size=1,
  119. ratio=2,
  120. dw_size=3,
  121. stride=1,
  122. relu=True,
  123. lr_mult=1.,
  124. conv_decay=0.,
  125. norm_type='bn',
  126. norm_decay=0.,
  127. freeze_norm=False,
  128. name=None):
  129. super(GhostModule, self).__init__()
  130. init_channels = int(math.ceil(output_channels / ratio))
  131. new_channels = int(init_channels * (ratio - 1))
  132. self.primary_conv = ConvBNLayer(
  133. in_c=in_channels,
  134. out_c=init_channels,
  135. filter_size=kernel_size,
  136. stride=stride,
  137. padding=int((kernel_size - 1) // 2),
  138. num_groups=1,
  139. act="relu" if relu else None,
  140. lr_mult=lr_mult,
  141. conv_decay=conv_decay,
  142. norm_type=norm_type,
  143. norm_decay=norm_decay,
  144. freeze_norm=freeze_norm,
  145. name=name + "_primary_conv")
  146. self.cheap_operation = ConvBNLayer(
  147. in_c=init_channels,
  148. out_c=new_channels,
  149. filter_size=dw_size,
  150. stride=1,
  151. padding=int((dw_size - 1) // 2),
  152. num_groups=init_channels,
  153. act="relu" if relu else None,
  154. lr_mult=lr_mult,
  155. conv_decay=conv_decay,
  156. norm_type=norm_type,
  157. norm_decay=norm_decay,
  158. freeze_norm=freeze_norm,
  159. name=name + "_cheap_operation")
  160. def forward(self, inputs):
  161. x = self.primary_conv(inputs)
  162. y = self.cheap_operation(x)
  163. out = paddle.concat([x, y], axis=1)
  164. return out
  165. class GhostBottleneck(nn.Layer):
  166. def __init__(self,
  167. in_channels,
  168. hidden_dim,
  169. output_channels,
  170. kernel_size,
  171. stride,
  172. use_se,
  173. lr_mult,
  174. conv_decay=0.,
  175. norm_type='bn',
  176. norm_decay=0.,
  177. freeze_norm=False,
  178. return_list=False,
  179. name=None):
  180. super(GhostBottleneck, self).__init__()
  181. self._stride = stride
  182. self._use_se = use_se
  183. self._num_channels = in_channels
  184. self._output_channels = output_channels
  185. self.return_list = return_list
  186. self.ghost_module_1 = GhostModule(
  187. in_channels=in_channels,
  188. output_channels=hidden_dim,
  189. kernel_size=1,
  190. stride=1,
  191. relu=True,
  192. lr_mult=lr_mult,
  193. conv_decay=conv_decay,
  194. norm_type=norm_type,
  195. norm_decay=norm_decay,
  196. freeze_norm=freeze_norm,
  197. name=name + "_ghost_module_1")
  198. if stride == 2:
  199. self.depthwise_conv = ConvBNLayer(
  200. in_c=hidden_dim,
  201. out_c=hidden_dim,
  202. filter_size=kernel_size,
  203. stride=stride,
  204. padding=int((kernel_size - 1) // 2),
  205. num_groups=hidden_dim,
  206. act=None,
  207. lr_mult=lr_mult,
  208. conv_decay=conv_decay,
  209. norm_type=norm_type,
  210. norm_decay=norm_decay,
  211. freeze_norm=freeze_norm,
  212. name=name +
  213. "_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
  214. )
  215. if use_se:
  216. self.se_block = SEBlock(hidden_dim, lr_mult, name=name + "_se")
  217. self.ghost_module_2 = GhostModule(
  218. in_channels=hidden_dim,
  219. output_channels=output_channels,
  220. kernel_size=1,
  221. relu=False,
  222. lr_mult=lr_mult,
  223. conv_decay=conv_decay,
  224. norm_type=norm_type,
  225. norm_decay=norm_decay,
  226. freeze_norm=freeze_norm,
  227. name=name + "_ghost_module_2")
  228. if stride != 1 or in_channels != output_channels:
  229. self.shortcut_depthwise = ConvBNLayer(
  230. in_c=in_channels,
  231. out_c=in_channels,
  232. filter_size=kernel_size,
  233. stride=stride,
  234. padding=int((kernel_size - 1) // 2),
  235. num_groups=in_channels,
  236. act=None,
  237. lr_mult=lr_mult,
  238. conv_decay=conv_decay,
  239. norm_type=norm_type,
  240. norm_decay=norm_decay,
  241. freeze_norm=freeze_norm,
  242. name=name +
  243. "_shortcut_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
  244. )
  245. self.shortcut_conv = ConvBNLayer(
  246. in_c=in_channels,
  247. out_c=output_channels,
  248. filter_size=1,
  249. stride=1,
  250. padding=0,
  251. num_groups=1,
  252. act=None,
  253. lr_mult=lr_mult,
  254. conv_decay=conv_decay,
  255. norm_type=norm_type,
  256. norm_decay=norm_decay,
  257. freeze_norm=freeze_norm,
  258. name=name + "_shortcut_conv")
  259. def forward(self, inputs):
  260. y = self.ghost_module_1(inputs)
  261. x = y
  262. if self._stride == 2:
  263. x = self.depthwise_conv(x)
  264. if self._use_se:
  265. x = self.se_block(x)
  266. x = self.ghost_module_2(x)
  267. if self._stride == 1 and self._num_channels == self._output_channels:
  268. shortcut = inputs
  269. else:
  270. shortcut = self.shortcut_depthwise(inputs)
  271. shortcut = self.shortcut_conv(shortcut)
  272. x = paddle.add(x=x, y=shortcut)
  273. if self.return_list:
  274. return [y, x]
  275. else:
  276. return x
  277. @register
  278. @serializable
  279. class GhostNet(nn.Layer):
  280. __shared__ = ['norm_type']
  281. def __init__(
  282. self,
  283. scale=1.3,
  284. feature_maps=[6, 12, 15],
  285. with_extra_blocks=False,
  286. extra_block_filters=[[256, 512], [128, 256], [128, 256], [64, 128]],
  287. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
  288. conv_decay=0.,
  289. norm_type='bn',
  290. norm_decay=0.0,
  291. freeze_norm=False):
  292. super(GhostNet, self).__init__()
  293. if isinstance(feature_maps, Integral):
  294. feature_maps = [feature_maps]
  295. if norm_type == 'sync_bn' and freeze_norm:
  296. raise ValueError(
  297. "The norm_type should not be sync_bn when freeze_norm is True")
  298. self.feature_maps = feature_maps
  299. self.with_extra_blocks = with_extra_blocks
  300. self.extra_block_filters = extra_block_filters
  301. inplanes = 16
  302. self.cfgs = [
  303. # k, t, c, SE, s
  304. [3, 16, 16, 0, 1],
  305. [3, 48, 24, 0, 2],
  306. [3, 72, 24, 0, 1],
  307. [5, 72, 40, 1, 2],
  308. [5, 120, 40, 1, 1],
  309. [3, 240, 80, 0, 2],
  310. [3, 200, 80, 0, 1],
  311. [3, 184, 80, 0, 1],
  312. [3, 184, 80, 0, 1],
  313. [3, 480, 112, 1, 1],
  314. [3, 672, 112, 1, 1],
  315. [5, 672, 160, 1, 2], # SSDLite output
  316. [5, 960, 160, 0, 1],
  317. [5, 960, 160, 1, 1],
  318. [5, 960, 160, 0, 1],
  319. [5, 960, 160, 1, 1]
  320. ]
  321. self.scale = scale
  322. conv1_out_ch = int(make_divisible(inplanes * self.scale, 4))
  323. self.conv1 = ConvBNLayer(
  324. in_c=3,
  325. out_c=conv1_out_ch,
  326. filter_size=3,
  327. stride=2,
  328. padding=1,
  329. num_groups=1,
  330. act="relu",
  331. lr_mult=1.,
  332. conv_decay=conv_decay,
  333. norm_type=norm_type,
  334. norm_decay=norm_decay,
  335. freeze_norm=freeze_norm,
  336. name="conv1")
  337. # build inverted residual blocks
  338. self._out_channels = []
  339. self.ghost_bottleneck_list = []
  340. idx = 0
  341. inplanes = conv1_out_ch
  342. for k, exp_size, c, use_se, s in self.cfgs:
  343. lr_idx = min(idx // 3, len(lr_mult_list) - 1)
  344. lr_mult = lr_mult_list[lr_idx]
  345. # for SSD/SSDLite, first head input is after ResidualUnit expand_conv
  346. return_list = self.with_extra_blocks and idx + 2 in self.feature_maps
  347. ghost_bottleneck = self.add_sublayer(
  348. "_ghostbottleneck_" + str(idx),
  349. sublayer=GhostBottleneck(
  350. in_channels=inplanes,
  351. hidden_dim=int(make_divisible(exp_size * self.scale, 4)),
  352. output_channels=int(make_divisible(c * self.scale, 4)),
  353. kernel_size=k,
  354. stride=s,
  355. use_se=use_se,
  356. lr_mult=lr_mult,
  357. conv_decay=conv_decay,
  358. norm_type=norm_type,
  359. norm_decay=norm_decay,
  360. freeze_norm=freeze_norm,
  361. return_list=return_list,
  362. name="_ghostbottleneck_" + str(idx)))
  363. self.ghost_bottleneck_list.append(ghost_bottleneck)
  364. inplanes = int(make_divisible(c * self.scale, 4))
  365. idx += 1
  366. self._update_out_channels(
  367. int(make_divisible(exp_size * self.scale, 4))
  368. if return_list else inplanes, idx + 1, feature_maps)
  369. if self.with_extra_blocks:
  370. self.extra_block_list = []
  371. extra_out_c = int(make_divisible(self.scale * self.cfgs[-1][1], 4))
  372. lr_idx = min(idx // 3, len(lr_mult_list) - 1)
  373. lr_mult = lr_mult_list[lr_idx]
  374. conv_extra = self.add_sublayer(
  375. "conv" + str(idx + 2),
  376. sublayer=ConvBNLayer(
  377. in_c=inplanes,
  378. out_c=extra_out_c,
  379. filter_size=1,
  380. stride=1,
  381. padding=0,
  382. num_groups=1,
  383. act="relu6",
  384. lr_mult=lr_mult,
  385. conv_decay=conv_decay,
  386. norm_type=norm_type,
  387. norm_decay=norm_decay,
  388. freeze_norm=freeze_norm,
  389. name="conv" + str(idx + 2)))
  390. self.extra_block_list.append(conv_extra)
  391. idx += 1
  392. self._update_out_channels(extra_out_c, idx + 1, feature_maps)
  393. for j, block_filter in enumerate(self.extra_block_filters):
  394. in_c = extra_out_c if j == 0 else self.extra_block_filters[j -
  395. 1][1]
  396. conv_extra = self.add_sublayer(
  397. "conv" + str(idx + 2),
  398. sublayer=ExtraBlockDW(
  399. in_c,
  400. block_filter[0],
  401. block_filter[1],
  402. stride=2,
  403. lr_mult=lr_mult,
  404. conv_decay=conv_decay,
  405. norm_type=norm_type,
  406. norm_decay=norm_decay,
  407. freeze_norm=freeze_norm,
  408. name='conv' + str(idx + 2)))
  409. self.extra_block_list.append(conv_extra)
  410. idx += 1
  411. self._update_out_channels(block_filter[1], idx + 1,
  412. feature_maps)
  413. def _update_out_channels(self, channel, feature_idx, feature_maps):
  414. if feature_idx in feature_maps:
  415. self._out_channels.append(channel)
  416. def forward(self, inputs):
  417. x = self.conv1(inputs['image'])
  418. outs = []
  419. for idx, ghost_bottleneck in enumerate(self.ghost_bottleneck_list):
  420. x = ghost_bottleneck(x)
  421. if idx + 2 in self.feature_maps:
  422. if isinstance(x, list):
  423. outs.append(x[0])
  424. x = x[1]
  425. else:
  426. outs.append(x)
  427. if not self.with_extra_blocks:
  428. return outs
  429. for i, block in enumerate(self.extra_block_list):
  430. idx = i + len(self.ghost_bottleneck_list)
  431. x = block(x)
  432. if idx + 2 in self.feature_maps:
  433. outs.append(x)
  434. return outs
  435. @property
  436. def out_shape(self):
  437. return [ShapeSpec(channels=c) for c in self._out_channels]