task_aligned_assigner.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register
  21. from ..bbox_utils import iou_similarity
  22. from .utils import (gather_topk_anchors, check_points_inside_bboxes,
  23. compute_max_iou_anchor)
  24. __all__ = ['TaskAlignedAssigner']
  25. @register
  26. class TaskAlignedAssigner(nn.Layer):
  27. """TOOD: Task-aligned One-stage Object Detection
  28. """
  29. def __init__(self, topk=13, alpha=1.0, beta=6.0, eps=1e-9):
  30. super(TaskAlignedAssigner, self).__init__()
  31. self.topk = topk
  32. self.alpha = alpha
  33. self.beta = beta
  34. self.eps = eps
  35. @paddle.no_grad()
  36. def forward(self,
  37. pred_scores,
  38. pred_bboxes,
  39. anchor_points,
  40. num_anchors_list,
  41. gt_labels,
  42. gt_bboxes,
  43. pad_gt_mask,
  44. bg_index,
  45. gt_scores=None):
  46. r"""This code is based on
  47. https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/task_aligned_assigner.py
  48. The assignment is done in following steps
  49. 1. compute alignment metric between all bbox (bbox of all pyramid levels) and gt
  50. 2. select top-k bbox as candidates for each gt
  51. 3. limit the positive sample's center in gt (because the anchor-free detector
  52. only can predict positive distance)
  53. 4. if an anchor box is assigned to multiple gts, the one with the
  54. highest iou will be selected.
  55. Args:
  56. pred_scores (Tensor, float32): predicted class probability, shape(B, L, C)
  57. pred_bboxes (Tensor, float32): predicted bounding boxes, shape(B, L, 4)
  58. anchor_points (Tensor, float32): pre-defined anchors, shape(L, 2), "cxcy" format
  59. num_anchors_list (List): num of anchors in each level, shape(L)
  60. gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
  61. gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 4)
  62. pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
  63. bg_index (int): background index
  64. gt_scores (Tensor|None, float32) Score of gt_bboxes, shape(B, n, 1)
  65. Returns:
  66. assigned_labels (Tensor): (B, L)
  67. assigned_bboxes (Tensor): (B, L, 4)
  68. assigned_scores (Tensor): (B, L, C)
  69. """
  70. assert pred_scores.ndim == pred_bboxes.ndim
  71. assert gt_labels.ndim == gt_bboxes.ndim and \
  72. gt_bboxes.ndim == 3
  73. batch_size, num_anchors, num_classes = pred_scores.shape
  74. _, num_max_boxes, _ = gt_bboxes.shape
  75. # negative batch
  76. if num_max_boxes == 0:
  77. assigned_labels = paddle.full(
  78. [batch_size, num_anchors], bg_index, dtype=gt_labels.dtype)
  79. assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
  80. assigned_scores = paddle.zeros(
  81. [batch_size, num_anchors, num_classes])
  82. return assigned_labels, assigned_bboxes, assigned_scores
  83. # compute iou between gt and pred bbox, [B, n, L]
  84. ious = iou_similarity(gt_bboxes, pred_bboxes)
  85. # gather pred bboxes class score
  86. pred_scores = pred_scores.transpose([0, 2, 1])
  87. batch_ind = paddle.arange(
  88. end=batch_size, dtype=gt_labels.dtype).unsqueeze(-1)
  89. gt_labels_ind = paddle.stack(
  90. [batch_ind.tile([1, num_max_boxes]), gt_labels.squeeze(-1)],
  91. axis=-1)
  92. bbox_cls_scores = paddle.gather_nd(pred_scores, gt_labels_ind)
  93. # compute alignment metrics, [B, n, L]
  94. alignment_metrics = bbox_cls_scores.pow(self.alpha) * ious.pow(
  95. self.beta)
  96. # check the positive sample's center in gt, [B, n, L]
  97. is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes)
  98. # select topk largest alignment metrics pred bbox as candidates
  99. # for each gt, [B, n, L]
  100. is_in_topk = gather_topk_anchors(
  101. alignment_metrics * is_in_gts,
  102. self.topk,
  103. topk_mask=pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool))
  104. # select positive sample, [B, n, L]
  105. mask_positive = is_in_topk * is_in_gts * pad_gt_mask
  106. # if an anchor box is assigned to multiple gts,
  107. # the one with the highest iou will be selected, [B, n, L]
  108. mask_positive_sum = mask_positive.sum(axis=-2)
  109. if mask_positive_sum.max() > 1:
  110. mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile(
  111. [1, num_max_boxes, 1])
  112. is_max_iou = compute_max_iou_anchor(ious)
  113. mask_positive = paddle.where(mask_multiple_gts, is_max_iou,
  114. mask_positive)
  115. mask_positive_sum = mask_positive.sum(axis=-2)
  116. assigned_gt_index = mask_positive.argmax(axis=-2)
  117. # assigned target
  118. assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
  119. assigned_labels = paddle.gather(
  120. gt_labels.flatten(), assigned_gt_index.flatten(), axis=0)
  121. assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
  122. assigned_labels = paddle.where(
  123. mask_positive_sum > 0, assigned_labels,
  124. paddle.full_like(assigned_labels, bg_index))
  125. assigned_bboxes = paddle.gather(
  126. gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
  127. assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
  128. assigned_scores = F.one_hot(assigned_labels, num_classes + 1)
  129. ind = list(range(num_classes + 1))
  130. ind.remove(bg_index)
  131. assigned_scores = paddle.index_select(
  132. assigned_scores, paddle.to_tensor(ind), axis=-1)
  133. # rescale alignment metrics
  134. alignment_metrics *= mask_positive
  135. max_metrics_per_instance = alignment_metrics.max(axis=-1, keepdim=True)
  136. max_ious_per_instance = (ious * mask_positive).max(axis=-1,
  137. keepdim=True)
  138. alignment_metrics = alignment_metrics / (
  139. max_metrics_per_instance + self.eps) * max_ious_per_instance
  140. alignment_metrics = alignment_metrics.max(-2).unsqueeze(-1)
  141. assigned_scores = assigned_scores * alignment_metrics
  142. return assigned_labels, assigned_bboxes, assigned_scores