res2net.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from numbers import Integral
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from ppdet.core.workspace import register, serializable
  19. from ..shape_spec import ShapeSpec
  20. from .resnet import ConvNormLayer
  21. __all__ = ['Res2Net', 'Res2NetC5']
  22. Res2Net_cfg = {
  23. 50: [3, 4, 6, 3],
  24. 101: [3, 4, 23, 3],
  25. 152: [3, 8, 36, 3],
  26. 200: [3, 12, 48, 3]
  27. }
  28. class BottleNeck(nn.Layer):
  29. def __init__(self,
  30. ch_in,
  31. ch_out,
  32. stride,
  33. shortcut,
  34. width,
  35. scales=4,
  36. variant='b',
  37. groups=1,
  38. lr=1.0,
  39. norm_type='bn',
  40. norm_decay=0.,
  41. freeze_norm=True,
  42. dcn_v2=False):
  43. super(BottleNeck, self).__init__()
  44. self.shortcut = shortcut
  45. self.scales = scales
  46. self.stride = stride
  47. if not shortcut:
  48. if variant == 'd' and stride == 2:
  49. self.branch1 = nn.Sequential()
  50. self.branch1.add_sublayer(
  51. 'pool',
  52. nn.AvgPool2D(
  53. kernel_size=2, stride=2, padding=0, ceil_mode=True))
  54. self.branch1.add_sublayer(
  55. 'conv',
  56. ConvNormLayer(
  57. ch_in=ch_in,
  58. ch_out=ch_out,
  59. filter_size=1,
  60. stride=1,
  61. norm_type=norm_type,
  62. norm_decay=norm_decay,
  63. freeze_norm=freeze_norm,
  64. lr=lr))
  65. else:
  66. self.branch1 = ConvNormLayer(
  67. ch_in=ch_in,
  68. ch_out=ch_out,
  69. filter_size=1,
  70. stride=stride,
  71. norm_type=norm_type,
  72. norm_decay=norm_decay,
  73. freeze_norm=freeze_norm,
  74. lr=lr)
  75. self.branch2a = ConvNormLayer(
  76. ch_in=ch_in,
  77. ch_out=width * scales,
  78. filter_size=1,
  79. stride=stride if variant == 'a' else 1,
  80. groups=1,
  81. act='relu',
  82. norm_type=norm_type,
  83. norm_decay=norm_decay,
  84. freeze_norm=freeze_norm,
  85. lr=lr)
  86. self.branch2b = nn.LayerList([
  87. ConvNormLayer(
  88. ch_in=width,
  89. ch_out=width,
  90. filter_size=3,
  91. stride=1 if variant == 'a' else stride,
  92. groups=groups,
  93. act='relu',
  94. norm_type=norm_type,
  95. norm_decay=norm_decay,
  96. freeze_norm=freeze_norm,
  97. lr=lr,
  98. dcn_v2=dcn_v2) for _ in range(self.scales - 1)
  99. ])
  100. self.branch2c = ConvNormLayer(
  101. ch_in=width * scales,
  102. ch_out=ch_out,
  103. filter_size=1,
  104. stride=1,
  105. groups=1,
  106. norm_type=norm_type,
  107. norm_decay=norm_decay,
  108. freeze_norm=freeze_norm,
  109. lr=lr)
  110. def forward(self, inputs):
  111. out = self.branch2a(inputs)
  112. feature_split = paddle.split(out, self.scales, 1)
  113. out_split = []
  114. for i in range(self.scales - 1):
  115. if i == 0 or self.stride == 2:
  116. out_split.append(self.branch2b[i](feature_split[i]))
  117. else:
  118. out_split.append(self.branch2b[i](paddle.add(feature_split[i],
  119. out_split[-1])))
  120. if self.stride == 1:
  121. out_split.append(feature_split[-1])
  122. else:
  123. out_split.append(F.avg_pool2d(feature_split[-1], 3, self.stride, 1))
  124. out = self.branch2c(paddle.concat(out_split, 1))
  125. if self.shortcut:
  126. short = inputs
  127. else:
  128. short = self.branch1(inputs)
  129. out = paddle.add(out, short)
  130. out = F.relu(out)
  131. return out
  132. class Blocks(nn.Layer):
  133. def __init__(self,
  134. ch_in,
  135. ch_out,
  136. count,
  137. stage_num,
  138. width,
  139. scales=4,
  140. variant='b',
  141. groups=1,
  142. lr=1.0,
  143. norm_type='bn',
  144. norm_decay=0.,
  145. freeze_norm=True,
  146. dcn_v2=False):
  147. super(Blocks, self).__init__()
  148. self.blocks = nn.Sequential()
  149. for i in range(count):
  150. self.blocks.add_sublayer(
  151. str(i),
  152. BottleNeck(
  153. ch_in=ch_in if i == 0 else ch_out,
  154. ch_out=ch_out,
  155. stride=2 if i == 0 and stage_num != 2 else 1,
  156. shortcut=False if i == 0 else True,
  157. width=width * (2**(stage_num - 2)),
  158. scales=scales,
  159. variant=variant,
  160. groups=groups,
  161. lr=lr,
  162. norm_type=norm_type,
  163. norm_decay=norm_decay,
  164. freeze_norm=freeze_norm,
  165. dcn_v2=dcn_v2))
  166. def forward(self, inputs):
  167. return self.blocks(inputs)
  168. @register
  169. @serializable
  170. class Res2Net(nn.Layer):
  171. """
  172. Res2Net, see https://arxiv.org/abs/1904.01169
  173. Args:
  174. depth (int): Res2Net depth, should be 50, 101, 152, 200.
  175. width (int): Res2Net width
  176. scales (int): Res2Net scale
  177. variant (str): Res2Net variant, supports 'a', 'b', 'c', 'd' currently
  178. lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
  179. lower learning rate ratio is need for pretrained model
  180. got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
  181. groups (int): The groups number of the Conv Layer.
  182. norm_type (str): normalization type, 'bn' or 'sync_bn'
  183. norm_decay (float): weight decay for normalization layer weights
  184. freeze_norm (bool): freeze normalization layers
  185. freeze_at (int): freeze the backbone at which stage
  186. return_idx (list): index of stages whose feature maps are returned,
  187. index 0 stands for res2
  188. dcn_v2_stages (list): index of stages who select deformable conv v2
  189. num_stages (int): number of stages created
  190. """
  191. __shared__ = ['norm_type']
  192. def __init__(self,
  193. depth=50,
  194. width=26,
  195. scales=4,
  196. variant='b',
  197. lr_mult_list=[1.0, 1.0, 1.0, 1.0],
  198. groups=1,
  199. norm_type='bn',
  200. norm_decay=0.,
  201. freeze_norm=True,
  202. freeze_at=0,
  203. return_idx=[0, 1, 2, 3],
  204. dcn_v2_stages=[-1],
  205. num_stages=4):
  206. super(Res2Net, self).__init__()
  207. self._model_type = 'Res2Net' if groups == 1 else 'Res2NeXt'
  208. assert depth in [50, 101, 152, 200], \
  209. "depth {} not in [50, 101, 152, 200]"
  210. assert variant in ['a', 'b', 'c', 'd'], "invalid Res2Net variant"
  211. assert num_stages >= 1 and num_stages <= 4
  212. self.depth = depth
  213. self.variant = variant
  214. self.norm_type = norm_type
  215. self.norm_decay = norm_decay
  216. self.freeze_norm = freeze_norm
  217. self.freeze_at = freeze_at
  218. if isinstance(return_idx, Integral):
  219. return_idx = [return_idx]
  220. assert max(return_idx) < num_stages, \
  221. 'the maximum return index must smaller than num_stages, ' \
  222. 'but received maximum return index is {} and num_stages ' \
  223. 'is {}'.format(max(return_idx), num_stages)
  224. self.return_idx = return_idx
  225. self.num_stages = num_stages
  226. assert len(lr_mult_list) == 4, \
  227. "lr_mult_list length must be 4 but got {}".format(len(lr_mult_list))
  228. if isinstance(dcn_v2_stages, Integral):
  229. dcn_v2_stages = [dcn_v2_stages]
  230. assert max(dcn_v2_stages) < num_stages
  231. self.dcn_v2_stages = dcn_v2_stages
  232. block_nums = Res2Net_cfg[depth]
  233. # C1 stage
  234. if self.variant in ['c', 'd']:
  235. conv_def = [
  236. [3, 32, 3, 2, "conv1_1"],
  237. [32, 32, 3, 1, "conv1_2"],
  238. [32, 64, 3, 1, "conv1_3"],
  239. ]
  240. else:
  241. conv_def = [[3, 64, 7, 2, "conv1"]]
  242. self.res1 = nn.Sequential()
  243. for (c_in, c_out, k, s, _name) in conv_def:
  244. self.res1.add_sublayer(
  245. _name,
  246. ConvNormLayer(
  247. ch_in=c_in,
  248. ch_out=c_out,
  249. filter_size=k,
  250. stride=s,
  251. groups=1,
  252. act='relu',
  253. norm_type=norm_type,
  254. norm_decay=norm_decay,
  255. freeze_norm=freeze_norm,
  256. lr=1.0))
  257. self._in_channels = [64, 256, 512, 1024]
  258. self._out_channels = [256, 512, 1024, 2048]
  259. self._out_strides = [4, 8, 16, 32]
  260. # C2-C5 stages
  261. self.res_layers = []
  262. for i in range(num_stages):
  263. lr_mult = lr_mult_list[i]
  264. stage_num = i + 2
  265. self.res_layers.append(
  266. self.add_sublayer(
  267. "res{}".format(stage_num),
  268. Blocks(
  269. self._in_channels[i],
  270. self._out_channels[i],
  271. count=block_nums[i],
  272. stage_num=stage_num,
  273. width=width,
  274. scales=scales,
  275. groups=groups,
  276. lr=lr_mult,
  277. norm_type=norm_type,
  278. norm_decay=norm_decay,
  279. freeze_norm=freeze_norm,
  280. dcn_v2=(i in self.dcn_v2_stages))))
  281. @property
  282. def out_shape(self):
  283. return [
  284. ShapeSpec(
  285. channels=self._out_channels[i], stride=self._out_strides[i])
  286. for i in self.return_idx
  287. ]
  288. def forward(self, inputs):
  289. x = inputs['image']
  290. res1 = self.res1(x)
  291. x = F.max_pool2d(res1, kernel_size=3, stride=2, padding=1)
  292. outs = []
  293. for idx, stage in enumerate(self.res_layers):
  294. x = stage(x)
  295. if idx == self.freeze_at:
  296. x.stop_gradient = True
  297. if idx in self.return_idx:
  298. outs.append(x)
  299. return outs
  300. @register
  301. class Res2NetC5(nn.Layer):
  302. def __init__(self, depth=50, width=26, scales=4, variant='b'):
  303. super(Res2NetC5, self).__init__()
  304. feat_in, feat_out = [1024, 2048]
  305. self.res5 = Blocks(
  306. feat_in,
  307. feat_out,
  308. count=3,
  309. stage_num=5,
  310. width=width,
  311. scales=scales,
  312. variant=variant)
  313. self.feat_out = feat_out
  314. @property
  315. def out_shape(self):
  316. return [ShapeSpec(
  317. channels=self.feat_out,
  318. stride=32, )]
  319. def forward(self, roi_feat, stage=0):
  320. y = self.res5(roi_feat)
  321. return y