senet.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle.nn as nn
  15. from ppdet.core.workspace import register, serializable
  16. from .resnet import ResNet, Blocks, BasicBlock, BottleNeck
  17. __all__ = ['SENet', 'SERes5Head']
  18. @register
  19. @serializable
  20. class SENet(ResNet):
  21. __shared__ = ['norm_type']
  22. def __init__(self,
  23. depth=50,
  24. variant='b',
  25. lr_mult_list=[1.0, 1.0, 1.0, 1.0],
  26. groups=1,
  27. base_width=64,
  28. norm_type='bn',
  29. norm_decay=0,
  30. freeze_norm=True,
  31. freeze_at=0,
  32. return_idx=[0, 1, 2, 3],
  33. dcn_v2_stages=[-1],
  34. std_senet=True,
  35. num_stages=4):
  36. """
  37. Squeeze-and-Excitation Networks, see https://arxiv.org/abs/1709.01507
  38. Args:
  39. depth (int): SENet depth, should be 50, 101, 152
  40. variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
  41. lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
  42. lower learning rate ratio is need for pretrained model
  43. got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
  44. groups (int): group convolution cardinality
  45. base_width (int): base width of each group convolution
  46. norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
  47. norm_decay (float): weight decay for normalization layer weights
  48. freeze_norm (bool): freeze normalization layers
  49. freeze_at (int): freeze the backbone at which stage
  50. return_idx (list): index of the stages whose feature maps are returned
  51. dcn_v2_stages (list): index of stages who select deformable conv v2
  52. std_senet (bool): whether use senet, default True
  53. num_stages (int): total num of stages
  54. """
  55. super(SENet, self).__init__(
  56. depth=depth,
  57. variant=variant,
  58. lr_mult_list=lr_mult_list,
  59. ch_in=128,
  60. groups=groups,
  61. base_width=base_width,
  62. norm_type=norm_type,
  63. norm_decay=norm_decay,
  64. freeze_norm=freeze_norm,
  65. freeze_at=freeze_at,
  66. return_idx=return_idx,
  67. dcn_v2_stages=dcn_v2_stages,
  68. std_senet=std_senet,
  69. num_stages=num_stages)
  70. @register
  71. class SERes5Head(nn.Layer):
  72. def __init__(self,
  73. depth=50,
  74. variant='b',
  75. lr_mult=1.0,
  76. groups=1,
  77. base_width=64,
  78. norm_type='bn',
  79. norm_decay=0,
  80. dcn_v2=False,
  81. freeze_norm=False,
  82. std_senet=True):
  83. """
  84. SERes5Head layer
  85. Args:
  86. depth (int): SENet depth, should be 50, 101, 152
  87. variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
  88. lr_mult (list): learning rate ratio of SERes5Head, default as 1.0.
  89. groups (int): group convolution cardinality
  90. base_width (int): base width of each group convolution
  91. norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
  92. norm_decay (float): weight decay for normalization layer weights
  93. dcn_v2_stages (list): index of stages who select deformable conv v2
  94. std_senet (bool): whether use senet, default True
  95. """
  96. super(SERes5Head, self).__init__()
  97. ch_out = 512
  98. ch_in = 256 if depth < 50 else 1024
  99. na = NameAdapter(self)
  100. block = BottleNeck if depth >= 50 else BasicBlock
  101. self.res5 = Blocks(
  102. block,
  103. ch_in,
  104. ch_out,
  105. count=3,
  106. name_adapter=na,
  107. stage_num=5,
  108. variant=variant,
  109. groups=groups,
  110. base_width=base_width,
  111. lr=lr_mult,
  112. norm_type=norm_type,
  113. norm_decay=norm_decay,
  114. freeze_norm=freeze_norm,
  115. dcn_v2=dcn_v2,
  116. std_senet=std_senet)
  117. self.ch_out = ch_out * block.expansion
  118. @property
  119. def out_shape(self):
  120. return [ShapeSpec(
  121. channels=self.ch_out,
  122. stride=16, )]
  123. def forward(self, roi_feat):
  124. y = self.res5(roi_feat)
  125. return y