cb_resnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. from paddle import fluid
  19. from paddle.fluid.param_attr import ParamAttr
  20. from paddle.fluid.framework import Variable
  21. from paddle.fluid.regularizer import L2Decay
  22. from paddle.fluid.initializer import Constant
  23. from ppdet.core.workspace import register, serializable
  24. from numbers import Integral
  25. from .name_adapter import NameAdapter
  26. from .nonlocal_helper import add_space_nonlocal
  27. __all__ = ['CBResNet']
  28. @register
  29. @serializable
  30. class CBResNet(object):
  31. """
  32. CBNet, see https://arxiv.org/abs/1909.03625
  33. Args:
  34. depth (int): ResNet depth, should be 18, 34, 50, 101, 152.
  35. freeze_at (int): freeze the backbone at which stage
  36. norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel'
  37. freeze_norm (bool): freeze normalization layers
  38. norm_decay (float): weight decay for normalization layer weights
  39. variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
  40. feature_maps (list): index of stages whose feature maps are returned
  41. dcn_v2_stages (list): index of stages who select deformable conv v2
  42. nonlocal_stages (list): index of stages who select nonlocal networks
  43. repeat_num (int): number of repeat for backbone
  44. Attention:
  45. 1. Here we set the ResNet as the base backbone.
  46. 2. All the pretraned params are copied from corresponding names,
  47. but with different names to avoid name refliction.
  48. """
  49. def __init__(self,
  50. depth=50,
  51. freeze_at=2,
  52. norm_type='bn',
  53. freeze_norm=True,
  54. norm_decay=0.,
  55. variant='b',
  56. feature_maps=[2, 3, 4, 5],
  57. dcn_v2_stages=[],
  58. nonlocal_stages=[],
  59. repeat_num=2,
  60. lr_mult_list=[1., 1., 1., 1.]):
  61. super(CBResNet, self).__init__()
  62. if isinstance(feature_maps, Integral):
  63. feature_maps = [feature_maps]
  64. assert depth in [18, 34, 50, 101, 152, 200], \
  65. "depth {} not in [18, 34, 50, 101, 152, 200]"
  66. assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant"
  67. assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
  68. assert len(feature_maps) > 0, "need one or more feature maps"
  69. assert norm_type in ['bn', 'sync_bn', 'affine_channel']
  70. assert not (len(nonlocal_stages)>0 and depth<50), \
  71. "non-local is not supported for resnet18 or resnet34"
  72. self.depth = depth
  73. self.dcn_v2_stages = dcn_v2_stages
  74. self.freeze_at = freeze_at
  75. self.norm_type = norm_type
  76. self.norm_decay = norm_decay
  77. self.freeze_norm = freeze_norm
  78. self.variant = variant
  79. self._model_type = 'ResNet'
  80. self.feature_maps = feature_maps
  81. self.repeat_num = repeat_num
  82. self.curr_level = 0
  83. self.depth_cfg = {
  84. 18: ([2, 2, 2, 2], self.basicblock),
  85. 34: ([3, 4, 6, 3], self.basicblock),
  86. 50: ([3, 4, 6, 3], self.bottleneck),
  87. 101: ([3, 4, 23, 3], self.bottleneck),
  88. 152: ([3, 8, 36, 3], self.bottleneck),
  89. 200: ([3, 12, 48, 3], self.bottleneck),
  90. }
  91. self.nonlocal_stages = nonlocal_stages
  92. self.nonlocal_mod_cfg = {
  93. 50: 2,
  94. 101: 5,
  95. 152: 8,
  96. 200: 12,
  97. }
  98. self.lr_mult_list = lr_mult_list
  99. self.stage_num = -1
  100. self.stage_filters = [64, 128, 256, 512]
  101. self._c1_out_chan_num = 64
  102. self.na = NameAdapter(self)
  103. def _conv_offset(self,
  104. input,
  105. filter_size,
  106. stride,
  107. padding,
  108. act=None,
  109. name=None):
  110. out_channel = filter_size * filter_size * 3
  111. out = fluid.layers.conv2d(
  112. input,
  113. num_filters=out_channel,
  114. filter_size=filter_size,
  115. stride=stride,
  116. padding=padding,
  117. param_attr=ParamAttr(
  118. initializer=Constant(0.0), name=name + ".w_0"),
  119. bias_attr=ParamAttr(
  120. initializer=Constant(0.0), name=name + ".b_0"),
  121. act=act,
  122. name=name)
  123. return out
  124. def _conv_norm(self,
  125. input,
  126. num_filters,
  127. filter_size,
  128. stride=1,
  129. groups=1,
  130. act=None,
  131. name=None,
  132. dcn=False):
  133. # need fine lr for distilled model, default as 1.0
  134. lr_mult = 1.0
  135. mult_idx = max(self.stage_num - 2, 0)
  136. mult_idx = min(self.stage_num - 2, 3)
  137. lr_mult = self.lr_mult_list[mult_idx]
  138. if not dcn:
  139. conv = fluid.layers.conv2d(
  140. input=input,
  141. num_filters=num_filters,
  142. filter_size=filter_size,
  143. stride=stride,
  144. padding=(filter_size - 1) // 2,
  145. groups=groups,
  146. act=None,
  147. param_attr=ParamAttr(
  148. name=name + "_weights_" + str(self.curr_level),
  149. learning_rate=lr_mult),
  150. bias_attr=False)
  151. else:
  152. offset_mask = self._conv_offset(
  153. input=input,
  154. filter_size=filter_size,
  155. stride=stride,
  156. padding=(filter_size - 1) // 2,
  157. act=None,
  158. name=name + "_conv_offset_" + str(self.curr_level))
  159. offset_channel = filter_size**2 * 2
  160. mask_channel = filter_size**2
  161. offset, mask = fluid.layers.split(
  162. input=offset_mask,
  163. num_or_sections=[offset_channel, mask_channel],
  164. dim=1)
  165. mask = fluid.layers.sigmoid(mask)
  166. conv = fluid.layers.deformable_conv(
  167. input=input,
  168. offset=offset,
  169. mask=mask,
  170. num_filters=num_filters,
  171. filter_size=filter_size,
  172. stride=stride,
  173. padding=(filter_size - 1) // 2,
  174. groups=groups,
  175. deformable_groups=1,
  176. im2col_step=1,
  177. param_attr=ParamAttr(
  178. name=name + "_weights_" + str(self.curr_level),
  179. learning_rate=lr_mult),
  180. bias_attr=False)
  181. bn_name = self.na.fix_conv_norm_name(name)
  182. norm_lr = 0. if self.freeze_norm else lr_mult
  183. norm_decay = self.norm_decay
  184. pattr = ParamAttr(
  185. name=bn_name + '_scale_' + str(self.curr_level),
  186. learning_rate=norm_lr,
  187. regularizer=L2Decay(norm_decay))
  188. battr = ParamAttr(
  189. name=bn_name + '_offset_' + str(self.curr_level),
  190. learning_rate=norm_lr,
  191. regularizer=L2Decay(norm_decay))
  192. if self.norm_type in ['bn', 'sync_bn']:
  193. global_stats = True if self.freeze_norm else False
  194. out = fluid.layers.batch_norm(
  195. input=conv,
  196. act=act,
  197. name=bn_name + '.output.1_' + str(self.curr_level),
  198. param_attr=pattr,
  199. bias_attr=battr,
  200. moving_mean_name=bn_name + '_mean_' + str(self.curr_level),
  201. moving_variance_name=bn_name + '_variance_' +
  202. str(self.curr_level),
  203. use_global_stats=global_stats)
  204. scale = fluid.framework._get_var(pattr.name)
  205. bias = fluid.framework._get_var(battr.name)
  206. elif self.norm_type == 'affine_channel':
  207. assert False, "deprecated!!!"
  208. if self.freeze_norm:
  209. scale.stop_gradient = True
  210. bias.stop_gradient = True
  211. return out
  212. def _shortcut(self, input, ch_out, stride, is_first, name):
  213. max_pooling_in_short_cut = self.variant == 'd'
  214. ch_in = input.shape[1]
  215. # the naming rule is same as pretrained weight
  216. name = self.na.fix_shortcut_name(name)
  217. if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first):
  218. if max_pooling_in_short_cut and not is_first:
  219. input = fluid.layers.pool2d(
  220. input=input,
  221. pool_size=2,
  222. pool_stride=2,
  223. pool_padding=0,
  224. ceil_mode=True,
  225. pool_type='avg')
  226. return self._conv_norm(input, ch_out, 1, 1, name=name)
  227. return self._conv_norm(input, ch_out, 1, stride, name=name)
  228. else:
  229. return input
  230. def bottleneck(self, input, num_filters, stride, is_first, name, dcn=False):
  231. if self.variant == 'a':
  232. stride1, stride2 = stride, 1
  233. else:
  234. stride1, stride2 = 1, stride
  235. # ResNeXt
  236. groups = getattr(self, 'groups', 1)
  237. group_width = getattr(self, 'group_width', -1)
  238. if groups == 1:
  239. expand = 4
  240. elif (groups * group_width) == 256:
  241. expand = 1
  242. else: # FIXME hard code for now, handles 32x4d, 64x4d and 32x8d
  243. num_filters = num_filters // 2
  244. expand = 2
  245. conv_name1, conv_name2, conv_name3, \
  246. shortcut_name = self.na.fix_bottleneck_name(name)
  247. conv_def = [[num_filters, 1, stride1, 'relu', 1, conv_name1],
  248. [num_filters, 3, stride2, 'relu', groups, conv_name2],
  249. [num_filters * expand, 1, 1, None, 1, conv_name3]]
  250. residual = input
  251. for i, (c, k, s, act, g, _name) in enumerate(conv_def):
  252. residual = self._conv_norm(
  253. input=residual,
  254. num_filters=c,
  255. filter_size=k,
  256. stride=s,
  257. act=act,
  258. groups=g,
  259. name=_name,
  260. dcn=(i == 1 and dcn))
  261. short = self._shortcut(
  262. input,
  263. num_filters * expand,
  264. stride,
  265. is_first=is_first,
  266. name=shortcut_name)
  267. # Squeeze-and-Excitation
  268. if callable(getattr(self, '_squeeze_excitation', None)):
  269. residual = self._squeeze_excitation(
  270. input=residual, num_channels=num_filters, name='fc' + name)
  271. return fluid.layers.elementwise_add(x=short, y=residual, act='relu')
  272. def basicblock(self, input, num_filters, stride, is_first, name, dcn=False):
  273. assert dcn is False, "Not implemented yet."
  274. conv0 = self._conv_norm(
  275. input=input,
  276. num_filters=num_filters,
  277. filter_size=3,
  278. act='relu',
  279. stride=stride,
  280. name=name + "_branch2a")
  281. conv1 = self._conv_norm(
  282. input=conv0,
  283. num_filters=num_filters,
  284. filter_size=3,
  285. act=None,
  286. name=name + "_branch2b")
  287. short = self._shortcut(
  288. input, num_filters, stride, is_first, name=name + "_branch1")
  289. return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
  290. def layer_warp(self, input, stage_num):
  291. """
  292. Args:
  293. input (Variable): input variable.
  294. stage_num (int): the stage number, should be 2, 3, 4, 5
  295. Returns:
  296. The last variable in endpoint-th stage.
  297. """
  298. assert stage_num in [2, 3, 4, 5]
  299. self.stage_num = stage_num
  300. stages, block_func = self.depth_cfg[self.depth]
  301. count = stages[stage_num - 2]
  302. ch_out = self.stage_filters[stage_num - 2]
  303. is_first = False if stage_num != 2 else True
  304. dcn = True if stage_num in self.dcn_v2_stages else False
  305. nonlocal_mod = 1000
  306. if stage_num in self.nonlocal_stages:
  307. nonlocal_mod = self.nonlocal_mod_cfg[
  308. self.depth] if stage_num == 4 else 2
  309. # Make the layer name and parameter name consistent
  310. # with ImageNet pre-trained model
  311. conv = input
  312. for i in range(count):
  313. conv_name = self.na.fix_layer_warp_name(stage_num, count, i)
  314. if self.depth < 50:
  315. is_first = True if i == 0 and stage_num == 2 else False
  316. conv = block_func(
  317. input=conv,
  318. num_filters=ch_out,
  319. stride=2 if i == 0 and stage_num != 2 else 1,
  320. is_first=is_first,
  321. name=conv_name,
  322. dcn=dcn)
  323. # add non local model
  324. dim_in = conv.shape[1]
  325. nonlocal_name = "nonlocal_conv{}_lvl{}".format(stage_num,
  326. self.curr_level)
  327. if i % nonlocal_mod == nonlocal_mod - 1:
  328. conv = add_space_nonlocal(conv, dim_in, dim_in,
  329. nonlocal_name + '_{}'.format(i),
  330. int(dim_in / 2))
  331. return conv
  332. def c1_stage(self, input):
  333. out_chan = self._c1_out_chan_num
  334. conv1_name = self.na.fix_c1_stage_name()
  335. if self.variant in ['c', 'd']:
  336. conv1_1_name = "conv1_1"
  337. conv1_2_name = "conv1_2"
  338. conv1_3_name = "conv1_3"
  339. conv_def = [
  340. [out_chan // 2, 3, 2, conv1_1_name],
  341. [out_chan // 2, 3, 1, conv1_2_name],
  342. [out_chan, 3, 1, conv1_3_name],
  343. ]
  344. else:
  345. conv_def = [[out_chan, 7, 2, conv1_name]]
  346. for (c, k, s, _name) in conv_def:
  347. input = self._conv_norm(
  348. input=input,
  349. num_filters=c,
  350. filter_size=k,
  351. stride=s,
  352. act='relu',
  353. name=_name)
  354. output = fluid.layers.pool2d(
  355. input=input,
  356. pool_size=3,
  357. pool_stride=2,
  358. pool_padding=1,
  359. pool_type='max')
  360. return output
  361. def connect(self, left, right, name):
  362. ch_right = right.shape[1]
  363. conv = self._conv_norm(
  364. left,
  365. num_filters=ch_right,
  366. filter_size=1,
  367. stride=1,
  368. act="relu",
  369. name=name + "_connect")
  370. shape = fluid.layers.shape(right)
  371. shape_hw = fluid.layers.slice(shape, axes=[0], starts=[2], ends=[4])
  372. out_shape_ = shape_hw
  373. out_shape = fluid.layers.cast(out_shape_, dtype='int32')
  374. out_shape.stop_gradient = True
  375. conv = fluid.layers.resize_nearest(conv, scale=2., out_shape=out_shape)
  376. output = fluid.layers.elementwise_add(x=right, y=conv)
  377. return output
  378. def __call__(self, input):
  379. assert isinstance(input, Variable)
  380. assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
  381. "feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
  382. res_endpoints = []
  383. self.curr_level = 0
  384. res = self.c1_stage(input)
  385. feature_maps = range(2, max(self.feature_maps) + 1)
  386. for i in feature_maps:
  387. res = self.layer_warp(res, i)
  388. if i in self.feature_maps:
  389. res_endpoints.append(res)
  390. for num in range(1, self.repeat_num):
  391. self.stage_num = -1
  392. self.curr_level = num
  393. res = self.c1_stage(input)
  394. for i in range(len(res_endpoints)):
  395. res = self.connect(res_endpoints[i], res, "test_c" + str(i + 1))
  396. res = self.layer_warp(res, i + 2)
  397. res_endpoints[i] = res
  398. if self.freeze_at >= i + 2:
  399. res.stop_gradient = True
  400. return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
  401. for idx, feat in enumerate(res_endpoints)])