atss_assigner.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from ppdet.core.workspace import register
  22. from ..ops import iou_similarity
  23. from ..bbox_utils import iou_similarity as batch_iou_similarity
  24. from ..bbox_utils import bbox_center
  25. from .utils import (check_points_inside_bboxes, compute_max_iou_anchor,
  26. compute_max_iou_gt)
  27. __all__ = ['ATSSAssigner']
  28. @register
  29. class ATSSAssigner(nn.Layer):
  30. """Bridging the Gap Between Anchor-based and Anchor-free Detection
  31. via Adaptive Training Sample Selection
  32. """
  33. __shared__ = ['num_classes']
  34. def __init__(self,
  35. topk=9,
  36. num_classes=80,
  37. force_gt_matching=False,
  38. eps=1e-9):
  39. super(ATSSAssigner, self).__init__()
  40. self.topk = topk
  41. self.num_classes = num_classes
  42. self.force_gt_matching = force_gt_matching
  43. self.eps = eps
  44. def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list,
  45. pad_gt_mask):
  46. pad_gt_mask = pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool)
  47. gt2anchor_distances_list = paddle.split(
  48. gt2anchor_distances, num_anchors_list, axis=-1)
  49. num_anchors_index = np.cumsum(num_anchors_list).tolist()
  50. num_anchors_index = [0, ] + num_anchors_index[:-1]
  51. is_in_topk_list = []
  52. topk_idxs_list = []
  53. for distances, anchors_index in zip(gt2anchor_distances_list,
  54. num_anchors_index):
  55. num_anchors = distances.shape[-1]
  56. topk_metrics, topk_idxs = paddle.topk(
  57. distances, self.topk, axis=-1, largest=False)
  58. topk_idxs_list.append(topk_idxs + anchors_index)
  59. topk_idxs = paddle.where(pad_gt_mask, topk_idxs,
  60. paddle.zeros_like(topk_idxs))
  61. is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2)
  62. is_in_topk = paddle.where(is_in_topk > 1,
  63. paddle.zeros_like(is_in_topk), is_in_topk)
  64. is_in_topk_list.append(is_in_topk.astype(gt2anchor_distances.dtype))
  65. is_in_topk_list = paddle.concat(is_in_topk_list, axis=-1)
  66. topk_idxs_list = paddle.concat(topk_idxs_list, axis=-1)
  67. return is_in_topk_list, topk_idxs_list
  68. @paddle.no_grad()
  69. def forward(self,
  70. anchor_bboxes,
  71. num_anchors_list,
  72. gt_labels,
  73. gt_bboxes,
  74. pad_gt_mask,
  75. bg_index,
  76. gt_scores=None,
  77. pred_bboxes=None):
  78. r"""This code is based on
  79. https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
  80. The assignment is done in following steps
  81. 1. compute iou between all bbox (bbox of all pyramid levels) and gt
  82. 2. compute center distance between all bbox and gt
  83. 3. on each pyramid level, for each gt, select k bbox whose center
  84. are closest to the gt center, so we total select k*l bbox as
  85. candidates for each gt
  86. 4. get corresponding iou for the these candidates, and compute the
  87. mean and std, set mean + std as the iou threshold
  88. 5. select these candidates whose iou are greater than or equal to
  89. the threshold as positive
  90. 6. limit the positive sample's center in gt
  91. 7. if an anchor box is assigned to multiple gts, the one with the
  92. highest iou will be selected.
  93. Args:
  94. anchor_bboxes (Tensor, float32): pre-defined anchors, shape(L, 4),
  95. "xmin, xmax, ymin, ymax" format
  96. num_anchors_list (List): num of anchors in each level
  97. gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
  98. gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 4)
  99. pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
  100. bg_index (int): background index
  101. gt_scores (Tensor|None, float32) Score of gt_bboxes,
  102. shape(B, n, 1), if None, then it will initialize with one_hot label
  103. pred_bboxes (Tensor, float32, optional): predicted bounding boxes, shape(B, L, 4)
  104. Returns:
  105. assigned_labels (Tensor): (B, L)
  106. assigned_bboxes (Tensor): (B, L, 4)
  107. assigned_scores (Tensor): (B, L, C), if pred_bboxes is not None, then output ious
  108. """
  109. assert gt_labels.ndim == gt_bboxes.ndim and \
  110. gt_bboxes.ndim == 3
  111. num_anchors, _ = anchor_bboxes.shape
  112. batch_size, num_max_boxes, _ = gt_bboxes.shape
  113. # negative batch
  114. if num_max_boxes == 0:
  115. assigned_labels = paddle.full(
  116. [batch_size, num_anchors], bg_index, dtype=gt_labels.dtype)
  117. assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
  118. assigned_scores = paddle.zeros(
  119. [batch_size, num_anchors, self.num_classes])
  120. return assigned_labels, assigned_bboxes, assigned_scores
  121. # 1. compute iou between gt and anchor bbox, [B, n, L]
  122. ious = iou_similarity(gt_bboxes.reshape([-1, 4]), anchor_bboxes)
  123. ious = ious.reshape([batch_size, -1, num_anchors])
  124. # 2. compute center distance between all anchors and gt, [B, n, L]
  125. gt_centers = bbox_center(gt_bboxes.reshape([-1, 4])).unsqueeze(1)
  126. anchor_centers = bbox_center(anchor_bboxes)
  127. gt2anchor_distances = (gt_centers - anchor_centers.unsqueeze(0)) \
  128. .norm(2, axis=-1).reshape([batch_size, -1, num_anchors])
  129. # 3. on each pyramid level, selecting topk closest candidates
  130. # based on the center distance, [B, n, L]
  131. is_in_topk, topk_idxs = self._gather_topk_pyramid(
  132. gt2anchor_distances, num_anchors_list, pad_gt_mask)
  133. # 4. get corresponding iou for the these candidates, and compute the
  134. # mean and std, 5. set mean + std as the iou threshold
  135. iou_candidates = ious * is_in_topk
  136. iou_threshold = paddle.index_sample(
  137. iou_candidates.flatten(stop_axis=-2),
  138. topk_idxs.flatten(stop_axis=-2))
  139. iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1])
  140. iou_threshold = iou_threshold.mean(axis=-1, keepdim=True) + \
  141. iou_threshold.std(axis=-1, keepdim=True)
  142. is_in_topk = paddle.where(
  143. iou_candidates > iou_threshold.tile([1, 1, num_anchors]),
  144. is_in_topk, paddle.zeros_like(is_in_topk))
  145. # 6. check the positive sample's center in gt, [B, n, L]
  146. is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes)
  147. # select positive sample, [B, n, L]
  148. mask_positive = is_in_topk * is_in_gts * pad_gt_mask
  149. # 7. if an anchor box is assigned to multiple gts,
  150. # the one with the highest iou will be selected.
  151. mask_positive_sum = mask_positive.sum(axis=-2)
  152. if mask_positive_sum.max() > 1:
  153. mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
  154. [1, num_max_boxes, 1])
  155. is_max_iou = compute_max_iou_anchor(ious)
  156. mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
  157. mask_positive)
  158. mask_positive_sum = mask_positive.sum(axis=-2)
  159. # 8. make sure every gt_bbox matches the anchor
  160. if self.force_gt_matching:
  161. is_max_iou = compute_max_iou_gt(ious) * pad_gt_mask
  162. mask_max_iou = (is_max_iou.sum(-2, keepdim=True) == 1).tile(
  163. [1, num_max_boxes, 1])
  164. mask_positive = paddle.where(mask_max_iou, is_max_iou,
  165. mask_positive)
  166. mask_positive_sum = mask_positive.sum(axis=-2)
  167. assigned_gt_index = mask_positive.argmax(axis=-2)
  168. # assigned target
  169. batch_ind = paddle.arange(
  170. end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
  171. assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
  172. assigned_labels = paddle.gather(
  173. gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
  174. assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
  175. assigned_labels = paddle.where(
  176. mask_positive_sum > 0, assigned_labels,
  177. paddle.full_like(assigned_labels, bg_index))
  178. assigned_bboxes = paddle.gather(
  179. gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
  180. assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
  181. assigned_scores = F.one_hot(assigned_labels, self.num_classes + 1)
  182. ind = list(range(self.num_classes + 1))
  183. ind.remove(bg_index)
  184. assigned_scores = paddle.index_select(
  185. assigned_scores, paddle.to_tensor(ind), axis=-1)
  186. if pred_bboxes is not None:
  187. # assigned iou
  188. ious = batch_iou_similarity(gt_bboxes, pred_bboxes) * mask_positive
  189. ious = ious.max(axis=-2).unsqueeze(-1)
  190. assigned_scores *= ious
  191. elif gt_scores is not None:
  192. gather_scores = paddle.gather(
  193. gt_scores.flatten(), assigned_gt_index.flatten(), axis=0)
  194. gather_scores = gather_scores.reshape([batch_size, num_anchors])
  195. gather_scores = paddle.where(mask_positive_sum > 0, gather_scores,
  196. paddle.zeros_like(gather_scores))
  197. assigned_scores *= gather_scores.unsqueeze(-1)
  198. return assigned_labels, assigned_bboxes, assigned_scores