resnext.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from ppdet.core.workspace import register, serializable
  18. from .resnet import ResNet
  19. __all__ = ['ResNeXt']
  20. @register
  21. @serializable
  22. class ResNeXt(ResNet):
  23. """
  24. ResNeXt, see https://arxiv.org/abs/1611.05431
  25. Args:
  26. depth (int): network depth, should be 50, 101, 152.
  27. groups (int): group convolution cardinality
  28. group_width (int): width of each group convolution
  29. freeze_at (int): freeze the backbone at which stage
  30. norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
  31. freeze_norm (bool): freeze normalization layers
  32. norm_decay (float): weight decay for normalization layer weights
  33. variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
  34. feature_maps (list): index of the stages whose feature maps are returned
  35. dcn_v2_stages (list): index of stages who select deformable conv v2
  36. """
  37. def __init__(self,
  38. depth=50,
  39. groups=64,
  40. group_width=4,
  41. freeze_at=2,
  42. norm_type='affine_channel',
  43. freeze_norm=True,
  44. norm_decay=True,
  45. variant='a',
  46. feature_maps=[2, 3, 4, 5],
  47. dcn_v2_stages=[],
  48. weight_prefix_name=''):
  49. assert depth in [50, 101, 152], "depth {} should be 50, 101 or 152"
  50. super(ResNeXt, self).__init__(depth, freeze_at, norm_type, freeze_norm,
  51. norm_decay, variant, feature_maps)
  52. self.depth_cfg = {
  53. 50: ([3, 4, 6, 3], self.bottleneck),
  54. 101: ([3, 4, 23, 3], self.bottleneck),
  55. 152: ([3, 8, 36, 3], self.bottleneck)
  56. }
  57. self.stage_filters = [256, 512, 1024, 2048]
  58. self.groups = groups
  59. self.group_width = group_width
  60. self._model_type = 'ResNeXt'
  61. self.dcn_v2_stages = dcn_v2_stages
  62. @register
  63. @serializable
  64. class ResNeXtC5(ResNeXt):
  65. __doc__ = ResNeXt.__doc__
  66. def __init__(self,
  67. depth=50,
  68. groups=64,
  69. group_width=4,
  70. freeze_at=2,
  71. norm_type='affine_channel',
  72. freeze_norm=True,
  73. norm_decay=True,
  74. variant='a',
  75. feature_maps=[5],
  76. weight_prefix_name=''):
  77. super(ResNeXtC5, self).__init__(depth, groups, group_width, freeze_at,
  78. norm_type, freeze_norm, norm_decay,
  79. variant, feature_maps)
  80. self.severed_head = True