efficientnet.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. import collections
  17. import math
  18. import re
  19. from paddle import fluid
  20. from paddle.fluid.regularizer import L2Decay
  21. from ppdet.core.workspace import register
  22. __all__ = ['EfficientNet']
  23. GlobalParams = collections.namedtuple('GlobalParams', [
  24. 'batch_norm_momentum', 'batch_norm_epsilon', 'width_coefficient',
  25. 'depth_coefficient', 'depth_divisor'
  26. ])
  27. BlockArgs = collections.namedtuple('BlockArgs', [
  28. 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
  29. 'expand_ratio', 'stride', 'se_ratio'
  30. ])
  31. GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
  32. BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
  33. def _decode_block_string(block_string):
  34. assert isinstance(block_string, str)
  35. ops = block_string.split('_')
  36. options = {}
  37. for op in ops:
  38. splits = re.split(r'(\d.*)', op)
  39. if len(splits) >= 2:
  40. key, value = splits[:2]
  41. options[key] = value
  42. assert (('s' in options and len(options['s']) == 1) or
  43. (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
  44. return BlockArgs(
  45. kernel_size=int(options['k']),
  46. num_repeat=int(options['r']),
  47. input_filters=int(options['i']),
  48. output_filters=int(options['o']),
  49. expand_ratio=int(options['e']),
  50. se_ratio=float(options['se']) if 'se' in options else None,
  51. stride=int(options['s'][0]))
  52. def get_model_params(scale):
  53. block_strings = [
  54. 'r1_k3_s11_e1_i32_o16_se0.25',
  55. 'r2_k3_s22_e6_i16_o24_se0.25',
  56. 'r2_k5_s22_e6_i24_o40_se0.25',
  57. 'r3_k3_s22_e6_i40_o80_se0.25',
  58. 'r3_k5_s11_e6_i80_o112_se0.25',
  59. 'r4_k5_s22_e6_i112_o192_se0.25',
  60. 'r1_k3_s11_e6_i192_o320_se0.25',
  61. ]
  62. block_args = []
  63. for block_string in block_strings:
  64. block_args.append(_decode_block_string(block_string))
  65. params_dict = {
  66. # width, depth
  67. 'b0': (1.0, 1.0),
  68. 'b1': (1.0, 1.1),
  69. 'b2': (1.1, 1.2),
  70. 'b3': (1.2, 1.4),
  71. 'b4': (1.4, 1.8),
  72. 'b5': (1.6, 2.2),
  73. 'b6': (1.8, 2.6),
  74. 'b7': (2.0, 3.1),
  75. }
  76. w, d = params_dict[scale]
  77. global_params = GlobalParams(
  78. batch_norm_momentum=0.99,
  79. batch_norm_epsilon=1e-3,
  80. width_coefficient=w,
  81. depth_coefficient=d,
  82. depth_divisor=8)
  83. return block_args, global_params
  84. def round_filters(filters, global_params):
  85. multiplier = global_params.width_coefficient
  86. if not multiplier:
  87. return filters
  88. divisor = global_params.depth_divisor
  89. filters *= multiplier
  90. min_depth = divisor
  91. new_filters = max(min_depth,
  92. int(filters + divisor / 2) // divisor * divisor)
  93. if new_filters < 0.9 * filters: # prevent rounding by more than 10%
  94. new_filters += divisor
  95. return int(new_filters)
  96. def round_repeats(repeats, global_params):
  97. multiplier = global_params.depth_coefficient
  98. if not multiplier:
  99. return repeats
  100. return int(math.ceil(multiplier * repeats))
  101. def conv2d(inputs,
  102. num_filters,
  103. filter_size,
  104. stride=1,
  105. padding='SAME',
  106. groups=1,
  107. use_bias=False,
  108. name='conv2d'):
  109. param_attr = fluid.ParamAttr(name=name + '_weights')
  110. bias_attr = False
  111. if use_bias:
  112. bias_attr = fluid.ParamAttr(
  113. name=name + '_offset', regularizer=L2Decay(0.))
  114. feats = fluid.layers.conv2d(
  115. inputs,
  116. num_filters,
  117. filter_size,
  118. groups=groups,
  119. name=name,
  120. stride=stride,
  121. padding=padding,
  122. param_attr=param_attr,
  123. bias_attr=bias_attr)
  124. return feats
  125. def batch_norm(inputs, momentum, eps, name=None):
  126. param_attr = fluid.ParamAttr(name=name + '_scale', regularizer=L2Decay(0.))
  127. bias_attr = fluid.ParamAttr(name=name + '_offset', regularizer=L2Decay(0.))
  128. return fluid.layers.batch_norm(
  129. input=inputs,
  130. momentum=momentum,
  131. epsilon=eps,
  132. name=name,
  133. moving_mean_name=name + '_mean',
  134. moving_variance_name=name + '_variance',
  135. param_attr=param_attr,
  136. bias_attr=bias_attr)
  137. def mb_conv_block(inputs,
  138. input_filters,
  139. output_filters,
  140. expand_ratio,
  141. kernel_size,
  142. stride,
  143. momentum,
  144. eps,
  145. se_ratio=None,
  146. name=None):
  147. feats = inputs
  148. num_filters = input_filters * expand_ratio
  149. if expand_ratio != 1:
  150. feats = conv2d(feats, num_filters, 1, name=name + '_expand_conv')
  151. feats = batch_norm(feats, momentum, eps, name=name + '_bn0')
  152. feats = fluid.layers.swish(feats)
  153. feats = conv2d(
  154. feats,
  155. num_filters,
  156. kernel_size,
  157. stride,
  158. groups=num_filters,
  159. name=name + '_depthwise_conv')
  160. feats = batch_norm(feats, momentum, eps, name=name + '_bn1')
  161. feats = fluid.layers.swish(feats)
  162. if se_ratio is not None:
  163. filter_squeezed = max(1, int(input_filters * se_ratio))
  164. squeezed = fluid.layers.pool2d(
  165. feats, pool_type='avg', global_pooling=True)
  166. squeezed = conv2d(
  167. squeezed,
  168. filter_squeezed,
  169. 1,
  170. use_bias=True,
  171. name=name + '_se_reduce')
  172. squeezed = fluid.layers.swish(squeezed)
  173. squeezed = conv2d(
  174. squeezed, num_filters, 1, use_bias=True, name=name + '_se_expand')
  175. feats = feats * fluid.layers.sigmoid(squeezed)
  176. feats = conv2d(feats, output_filters, 1, name=name + '_project_conv')
  177. feats = batch_norm(feats, momentum, eps, name=name + '_bn2')
  178. if stride == 1 and input_filters == output_filters:
  179. feats = fluid.layers.elementwise_add(feats, inputs)
  180. return feats
  181. @register
  182. class EfficientNet(object):
  183. """
  184. EfficientNet, see https://arxiv.org/abs/1905.11946
  185. Args:
  186. scale (str): compounding scale factor, 'b0' - 'b7'.
  187. use_se (bool): use squeeze and excite module.
  188. norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
  189. """
  190. __shared__ = ['norm_type']
  191. def __init__(self, scale='b0', use_se=True, norm_type='bn'):
  192. assert scale in ['b' + str(i) for i in range(8)], \
  193. "valid scales are b0 - b7"
  194. assert norm_type in ['bn', 'sync_bn'], \
  195. "only 'bn' and 'sync_bn' are supported"
  196. super(EfficientNet, self).__init__()
  197. self.norm_type = norm_type
  198. self.scale = scale
  199. self.use_se = use_se
  200. def __call__(self, inputs):
  201. blocks_args, global_params = get_model_params(self.scale)
  202. momentum = global_params.batch_norm_momentum
  203. eps = global_params.batch_norm_epsilon
  204. num_filters = round_filters(32, global_params)
  205. feats = conv2d(
  206. inputs,
  207. num_filters=num_filters,
  208. filter_size=3,
  209. stride=2,
  210. name='_conv_stem')
  211. feats = batch_norm(feats, momentum=momentum, eps=eps, name='_bn0')
  212. feats = fluid.layers.swish(feats)
  213. layer_count = 0
  214. feature_maps = []
  215. for b, block_arg in enumerate(blocks_args):
  216. for r in range(block_arg.num_repeat):
  217. input_filters = round_filters(block_arg.input_filters,
  218. global_params)
  219. output_filters = round_filters(block_arg.output_filters,
  220. global_params)
  221. kernel_size = block_arg.kernel_size
  222. stride = block_arg.stride
  223. se_ratio = None
  224. if self.use_se:
  225. se_ratio = block_arg.se_ratio
  226. if r > 0:
  227. input_filters = output_filters
  228. stride = 1
  229. feats = mb_conv_block(
  230. feats,
  231. input_filters,
  232. output_filters,
  233. block_arg.expand_ratio,
  234. kernel_size,
  235. stride,
  236. momentum,
  237. eps,
  238. se_ratio=se_ratio,
  239. name='_blocks.{}.'.format(layer_count))
  240. layer_count += 1
  241. feature_maps.append(feats)
  242. return list(feature_maps[i] for i in [2, 4, 6])