atss_assigner.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # The code is based on:
  15. # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. from ppdet.utils.logger import setup_logger
  21. logger = setup_logger(__name__)
  22. def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
  23. """Calculate overlap between two set of bboxes.
  24. If ``is_aligned `` is ``False``, then calculate the overlaps between each
  25. bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
  26. pair of bboxes1 and bboxes2.
  27. Args:
  28. bboxes1 (Tensor): shape (B, m, 4) in <x1, y1, x2, y2> format or empty.
  29. bboxes2 (Tensor): shape (B, n, 4) in <x1, y1, x2, y2> format or empty.
  30. B indicates the batch dim, in shape (B1, B2, ..., Bn).
  31. If ``is_aligned `` is ``True``, then m and n must be equal.
  32. mode (str): "iou" (intersection over union) or "iof" (intersection over
  33. foreground).
  34. is_aligned (bool, optional): If True, then m and n must be equal.
  35. Default False.
  36. eps (float, optional): A value added to the denominator for numerical
  37. stability. Default 1e-6.
  38. Returns:
  39. Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
  40. """
  41. assert mode in ['iou', 'iof', 'giou'], 'Unsupported mode {}'.format(mode)
  42. # Either the boxes are empty or the length of boxes's last dimenstion is 4
  43. assert (bboxes1.shape[-1] == 4 or bboxes1.shape[0] == 0)
  44. assert (bboxes2.shape[-1] == 4 or bboxes2.shape[0] == 0)
  45. # Batch dim must be the same
  46. # Batch dim: (B1, B2, ... Bn)
  47. assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
  48. batch_shape = bboxes1.shape[:-2]
  49. rows = bboxes1.shape[-2] if bboxes1.shape[0] > 0 else 0
  50. cols = bboxes2.shape[-2] if bboxes2.shape[0] > 0 else 0
  51. if is_aligned:
  52. assert rows == cols
  53. if rows * cols == 0:
  54. if is_aligned:
  55. return np.random.random(batch_shape + (rows, ))
  56. else:
  57. return np.random.random(batch_shape + (rows, cols))
  58. area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
  59. bboxes1[..., 3] - bboxes1[..., 1])
  60. area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
  61. bboxes2[..., 3] - bboxes2[..., 1])
  62. if is_aligned:
  63. lt = np.maximum(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2]
  64. rb = np.minimum(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2]
  65. wh = (rb - lt).clip(min=0) # [B, rows, 2]
  66. overlap = wh[..., 0] * wh[..., 1]
  67. if mode in ['iou', 'giou']:
  68. union = area1 + area2 - overlap
  69. else:
  70. union = area1
  71. if mode == 'giou':
  72. enclosed_lt = np.minimum(bboxes1[..., :2], bboxes2[..., :2])
  73. enclosed_rb = np.maximum(bboxes1[..., 2:], bboxes2[..., 2:])
  74. else:
  75. lt = np.maximum(bboxes1[..., :, None, :2],
  76. bboxes2[..., None, :, :2]) # [B, rows, cols, 2]
  77. rb = np.minimum(bboxes1[..., :, None, 2:],
  78. bboxes2[..., None, :, 2:]) # [B, rows, cols, 2]
  79. wh = (rb - lt).clip(min=0) # [B, rows, cols, 2]
  80. overlap = wh[..., 0] * wh[..., 1]
  81. if mode in ['iou', 'giou']:
  82. union = area1[..., None] + area2[..., None, :] - overlap
  83. else:
  84. union = area1[..., None]
  85. if mode == 'giou':
  86. enclosed_lt = np.minimum(bboxes1[..., :, None, :2],
  87. bboxes2[..., None, :, :2])
  88. enclosed_rb = np.maximum(bboxes1[..., :, None, 2:],
  89. bboxes2[..., None, :, 2:])
  90. eps = np.array([eps])
  91. union = np.maximum(union, eps)
  92. ious = overlap / union
  93. if mode in ['iou', 'iof']:
  94. return ious
  95. # calculate gious
  96. enclose_wh = (enclosed_rb - enclosed_lt).clip(min=0)
  97. enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
  98. enclose_area = np.maximum(enclose_area, eps)
  99. gious = ious - (enclose_area - union) / enclose_area
  100. return gious
  101. def topk_(input, k, axis=1, largest=True):
  102. x = -input if largest else input
  103. if axis == 0:
  104. row_index = np.arange(input.shape[1 - axis])
  105. topk_index = np.argpartition(x, k, axis=axis)[0:k, :]
  106. topk_data = x[topk_index, row_index]
  107. topk_index_sort = np.argsort(topk_data, axis=axis)
  108. topk_data_sort = topk_data[topk_index_sort, row_index]
  109. topk_index_sort = topk_index[0:k, :][topk_index_sort, row_index]
  110. else:
  111. column_index = np.arange(x.shape[1 - axis])[:, None]
  112. topk_index = np.argpartition(x, k, axis=axis)[:, 0:k]
  113. topk_data = x[column_index, topk_index]
  114. topk_data = -topk_data if largest else topk_data
  115. topk_index_sort = np.argsort(topk_data, axis=axis)
  116. topk_data_sort = topk_data[column_index, topk_index_sort]
  117. topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort]
  118. return topk_data_sort, topk_index_sort
  119. class ATSSAssigner(object):
  120. """Assign a corresponding gt bbox or background to each bbox.
  121. Each proposals will be assigned with `0` or a positive integer
  122. indicating the ground truth index.
  123. - 0: negative sample, no assigned gt
  124. - positive integer: positive sample, index (1-based) of assigned gt
  125. Args:
  126. topk (float): number of bbox selected in each level
  127. """
  128. def __init__(self, topk=9):
  129. self.topk = topk
  130. def __call__(self,
  131. bboxes,
  132. num_level_bboxes,
  133. gt_bboxes,
  134. gt_bboxes_ignore=None,
  135. gt_labels=None):
  136. """Assign gt to bboxes.
  137. The assignment is done in following steps
  138. 1. compute iou between all bbox (bbox of all pyramid levels) and gt
  139. 2. compute center distance between all bbox and gt
  140. 3. on each pyramid level, for each gt, select k bbox whose center
  141. are closest to the gt center, so we total select k*l bbox as
  142. candidates for each gt
  143. 4. get corresponding iou for the these candidates, and compute the
  144. mean and std, set mean + std as the iou threshold
  145. 5. select these candidates whose iou are greater than or equal to
  146. the threshold as postive
  147. 6. limit the positive sample's center in gt
  148. Args:
  149. bboxes (np.array): Bounding boxes to be assigned, shape(n, 4).
  150. num_level_bboxes (List): num of bboxes in each level
  151. gt_bboxes (np.array): Groundtruth boxes, shape (k, 4).
  152. gt_bboxes_ignore (np.array, optional): Ground truth bboxes that are
  153. labelled as `ignored`, e.g., crowd boxes in COCO.
  154. gt_labels (np.array, optional): Label of gt_bboxes, shape (k, ).
  155. """
  156. bboxes = bboxes[:, :4]
  157. num_gt, num_bboxes = gt_bboxes.shape[0], bboxes.shape[0]
  158. # assign 0 by default
  159. assigned_gt_inds = np.zeros((num_bboxes, ), dtype=np.int64)
  160. if num_gt == 0 or num_bboxes == 0:
  161. # No ground truth or boxes, return empty assignment
  162. max_overlaps = np.zeros((num_bboxes, ))
  163. if num_gt == 0:
  164. # No truth, assign everything to background
  165. assigned_gt_inds[:] = 0
  166. if not np.any(gt_labels):
  167. assigned_labels = None
  168. else:
  169. assigned_labels = -np.ones((num_bboxes, ), dtype=np.int64)
  170. return assigned_gt_inds, max_overlaps
  171. # compute iou between all bbox and gt
  172. overlaps = bbox_overlaps(bboxes, gt_bboxes)
  173. # compute center distance between all bbox and gt
  174. gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
  175. gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
  176. gt_points = np.stack((gt_cx, gt_cy), axis=1)
  177. bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
  178. bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
  179. bboxes_points = np.stack((bboxes_cx, bboxes_cy), axis=1)
  180. distances = np.sqrt(
  181. np.power((bboxes_points[:, None, :] - gt_points[None, :, :]), 2)
  182. .sum(-1))
  183. # Selecting candidates based on the center distance
  184. candidate_idxs = []
  185. start_idx = 0
  186. for bboxes_per_level in num_level_bboxes:
  187. # on each pyramid level, for each gt,
  188. # select k bbox whose center are closest to the gt center
  189. end_idx = start_idx + bboxes_per_level
  190. distances_per_level = distances[start_idx:end_idx, :]
  191. selectable_k = min(self.topk, bboxes_per_level)
  192. _, topk_idxs_per_level = topk_(
  193. distances_per_level, selectable_k, axis=0, largest=False)
  194. candidate_idxs.append(topk_idxs_per_level + start_idx)
  195. start_idx = end_idx
  196. candidate_idxs = np.concatenate(candidate_idxs, axis=0)
  197. # get corresponding iou for the these candidates, and compute the
  198. # mean and std, set mean + std as the iou threshold
  199. candidate_overlaps = overlaps[candidate_idxs, np.arange(num_gt)]
  200. overlaps_mean_per_gt = candidate_overlaps.mean(0)
  201. overlaps_std_per_gt = candidate_overlaps.std(0)
  202. overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
  203. is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
  204. # limit the positive sample's center in gt
  205. for gt_idx in range(num_gt):
  206. candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
  207. ep_bboxes_cx = np.broadcast_to(
  208. bboxes_cx.reshape(1, -1), [num_gt, num_bboxes]).reshape(-1)
  209. ep_bboxes_cy = np.broadcast_to(
  210. bboxes_cy.reshape(1, -1), [num_gt, num_bboxes]).reshape(-1)
  211. candidate_idxs = candidate_idxs.reshape(-1)
  212. # calculate the left, top, right, bottom distance between positive
  213. # bbox center and gt side
  214. l_ = ep_bboxes_cx[candidate_idxs].reshape(-1, num_gt) - gt_bboxes[:, 0]
  215. t_ = ep_bboxes_cy[candidate_idxs].reshape(-1, num_gt) - gt_bboxes[:, 1]
  216. r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].reshape(-1, num_gt)
  217. b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].reshape(-1, num_gt)
  218. is_in_gts = np.stack([l_, t_, r_, b_], axis=1).min(axis=1) > 0.01
  219. is_pos = is_pos & is_in_gts
  220. # if an anchor box is assigned to multiple gts,
  221. # the one with the highest IoU will be selected.
  222. overlaps_inf = -np.inf * np.ones_like(overlaps).T.reshape(-1)
  223. index = candidate_idxs.reshape(-1)[is_pos.reshape(-1)]
  224. overlaps_inf[index] = overlaps.T.reshape(-1)[index]
  225. overlaps_inf = overlaps_inf.reshape(num_gt, -1).T
  226. max_overlaps = overlaps_inf.max(axis=1)
  227. argmax_overlaps = overlaps_inf.argmax(axis=1)
  228. assigned_gt_inds[max_overlaps !=
  229. -np.inf] = argmax_overlaps[max_overlaps != -np.inf] + 1
  230. return assigned_gt_inds, max_overlaps