resnet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. from paddle import fluid
  19. from paddle.fluid.param_attr import ParamAttr
  20. from paddle.fluid.framework import Variable
  21. from paddle.fluid.regularizer import L2Decay
  22. from paddle.fluid.initializer import Constant
  23. from ppdet.core.workspace import register, serializable
  24. from numbers import Integral
  25. from .nonlocal_helper import add_space_nonlocal
  26. from .gc_block import add_gc_block
  27. from .name_adapter import NameAdapter
  28. __all__ = ['ResNet', 'ResNetC5']
  29. @register
  30. @serializable
  31. class ResNet(object):
  32. """
  33. Residual Network, see https://arxiv.org/abs/1512.03385
  34. Args:
  35. depth (int): ResNet depth, should be 18, 34, 50, 101, 152.
  36. freeze_at (int): freeze the backbone at which stage
  37. norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel'
  38. freeze_norm (bool): freeze normalization layers
  39. norm_decay (float): weight decay for normalization layer weights
  40. variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
  41. feature_maps (list): index of stages whose feature maps are returned
  42. dcn_v2_stages (list): index of stages who select deformable conv v2
  43. nonlocal_stages (list): index of stages who select nonlocal networks
  44. gcb_stages (list): index of stages who select gc blocks
  45. gcb_params (dict): gc blocks config, includes ratio(default as 1.0/16),
  46. pooling_type(default as "att") and
  47. fusion_types(default as ['channel_add'])
  48. lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
  49. lower learning rate ratio is need for pretrained model
  50. got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
  51. """
  52. __shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name']
  53. def __init__(self,
  54. depth=50,
  55. freeze_at=2,
  56. norm_type='affine_channel',
  57. freeze_norm=True,
  58. norm_decay=0.,
  59. variant='b',
  60. feature_maps=[2, 3, 4, 5],
  61. dcn_v2_stages=[],
  62. weight_prefix_name='',
  63. nonlocal_stages=[],
  64. gcb_stages=[],
  65. gcb_params=dict(),
  66. lr_mult_list=[1., 1., 1., 1.]):
  67. super(ResNet, self).__init__()
  68. if isinstance(feature_maps, Integral):
  69. feature_maps = [feature_maps]
  70. assert depth in [18, 34, 50, 101, 152, 200], \
  71. "depth {} not in [18, 34, 50, 101, 152, 200]"
  72. assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant"
  73. assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
  74. assert len(feature_maps) > 0, "need one or more feature maps"
  75. assert norm_type in ['bn', 'sync_bn', 'affine_channel']
  76. assert not (len(nonlocal_stages)>0 and depth<50), \
  77. "non-local is not supported for resnet18 or resnet34"
  78. assert len(lr_mult_list
  79. ) == 4, "lr_mult_list length must be 4 but got {}".format(
  80. len(lr_mult_list))
  81. self.depth = depth
  82. self.freeze_at = freeze_at
  83. self.norm_type = norm_type
  84. self.norm_decay = norm_decay
  85. self.freeze_norm = freeze_norm
  86. self.variant = variant
  87. self._model_type = 'ResNet'
  88. self.feature_maps = feature_maps
  89. self.dcn_v2_stages = dcn_v2_stages
  90. self.depth_cfg = {
  91. 18: ([2, 2, 2, 2], self.basicblock),
  92. 34: ([3, 4, 6, 3], self.basicblock),
  93. 50: ([3, 4, 6, 3], self.bottleneck),
  94. 101: ([3, 4, 23, 3], self.bottleneck),
  95. 152: ([3, 8, 36, 3], self.bottleneck),
  96. 200: ([3, 12, 48, 3], self.bottleneck),
  97. }
  98. self.stage_filters = [64, 128, 256, 512]
  99. self._c1_out_chan_num = 64
  100. self.na = NameAdapter(self)
  101. self.prefix_name = weight_prefix_name
  102. self.nonlocal_stages = nonlocal_stages
  103. self.nonlocal_mod_cfg = {
  104. 50: 2,
  105. 101: 5,
  106. 152: 8,
  107. 200: 12,
  108. }
  109. self.gcb_stages = gcb_stages
  110. self.gcb_params = gcb_params
  111. self.lr_mult_list = lr_mult_list
  112. # var denoting curr stage
  113. self.stage_num = -1
  114. def _conv_offset(self,
  115. input,
  116. filter_size,
  117. stride,
  118. padding,
  119. act=None,
  120. name=None):
  121. out_channel = filter_size * filter_size * 3
  122. out = fluid.layers.conv2d(
  123. input,
  124. num_filters=out_channel,
  125. filter_size=filter_size,
  126. stride=stride,
  127. padding=padding,
  128. param_attr=ParamAttr(
  129. initializer=Constant(0.0), name=name + ".w_0"),
  130. bias_attr=ParamAttr(
  131. initializer=Constant(0.0), name=name + ".b_0"),
  132. act=act,
  133. name=name)
  134. return out
  135. def _conv_norm(self,
  136. input,
  137. num_filters,
  138. filter_size,
  139. stride=1,
  140. groups=1,
  141. act=None,
  142. name=None,
  143. dcn_v2=False):
  144. _name = self.prefix_name + name if self.prefix_name != '' else name
  145. # need fine lr for distilled model, default as 1.0
  146. lr_mult = 1.0
  147. mult_idx = max(self.stage_num - 2, 0)
  148. mult_idx = min(self.stage_num - 2, 3)
  149. lr_mult = self.lr_mult_list[mult_idx]
  150. if not dcn_v2:
  151. conv = fluid.layers.conv2d(
  152. input=input,
  153. num_filters=num_filters,
  154. filter_size=filter_size,
  155. stride=stride,
  156. padding=(filter_size - 1) // 2,
  157. groups=groups,
  158. act=None,
  159. param_attr=ParamAttr(
  160. name=_name + "_weights", learning_rate=lr_mult),
  161. bias_attr=False,
  162. name=_name + '.conv2d.output.1')
  163. else:
  164. # select deformable conv"
  165. offset_mask = self._conv_offset(
  166. input=input,
  167. filter_size=filter_size,
  168. stride=stride,
  169. padding=(filter_size - 1) // 2,
  170. act=None,
  171. name=_name + "_conv_offset")
  172. offset_channel = filter_size**2 * 2
  173. mask_channel = filter_size**2
  174. offset, mask = fluid.layers.split(
  175. input=offset_mask,
  176. num_or_sections=[offset_channel, mask_channel],
  177. dim=1)
  178. mask = fluid.layers.sigmoid(mask)
  179. conv = fluid.layers.deformable_conv(
  180. input=input,
  181. offset=offset,
  182. mask=mask,
  183. num_filters=num_filters,
  184. filter_size=filter_size,
  185. stride=stride,
  186. padding=(filter_size - 1) // 2,
  187. groups=groups,
  188. deformable_groups=1,
  189. im2col_step=1,
  190. param_attr=ParamAttr(
  191. name=_name + "_weights", learning_rate=lr_mult),
  192. bias_attr=False,
  193. name=_name + ".conv2d.output.1")
  194. bn_name = self.na.fix_conv_norm_name(name)
  195. bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name
  196. norm_lr = 0. if self.freeze_norm else lr_mult
  197. norm_decay = self.norm_decay
  198. pattr = ParamAttr(
  199. name=bn_name + '_scale',
  200. learning_rate=norm_lr,
  201. regularizer=L2Decay(norm_decay))
  202. battr = ParamAttr(
  203. name=bn_name + '_offset',
  204. learning_rate=norm_lr,
  205. regularizer=L2Decay(norm_decay))
  206. if self.norm_type in ['bn', 'sync_bn']:
  207. global_stats = True if self.freeze_norm else False
  208. out = fluid.layers.batch_norm(
  209. input=conv,
  210. act=act,
  211. name=bn_name + '.output.1',
  212. param_attr=pattr,
  213. bias_attr=battr,
  214. moving_mean_name=bn_name + '_mean',
  215. moving_variance_name=bn_name + '_variance',
  216. use_global_stats=global_stats)
  217. scale = fluid.framework._get_var(pattr.name)
  218. bias = fluid.framework._get_var(battr.name)
  219. elif self.norm_type == 'affine_channel':
  220. scale = fluid.layers.create_parameter(
  221. shape=[conv.shape[1]],
  222. dtype=conv.dtype,
  223. attr=pattr,
  224. default_initializer=fluid.initializer.Constant(1.))
  225. bias = fluid.layers.create_parameter(
  226. shape=[conv.shape[1]],
  227. dtype=conv.dtype,
  228. attr=battr,
  229. default_initializer=fluid.initializer.Constant(0.))
  230. out = fluid.layers.affine_channel(
  231. x=conv, scale=scale, bias=bias, act=act)
  232. if self.freeze_norm:
  233. scale.stop_gradient = True
  234. bias.stop_gradient = True
  235. return out
  236. def _shortcut(self, input, ch_out, stride, is_first, name):
  237. max_pooling_in_short_cut = self.variant == 'd'
  238. ch_in = input.shape[1]
  239. # the naming rule is same as pretrained weight
  240. name = self.na.fix_shortcut_name(name)
  241. std_senet = getattr(self, 'std_senet', False)
  242. if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first):
  243. if std_senet:
  244. if is_first:
  245. return self._conv_norm(input, ch_out, 1, stride, name=name)
  246. else:
  247. return self._conv_norm(input, ch_out, 3, stride, name=name)
  248. if max_pooling_in_short_cut and not is_first:
  249. input = fluid.layers.pool2d(
  250. input=input,
  251. pool_size=2,
  252. pool_stride=2,
  253. pool_padding=0,
  254. ceil_mode=True,
  255. pool_type='avg')
  256. return self._conv_norm(input, ch_out, 1, 1, name=name)
  257. return self._conv_norm(input, ch_out, 1, stride, name=name)
  258. else:
  259. return input
  260. def bottleneck(self,
  261. input,
  262. num_filters,
  263. stride,
  264. is_first,
  265. name,
  266. dcn_v2=False,
  267. gcb=False,
  268. gcb_name=None):
  269. if self.variant == 'a':
  270. stride1, stride2 = stride, 1
  271. else:
  272. stride1, stride2 = 1, stride
  273. # ResNeXt
  274. groups = getattr(self, 'groups', 1)
  275. group_width = getattr(self, 'group_width', -1)
  276. if groups == 1:
  277. expand = 4
  278. elif (groups * group_width) == 256:
  279. expand = 1
  280. else: # FIXME hard code for now, handles 32x4d, 64x4d and 32x8d
  281. num_filters = num_filters // 2
  282. expand = 2
  283. conv_name1, conv_name2, conv_name3, \
  284. shortcut_name = self.na.fix_bottleneck_name(name)
  285. std_senet = getattr(self, 'std_senet', False)
  286. if std_senet:
  287. conv_def = [
  288. [int(num_filters / 2), 1, stride1, 'relu', 1, conv_name1],
  289. [num_filters, 3, stride2, 'relu', groups, conv_name2],
  290. [num_filters * expand, 1, 1, None, 1, conv_name3]
  291. ]
  292. else:
  293. conv_def = [[num_filters, 1, stride1, 'relu', 1, conv_name1],
  294. [num_filters, 3, stride2, 'relu', groups, conv_name2],
  295. [num_filters * expand, 1, 1, None, 1, conv_name3]]
  296. residual = input
  297. for i, (c, k, s, act, g, _name) in enumerate(conv_def):
  298. residual = self._conv_norm(
  299. input=residual,
  300. num_filters=c,
  301. filter_size=k,
  302. stride=s,
  303. act=act,
  304. groups=g,
  305. name=_name,
  306. dcn_v2=(i == 1 and dcn_v2))
  307. short = self._shortcut(
  308. input,
  309. num_filters * expand,
  310. stride,
  311. is_first=is_first,
  312. name=shortcut_name)
  313. # Squeeze-and-Excitation
  314. if callable(getattr(self, '_squeeze_excitation', None)):
  315. residual = self._squeeze_excitation(
  316. input=residual, num_channels=num_filters, name='fc' + name)
  317. if gcb:
  318. residual = add_gc_block(residual, name=gcb_name, **self.gcb_params)
  319. return fluid.layers.elementwise_add(
  320. x=short, y=residual, act='relu', name=name + ".add.output.5")
  321. def basicblock(self,
  322. input,
  323. num_filters,
  324. stride,
  325. is_first,
  326. name,
  327. dcn_v2=False,
  328. gcb=False,
  329. gcb_name=None):
  330. assert dcn_v2 is False, "Not implemented yet."
  331. assert gcb is False, "Not implemented yet."
  332. conv0 = self._conv_norm(
  333. input=input,
  334. num_filters=num_filters,
  335. filter_size=3,
  336. act='relu',
  337. stride=stride,
  338. name=name + "_branch2a")
  339. conv1 = self._conv_norm(
  340. input=conv0,
  341. num_filters=num_filters,
  342. filter_size=3,
  343. act=None,
  344. name=name + "_branch2b")
  345. short = self._shortcut(
  346. input, num_filters, stride, is_first, name=name + "_branch1")
  347. return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
  348. def layer_warp(self, input, stage_num):
  349. """
  350. Args:
  351. input (Variable): input variable.
  352. stage_num (int): the stage number, should be 2, 3, 4, 5
  353. Returns:
  354. The last variable in endpoint-th stage.
  355. """
  356. assert stage_num in [2, 3, 4, 5]
  357. self.stage_num = stage_num
  358. stages, block_func = self.depth_cfg[self.depth]
  359. count = stages[stage_num - 2]
  360. ch_out = self.stage_filters[stage_num - 2]
  361. is_first = False if stage_num != 2 else True
  362. dcn_v2 = True if stage_num in self.dcn_v2_stages else False
  363. nonlocal_mod = 1000
  364. if stage_num in self.nonlocal_stages:
  365. nonlocal_mod = self.nonlocal_mod_cfg[
  366. self.depth] if stage_num == 4 else 2
  367. # Make the layer name and parameter name consistent
  368. # with ImageNet pre-trained model
  369. conv = input
  370. for i in range(count):
  371. conv_name = self.na.fix_layer_warp_name(stage_num, count, i)
  372. if self.depth < 50:
  373. is_first = True if i == 0 and stage_num == 2 else False
  374. gcb = stage_num in self.gcb_stages
  375. gcb_name = "gcb_res{}_b{}".format(stage_num, i)
  376. conv = block_func(
  377. input=conv,
  378. num_filters=ch_out,
  379. stride=2 if i == 0 and stage_num != 2 else 1,
  380. is_first=is_first,
  381. name=conv_name,
  382. dcn_v2=dcn_v2,
  383. gcb=gcb,
  384. gcb_name=gcb_name)
  385. # add non local model
  386. dim_in = conv.shape[1]
  387. nonlocal_name = "nonlocal_conv{}".format(stage_num)
  388. if i % nonlocal_mod == nonlocal_mod - 1:
  389. conv = add_space_nonlocal(conv, dim_in, dim_in,
  390. nonlocal_name + '_{}'.format(i),
  391. int(dim_in / 2))
  392. return conv
  393. def c1_stage(self, input):
  394. out_chan = self._c1_out_chan_num
  395. conv1_name = self.na.fix_c1_stage_name()
  396. if self.variant in ['c', 'd']:
  397. conv_def = [
  398. [out_chan // 2, 3, 2, "conv1_1"],
  399. [out_chan // 2, 3, 1, "conv1_2"],
  400. [out_chan, 3, 1, "conv1_3"],
  401. ]
  402. else:
  403. conv_def = [[out_chan, 7, 2, conv1_name]]
  404. for (c, k, s, _name) in conv_def:
  405. input = self._conv_norm(
  406. input=input,
  407. num_filters=c,
  408. filter_size=k,
  409. stride=s,
  410. act='relu',
  411. name=_name)
  412. output = fluid.layers.pool2d(
  413. input=input,
  414. pool_size=3,
  415. pool_stride=2,
  416. pool_padding=1,
  417. pool_type='max')
  418. return output
  419. def __call__(self, input):
  420. assert isinstance(input, Variable)
  421. assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
  422. "feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
  423. res_endpoints = []
  424. res = input
  425. feature_maps = self.feature_maps
  426. severed_head = getattr(self, 'severed_head', False)
  427. if not severed_head:
  428. res = self.c1_stage(res)
  429. feature_maps = range(2, max(self.feature_maps) + 1)
  430. for i in feature_maps:
  431. res = self.layer_warp(res, i)
  432. if i in self.feature_maps:
  433. res_endpoints.append(res)
  434. if self.freeze_at >= i:
  435. res.stop_gradient = True
  436. return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
  437. for idx, feat in enumerate(res_endpoints)])
  438. @register
  439. @serializable
  440. class ResNetC5(ResNet):
  441. __doc__ = ResNet.__doc__
  442. def __init__(self,
  443. depth=50,
  444. freeze_at=2,
  445. norm_type='affine_channel',
  446. freeze_norm=True,
  447. norm_decay=0.,
  448. variant='b',
  449. feature_maps=[5],
  450. weight_prefix_name=''):
  451. super(ResNetC5, self).__init__(depth, freeze_at, norm_type, freeze_norm,
  452. norm_decay, variant, feature_maps)
  453. self.severed_head = True