roi_extractor.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from ppdet.core.workspace import register
  16. from ppdet.modeling import ops
  17. def _to_list(v):
  18. if not isinstance(v, (list, tuple)):
  19. return [v]
  20. return v
  21. @register
  22. class RoIAlign(object):
  23. """
  24. RoI Align module
  25. For more details, please refer to the document of roi_align in
  26. in https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/vision/ops.py
  27. Args:
  28. resolution (int): The output size, default 14
  29. spatial_scale (float): Multiplicative spatial scale factor to translate
  30. ROI coords from their input scale to the scale used when pooling.
  31. default 0.0625
  32. sampling_ratio (int): The number of sampling points in the interpolation
  33. grid, default 0
  34. canconical_level (int): The referring level of FPN layer with
  35. specified level. default 4
  36. canonical_size (int): The referring scale of FPN layer with
  37. specified scale. default 224
  38. start_level (int): The start level of FPN layer to extract RoI feature,
  39. default 0
  40. end_level (int): The end level of FPN layer to extract RoI feature,
  41. default 3
  42. aligned (bool): Whether to add offset to rois' coord in roi_align.
  43. default false
  44. """
  45. def __init__(self,
  46. resolution=14,
  47. spatial_scale=0.0625,
  48. sampling_ratio=0,
  49. canconical_level=4,
  50. canonical_size=224,
  51. start_level=0,
  52. end_level=3,
  53. aligned=False):
  54. super(RoIAlign, self).__init__()
  55. self.resolution = resolution
  56. self.spatial_scale = _to_list(spatial_scale)
  57. self.sampling_ratio = sampling_ratio
  58. self.canconical_level = canconical_level
  59. self.canonical_size = canonical_size
  60. self.start_level = start_level
  61. self.end_level = end_level
  62. self.aligned = aligned
  63. @classmethod
  64. def from_config(cls, cfg, input_shape):
  65. return {'spatial_scale': [1. / i.stride for i in input_shape]}
  66. def __call__(self, feats, roi, rois_num):
  67. roi = paddle.concat(roi) if len(roi) > 1 else roi[0]
  68. if len(feats) == 1:
  69. rois_feat = paddle.vision.ops.roi_align(
  70. x=feats[self.start_level],
  71. boxes=roi,
  72. boxes_num=rois_num,
  73. output_size=self.resolution,
  74. spatial_scale=self.spatial_scale[0],
  75. aligned=self.aligned)
  76. else:
  77. offset = 2
  78. k_min = self.start_level + offset
  79. k_max = self.end_level + offset
  80. rois_dist, restore_index, rois_num_dist = ops.distribute_fpn_proposals(
  81. roi,
  82. k_min,
  83. k_max,
  84. self.canconical_level,
  85. self.canonical_size,
  86. rois_num=rois_num)
  87. rois_feat_list = []
  88. for lvl in range(self.start_level, self.end_level + 1):
  89. roi_feat = paddle.vision.ops.roi_align(
  90. x=feats[lvl],
  91. boxes=rois_dist[lvl],
  92. boxes_num=rois_num_dist[lvl],
  93. output_size=self.resolution,
  94. spatial_scale=self.spatial_scale[lvl],
  95. sampling_ratio=self.sampling_ratio,
  96. aligned=self.aligned)
  97. rois_feat_list.append(roi_feat)
  98. rois_feat_shuffle = paddle.concat(rois_feat_list)
  99. rois_feat = paddle.gather(rois_feat_shuffle, restore_index)
  100. return rois_feat