resnet.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from numbers import Integral
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from ppdet.core.workspace import register, serializable
  20. from paddle.regularizer import L2Decay
  21. from paddle.nn.initializer import Uniform
  22. from paddle import ParamAttr
  23. from paddle.nn.initializer import Constant
  24. from paddle.vision.ops import DeformConv2D
  25. from .name_adapter import NameAdapter
  26. from ..shape_spec import ShapeSpec
  27. __all__ = ['ResNet', 'Res5Head', 'Blocks', 'BasicBlock', 'BottleNeck']
  28. ResNet_cfg = {
  29. 18: [2, 2, 2, 2],
  30. 34: [3, 4, 6, 3],
  31. 50: [3, 4, 6, 3],
  32. 101: [3, 4, 23, 3],
  33. 152: [3, 8, 36, 3],
  34. }
  35. class ConvNormLayer(nn.Layer):
  36. def __init__(self,
  37. ch_in,
  38. ch_out,
  39. filter_size,
  40. stride,
  41. groups=1,
  42. act=None,
  43. norm_type='bn',
  44. norm_decay=0.,
  45. freeze_norm=True,
  46. lr=1.0,
  47. dcn_v2=False):
  48. super(ConvNormLayer, self).__init__()
  49. assert norm_type in ['bn', 'sync_bn']
  50. self.norm_type = norm_type
  51. self.act = act
  52. self.dcn_v2 = dcn_v2
  53. if not self.dcn_v2:
  54. self.conv = nn.Conv2D(
  55. in_channels=ch_in,
  56. out_channels=ch_out,
  57. kernel_size=filter_size,
  58. stride=stride,
  59. padding=(filter_size - 1) // 2,
  60. groups=groups,
  61. weight_attr=ParamAttr(learning_rate=lr),
  62. bias_attr=False)
  63. else:
  64. self.offset_channel = 2 * filter_size**2
  65. self.mask_channel = filter_size**2
  66. self.conv_offset = nn.Conv2D(
  67. in_channels=ch_in,
  68. out_channels=3 * filter_size**2,
  69. kernel_size=filter_size,
  70. stride=stride,
  71. padding=(filter_size - 1) // 2,
  72. weight_attr=ParamAttr(initializer=Constant(0.)),
  73. bias_attr=ParamAttr(initializer=Constant(0.)))
  74. self.conv = DeformConv2D(
  75. in_channels=ch_in,
  76. out_channels=ch_out,
  77. kernel_size=filter_size,
  78. stride=stride,
  79. padding=(filter_size - 1) // 2,
  80. dilation=1,
  81. groups=groups,
  82. weight_attr=ParamAttr(learning_rate=lr),
  83. bias_attr=False)
  84. norm_lr = 0. if freeze_norm else lr
  85. param_attr = ParamAttr(
  86. learning_rate=norm_lr,
  87. regularizer=L2Decay(norm_decay),
  88. trainable=False if freeze_norm else True)
  89. bias_attr = ParamAttr(
  90. learning_rate=norm_lr,
  91. regularizer=L2Decay(norm_decay),
  92. trainable=False if freeze_norm else True)
  93. global_stats = True if freeze_norm else None
  94. if norm_type in ['sync_bn', 'bn']:
  95. self.norm = nn.BatchNorm2D(
  96. ch_out,
  97. weight_attr=param_attr,
  98. bias_attr=bias_attr,
  99. use_global_stats=global_stats)
  100. norm_params = self.norm.parameters()
  101. if freeze_norm:
  102. for param in norm_params:
  103. param.stop_gradient = True
  104. def forward(self, inputs):
  105. if not self.dcn_v2:
  106. out = self.conv(inputs)
  107. else:
  108. offset_mask = self.conv_offset(inputs)
  109. offset, mask = paddle.split(
  110. offset_mask,
  111. num_or_sections=[self.offset_channel, self.mask_channel],
  112. axis=1)
  113. mask = F.sigmoid(mask)
  114. out = self.conv(inputs, offset, mask=mask)
  115. if self.norm_type in ['bn', 'sync_bn']:
  116. out = self.norm(out)
  117. if self.act:
  118. out = getattr(F, self.act)(out)
  119. return out
  120. class SELayer(nn.Layer):
  121. def __init__(self, ch, reduction_ratio=16):
  122. super(SELayer, self).__init__()
  123. self.pool = nn.AdaptiveAvgPool2D(1)
  124. stdv = 1.0 / math.sqrt(ch)
  125. c_ = ch // reduction_ratio
  126. self.squeeze = nn.Linear(
  127. ch,
  128. c_,
  129. weight_attr=paddle.ParamAttr(initializer=Uniform(-stdv, stdv)),
  130. bias_attr=True)
  131. stdv = 1.0 / math.sqrt(c_)
  132. self.extract = nn.Linear(
  133. c_,
  134. ch,
  135. weight_attr=paddle.ParamAttr(initializer=Uniform(-stdv, stdv)),
  136. bias_attr=True)
  137. def forward(self, inputs):
  138. out = self.pool(inputs)
  139. out = paddle.squeeze(out, axis=[2, 3])
  140. out = self.squeeze(out)
  141. out = F.relu(out)
  142. out = self.extract(out)
  143. out = F.sigmoid(out)
  144. out = paddle.unsqueeze(out, axis=[2, 3])
  145. scale = out * inputs
  146. return scale
  147. class BasicBlock(nn.Layer):
  148. expansion = 1
  149. def __init__(self,
  150. ch_in,
  151. ch_out,
  152. stride,
  153. shortcut,
  154. variant='b',
  155. groups=1,
  156. base_width=64,
  157. lr=1.0,
  158. norm_type='bn',
  159. norm_decay=0.,
  160. freeze_norm=True,
  161. dcn_v2=False,
  162. std_senet=False):
  163. super(BasicBlock, self).__init__()
  164. assert groups == 1 and base_width == 64, 'BasicBlock only supports groups=1 and base_width=64'
  165. self.shortcut = shortcut
  166. if not shortcut:
  167. if variant == 'd' and stride == 2:
  168. self.short = nn.Sequential()
  169. self.short.add_sublayer(
  170. 'pool',
  171. nn.AvgPool2D(
  172. kernel_size=2, stride=2, padding=0, ceil_mode=True))
  173. self.short.add_sublayer(
  174. 'conv',
  175. ConvNormLayer(
  176. ch_in=ch_in,
  177. ch_out=ch_out,
  178. filter_size=1,
  179. stride=1,
  180. norm_type=norm_type,
  181. norm_decay=norm_decay,
  182. freeze_norm=freeze_norm,
  183. lr=lr))
  184. else:
  185. self.short = ConvNormLayer(
  186. ch_in=ch_in,
  187. ch_out=ch_out,
  188. filter_size=1,
  189. stride=stride,
  190. norm_type=norm_type,
  191. norm_decay=norm_decay,
  192. freeze_norm=freeze_norm,
  193. lr=lr)
  194. self.branch2a = ConvNormLayer(
  195. ch_in=ch_in,
  196. ch_out=ch_out,
  197. filter_size=3,
  198. stride=stride,
  199. act='relu',
  200. norm_type=norm_type,
  201. norm_decay=norm_decay,
  202. freeze_norm=freeze_norm,
  203. lr=lr)
  204. self.branch2b = ConvNormLayer(
  205. ch_in=ch_out,
  206. ch_out=ch_out,
  207. filter_size=3,
  208. stride=1,
  209. act=None,
  210. norm_type=norm_type,
  211. norm_decay=norm_decay,
  212. freeze_norm=freeze_norm,
  213. lr=lr,
  214. dcn_v2=dcn_v2)
  215. self.std_senet = std_senet
  216. if self.std_senet:
  217. self.se = SELayer(ch_out)
  218. def forward(self, inputs):
  219. out = self.branch2a(inputs)
  220. out = self.branch2b(out)
  221. if self.std_senet:
  222. out = self.se(out)
  223. if self.shortcut:
  224. short = inputs
  225. else:
  226. short = self.short(inputs)
  227. out = paddle.add(x=out, y=short)
  228. out = F.relu(out)
  229. return out
  230. class BottleNeck(nn.Layer):
  231. expansion = 4
  232. def __init__(self,
  233. ch_in,
  234. ch_out,
  235. stride,
  236. shortcut,
  237. variant='b',
  238. groups=1,
  239. base_width=4,
  240. lr=1.0,
  241. norm_type='bn',
  242. norm_decay=0.,
  243. freeze_norm=True,
  244. dcn_v2=False,
  245. std_senet=False):
  246. super(BottleNeck, self).__init__()
  247. if variant == 'a':
  248. stride1, stride2 = stride, 1
  249. else:
  250. stride1, stride2 = 1, stride
  251. # ResNeXt
  252. width = int(ch_out * (base_width / 64.)) * groups
  253. self.shortcut = shortcut
  254. if not shortcut:
  255. if variant == 'd' and stride == 2:
  256. self.short = nn.Sequential()
  257. self.short.add_sublayer(
  258. 'pool',
  259. nn.AvgPool2D(
  260. kernel_size=2, stride=2, padding=0, ceil_mode=True))
  261. self.short.add_sublayer(
  262. 'conv',
  263. ConvNormLayer(
  264. ch_in=ch_in,
  265. ch_out=ch_out * self.expansion,
  266. filter_size=1,
  267. stride=1,
  268. norm_type=norm_type,
  269. norm_decay=norm_decay,
  270. freeze_norm=freeze_norm,
  271. lr=lr))
  272. else:
  273. self.short = ConvNormLayer(
  274. ch_in=ch_in,
  275. ch_out=ch_out * self.expansion,
  276. filter_size=1,
  277. stride=stride,
  278. norm_type=norm_type,
  279. norm_decay=norm_decay,
  280. freeze_norm=freeze_norm,
  281. lr=lr)
  282. self.branch2a = ConvNormLayer(
  283. ch_in=ch_in,
  284. ch_out=width,
  285. filter_size=1,
  286. stride=stride1,
  287. groups=1,
  288. act='relu',
  289. norm_type=norm_type,
  290. norm_decay=norm_decay,
  291. freeze_norm=freeze_norm,
  292. lr=lr)
  293. self.branch2b = ConvNormLayer(
  294. ch_in=width,
  295. ch_out=width,
  296. filter_size=3,
  297. stride=stride2,
  298. groups=groups,
  299. act='relu',
  300. norm_type=norm_type,
  301. norm_decay=norm_decay,
  302. freeze_norm=freeze_norm,
  303. lr=lr,
  304. dcn_v2=dcn_v2)
  305. self.branch2c = ConvNormLayer(
  306. ch_in=width,
  307. ch_out=ch_out * self.expansion,
  308. filter_size=1,
  309. stride=1,
  310. groups=1,
  311. norm_type=norm_type,
  312. norm_decay=norm_decay,
  313. freeze_norm=freeze_norm,
  314. lr=lr)
  315. self.std_senet = std_senet
  316. if self.std_senet:
  317. self.se = SELayer(ch_out * self.expansion)
  318. def forward(self, inputs):
  319. out = self.branch2a(inputs)
  320. out = self.branch2b(out)
  321. out = self.branch2c(out)
  322. if self.std_senet:
  323. out = self.se(out)
  324. if self.shortcut:
  325. short = inputs
  326. else:
  327. short = self.short(inputs)
  328. out = paddle.add(x=out, y=short)
  329. out = F.relu(out)
  330. return out
  331. class Blocks(nn.Layer):
  332. def __init__(self,
  333. block,
  334. ch_in,
  335. ch_out,
  336. count,
  337. name_adapter,
  338. stage_num,
  339. variant='b',
  340. groups=1,
  341. base_width=64,
  342. lr=1.0,
  343. norm_type='bn',
  344. norm_decay=0.,
  345. freeze_norm=True,
  346. dcn_v2=False,
  347. std_senet=False):
  348. super(Blocks, self).__init__()
  349. self.blocks = []
  350. for i in range(count):
  351. conv_name = name_adapter.fix_layer_warp_name(stage_num, count, i)
  352. layer = self.add_sublayer(
  353. conv_name,
  354. block(
  355. ch_in=ch_in,
  356. ch_out=ch_out,
  357. stride=2 if i == 0 and stage_num != 2 else 1,
  358. shortcut=False if i == 0 else True,
  359. variant=variant,
  360. groups=groups,
  361. base_width=base_width,
  362. lr=lr,
  363. norm_type=norm_type,
  364. norm_decay=norm_decay,
  365. freeze_norm=freeze_norm,
  366. dcn_v2=dcn_v2,
  367. std_senet=std_senet))
  368. self.blocks.append(layer)
  369. if i == 0:
  370. ch_in = ch_out * block.expansion
  371. def forward(self, inputs):
  372. block_out = inputs
  373. for block in self.blocks:
  374. block_out = block(block_out)
  375. return block_out
  376. @register
  377. @serializable
  378. class ResNet(nn.Layer):
  379. __shared__ = ['norm_type']
  380. def __init__(self,
  381. depth=50,
  382. ch_in=64,
  383. variant='b',
  384. lr_mult_list=[1.0, 1.0, 1.0, 1.0],
  385. groups=1,
  386. base_width=64,
  387. norm_type='bn',
  388. norm_decay=0,
  389. freeze_norm=True,
  390. freeze_at=0,
  391. return_idx=[0, 1, 2, 3],
  392. dcn_v2_stages=[-1],
  393. num_stages=4,
  394. std_senet=False):
  395. """
  396. Residual Network, see https://arxiv.org/abs/1512.03385
  397. Args:
  398. depth (int): ResNet depth, should be 18, 34, 50, 101, 152.
  399. ch_in (int): output channel of first stage, default 64
  400. variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
  401. lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
  402. lower learning rate ratio is need for pretrained model
  403. got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
  404. groups (int): group convolution cardinality
  405. base_width (int): base width of each group convolution
  406. norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
  407. norm_decay (float): weight decay for normalization layer weights
  408. freeze_norm (bool): freeze normalization layers
  409. freeze_at (int): freeze the backbone at which stage
  410. return_idx (list): index of the stages whose feature maps are returned
  411. dcn_v2_stages (list): index of stages who select deformable conv v2
  412. num_stages (int): total num of stages
  413. std_senet (bool): whether use senet, default True
  414. """
  415. super(ResNet, self).__init__()
  416. self._model_type = 'ResNet' if groups == 1 else 'ResNeXt'
  417. assert num_stages >= 1 and num_stages <= 4
  418. self.depth = depth
  419. self.variant = variant
  420. self.groups = groups
  421. self.base_width = base_width
  422. self.norm_type = norm_type
  423. self.norm_decay = norm_decay
  424. self.freeze_norm = freeze_norm
  425. self.freeze_at = freeze_at
  426. if isinstance(return_idx, Integral):
  427. return_idx = [return_idx]
  428. assert max(return_idx) < num_stages, \
  429. 'the maximum return index must smaller than num_stages, ' \
  430. 'but received maximum return index is {} and num_stages ' \
  431. 'is {}'.format(max(return_idx), num_stages)
  432. self.return_idx = return_idx
  433. self.num_stages = num_stages
  434. assert len(lr_mult_list) == 4, \
  435. "lr_mult_list length must be 4 but got {}".format(len(lr_mult_list))
  436. if isinstance(dcn_v2_stages, Integral):
  437. dcn_v2_stages = [dcn_v2_stages]
  438. assert max(dcn_v2_stages) < num_stages
  439. if isinstance(dcn_v2_stages, Integral):
  440. dcn_v2_stages = [dcn_v2_stages]
  441. assert max(dcn_v2_stages) < num_stages
  442. self.dcn_v2_stages = dcn_v2_stages
  443. block_nums = ResNet_cfg[depth]
  444. na = NameAdapter(self)
  445. conv1_name = na.fix_c1_stage_name()
  446. if variant in ['c', 'd']:
  447. conv_def = [
  448. [3, ch_in // 2, 3, 2, "conv1_1"],
  449. [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
  450. [ch_in // 2, ch_in, 3, 1, "conv1_3"],
  451. ]
  452. else:
  453. conv_def = [[3, ch_in, 7, 2, conv1_name]]
  454. self.conv1 = nn.Sequential()
  455. for (c_in, c_out, k, s, _name) in conv_def:
  456. self.conv1.add_sublayer(
  457. _name,
  458. ConvNormLayer(
  459. ch_in=c_in,
  460. ch_out=c_out,
  461. filter_size=k,
  462. stride=s,
  463. groups=1,
  464. act='relu',
  465. norm_type=norm_type,
  466. norm_decay=norm_decay,
  467. freeze_norm=freeze_norm,
  468. lr=1.0))
  469. self.ch_in = ch_in
  470. ch_out_list = [64, 128, 256, 512]
  471. block = BottleNeck if depth >= 50 else BasicBlock
  472. self._out_channels = [block.expansion * v for v in ch_out_list]
  473. self._out_strides = [4, 8, 16, 32]
  474. self.res_layers = []
  475. for i in range(num_stages):
  476. lr_mult = lr_mult_list[i]
  477. stage_num = i + 2
  478. res_name = "res{}".format(stage_num)
  479. res_layer = self.add_sublayer(
  480. res_name,
  481. Blocks(
  482. block,
  483. self.ch_in,
  484. ch_out_list[i],
  485. count=block_nums[i],
  486. name_adapter=na,
  487. stage_num=stage_num,
  488. variant=variant,
  489. groups=groups,
  490. base_width=base_width,
  491. lr=lr_mult,
  492. norm_type=norm_type,
  493. norm_decay=norm_decay,
  494. freeze_norm=freeze_norm,
  495. dcn_v2=(i in self.dcn_v2_stages),
  496. std_senet=std_senet))
  497. self.res_layers.append(res_layer)
  498. self.ch_in = self._out_channels[i]
  499. if freeze_at >= 0:
  500. self._freeze_parameters(self.conv1)
  501. for i in range(min(freeze_at + 1, num_stages)):
  502. self._freeze_parameters(self.res_layers[i])
  503. def _freeze_parameters(self, m):
  504. for p in m.parameters():
  505. p.stop_gradient = True
  506. @property
  507. def out_shape(self):
  508. return [
  509. ShapeSpec(
  510. channels=self._out_channels[i], stride=self._out_strides[i])
  511. for i in self.return_idx
  512. ]
  513. def forward(self, inputs):
  514. x = inputs['image']
  515. conv1 = self.conv1(x)
  516. x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1)
  517. outs = []
  518. for idx, stage in enumerate(self.res_layers):
  519. x = stage(x)
  520. if idx in self.return_idx:
  521. outs.append(x)
  522. return outs
  523. @register
  524. class Res5Head(nn.Layer):
  525. def __init__(self, depth=50):
  526. super(Res5Head, self).__init__()
  527. feat_in, feat_out = [1024, 512]
  528. if depth < 50:
  529. feat_in = 256
  530. na = NameAdapter(self)
  531. block = BottleNeck if depth >= 50 else BasicBlock
  532. self.res5 = Blocks(
  533. block, feat_in, feat_out, count=3, name_adapter=na, stage_num=5)
  534. self.feat_out = feat_out if depth < 50 else feat_out * 4
  535. @property
  536. def out_shape(self):
  537. return [ShapeSpec(
  538. channels=self.feat_out,
  539. stride=16, )]
  540. def forward(self, roi_feat, stage=0):
  541. y = self.res5(roi_feat)
  542. return y