matchers.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Modified from DETR (https://github.com/facebookresearch/detr)
  16. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import paddle
  21. import paddle.nn as nn
  22. import paddle.nn.functional as F
  23. from scipy.optimize import linear_sum_assignment
  24. from ppdet.core.workspace import register, serializable
  25. from ..losses.iou_loss import GIoULoss
  26. from .utils import bbox_cxcywh_to_xyxy
  27. __all__ = ['HungarianMatcher']
  28. @register
  29. @serializable
  30. class HungarianMatcher(nn.Layer):
  31. __shared__ = ['use_focal_loss']
  32. def __init__(self,
  33. matcher_coeff={'class': 1,
  34. 'bbox': 5,
  35. 'giou': 2},
  36. use_focal_loss=False,
  37. alpha=0.25,
  38. gamma=2.0):
  39. r"""
  40. Args:
  41. matcher_coeff (dict): The coefficient of hungarian matcher cost.
  42. """
  43. super(HungarianMatcher, self).__init__()
  44. self.matcher_coeff = matcher_coeff
  45. self.use_focal_loss = use_focal_loss
  46. self.alpha = alpha
  47. self.gamma = gamma
  48. self.giou_loss = GIoULoss()
  49. def forward(self, boxes, logits, gt_bbox, gt_class):
  50. r"""
  51. Args:
  52. boxes (Tensor): [b, query, 4]
  53. logits (Tensor): [b, query, num_classes]
  54. gt_bbox (List(Tensor)): list[[n, 4]]
  55. gt_class (List(Tensor)): list[[n, 1]]
  56. Returns:
  57. A list of size batch_size, containing tuples of (index_i, index_j) where:
  58. - index_i is the indices of the selected predictions (in order)
  59. - index_j is the indices of the corresponding selected targets (in order)
  60. For each batch element, it holds:
  61. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  62. """
  63. bs, num_queries = boxes.shape[:2]
  64. num_gts = sum(len(a) for a in gt_class)
  65. if num_gts == 0:
  66. return [(paddle.to_tensor(
  67. [], dtype=paddle.int64), paddle.to_tensor(
  68. [], dtype=paddle.int64)) for _ in range(bs)]
  69. # We flatten to compute the cost matrices in a batch
  70. # [batch_size * num_queries, num_classes]
  71. out_prob = F.sigmoid(logits.flatten(
  72. 0, 1)) if self.use_focal_loss else F.softmax(logits.flatten(0, 1))
  73. # [batch_size * num_queries, 4]
  74. out_bbox = boxes.flatten(0, 1)
  75. # Also concat the target labels and boxes
  76. tgt_ids = paddle.concat(gt_class).flatten()
  77. tgt_bbox = paddle.concat(gt_bbox)
  78. # Compute the classification cost
  79. if self.use_focal_loss:
  80. neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(
  81. 1 - out_prob + 1e-8).log())
  82. pos_cost_class = self.alpha * (
  83. (1 - out_prob)**self.gamma) * (-(out_prob + 1e-8).log())
  84. cost_class = paddle.gather(
  85. pos_cost_class, tgt_ids, axis=1) - paddle.gather(
  86. neg_cost_class, tgt_ids, axis=1)
  87. else:
  88. cost_class = -paddle.gather(out_prob, tgt_ids, axis=1)
  89. # Compute the L1 cost between boxes
  90. cost_bbox = (
  91. out_bbox.unsqueeze(1) - tgt_bbox.unsqueeze(0)).abs().sum(-1)
  92. # Compute the giou cost betwen boxes
  93. cost_giou = self.giou_loss(
  94. bbox_cxcywh_to_xyxy(out_bbox.unsqueeze(1)),
  95. bbox_cxcywh_to_xyxy(tgt_bbox.unsqueeze(0))).squeeze(-1)
  96. # Final cost matrix
  97. C = self.matcher_coeff['class'] * cost_class + self.matcher_coeff['bbox'] * cost_bbox + \
  98. self.matcher_coeff['giou'] * cost_giou
  99. C = C.reshape([bs, num_queries, -1])
  100. C = [a.squeeze(0) for a in C.chunk(bs)]
  101. sizes = [a.shape[0] for a in gt_bbox]
  102. indices = [
  103. linear_sum_assignment(c.split(sizes, -1)[i].numpy())
  104. for i, c in enumerate(C)
  105. ]
  106. return [(paddle.to_tensor(
  107. i, dtype=paddle.int64), paddle.to_tensor(
  108. j, dtype=paddle.int64)) for i, j in indices]