ssd_loss.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register
  21. from ..ops import iou_similarity
  22. from ..bbox_utils import bbox2delta
  23. __all__ = ['SSDLoss']
  24. @register
  25. class SSDLoss(nn.Layer):
  26. """
  27. SSDLoss
  28. Args:
  29. overlap_threshold (float32, optional): IoU threshold for negative bboxes
  30. and positive bboxes, 0.5 by default.
  31. neg_pos_ratio (float): The ratio of negative samples / positive samples.
  32. loc_loss_weight (float): The weight of loc_loss.
  33. conf_loss_weight (float): The weight of conf_loss.
  34. prior_box_var (list): Variances corresponding to prior box coord, [0.1,
  35. 0.1, 0.2, 0.2] by default.
  36. """
  37. def __init__(self,
  38. overlap_threshold=0.5,
  39. neg_pos_ratio=3.0,
  40. loc_loss_weight=1.0,
  41. conf_loss_weight=1.0,
  42. prior_box_var=[0.1, 0.1, 0.2, 0.2]):
  43. super(SSDLoss, self).__init__()
  44. self.overlap_threshold = overlap_threshold
  45. self.neg_pos_ratio = neg_pos_ratio
  46. self.loc_loss_weight = loc_loss_weight
  47. self.conf_loss_weight = conf_loss_weight
  48. self.prior_box_var = [1. / a for a in prior_box_var]
  49. def _bipartite_match_for_batch(self, gt_bbox, gt_label, prior_boxes,
  50. bg_index):
  51. """
  52. Args:
  53. gt_bbox (Tensor): [B, N, 4]
  54. gt_label (Tensor): [B, N, 1]
  55. prior_boxes (Tensor): [A, 4]
  56. bg_index (int): Background class index
  57. """
  58. batch_size, num_priors = gt_bbox.shape[0], prior_boxes.shape[0]
  59. ious = iou_similarity(gt_bbox.reshape((-1, 4)), prior_boxes).reshape(
  60. (batch_size, -1, num_priors))
  61. # For each prior box, get the max IoU of all GTs.
  62. prior_max_iou, prior_argmax_iou = ious.max(axis=1), ious.argmax(axis=1)
  63. # For each GT, get the max IoU of all prior boxes.
  64. gt_max_iou, gt_argmax_iou = ious.max(axis=2), ious.argmax(axis=2)
  65. # Gather target bbox and label according to 'prior_argmax_iou' index.
  66. batch_ind = paddle.arange(end=batch_size, dtype='int64').unsqueeze(-1)
  67. prior_argmax_iou = paddle.stack(
  68. [batch_ind.tile([1, num_priors]), prior_argmax_iou], axis=-1)
  69. targets_bbox = paddle.gather_nd(gt_bbox, prior_argmax_iou)
  70. targets_label = paddle.gather_nd(gt_label, prior_argmax_iou)
  71. # Assign negative
  72. bg_index_tensor = paddle.full([batch_size, num_priors, 1], bg_index,
  73. 'int64')
  74. targets_label = paddle.where(
  75. prior_max_iou.unsqueeze(-1) < self.overlap_threshold,
  76. bg_index_tensor, targets_label)
  77. # Ensure each GT can match the max IoU prior box.
  78. batch_ind = (batch_ind * num_priors + gt_argmax_iou).flatten()
  79. targets_bbox = paddle.scatter(
  80. targets_bbox.reshape([-1, 4]), batch_ind,
  81. gt_bbox.reshape([-1, 4])).reshape([batch_size, -1, 4])
  82. targets_label = paddle.scatter(
  83. targets_label.reshape([-1, 1]), batch_ind,
  84. gt_label.reshape([-1, 1])).reshape([batch_size, -1, 1])
  85. targets_label[:, :1] = bg_index
  86. # Encode box
  87. prior_boxes = prior_boxes.unsqueeze(0).tile([batch_size, 1, 1])
  88. targets_bbox = bbox2delta(
  89. prior_boxes.reshape([-1, 4]),
  90. targets_bbox.reshape([-1, 4]), self.prior_box_var)
  91. targets_bbox = targets_bbox.reshape([batch_size, -1, 4])
  92. return targets_bbox, targets_label
  93. def _mine_hard_example(self,
  94. conf_loss,
  95. targets_label,
  96. bg_index,
  97. mine_neg_ratio=0.01):
  98. pos = (targets_label != bg_index).astype(conf_loss.dtype)
  99. num_pos = pos.sum(axis=1, keepdim=True)
  100. neg = (targets_label == bg_index).astype(conf_loss.dtype)
  101. conf_loss = conf_loss.detach() * neg
  102. loss_idx = conf_loss.argsort(axis=1, descending=True)
  103. idx_rank = loss_idx.argsort(axis=1)
  104. num_negs = []
  105. for i in range(conf_loss.shape[0]):
  106. cur_num_pos = num_pos[i]
  107. num_neg = paddle.clip(
  108. cur_num_pos * self.neg_pos_ratio, max=pos.shape[1])
  109. num_neg = num_neg if num_neg > 0 else paddle.to_tensor(
  110. [pos.shape[1] * mine_neg_ratio])
  111. num_negs.append(num_neg)
  112. num_negs = paddle.stack(num_negs).expand_as(idx_rank)
  113. neg_mask = (idx_rank < num_negs).astype(conf_loss.dtype)
  114. return (neg_mask + pos).astype('bool')
  115. def forward(self, boxes, scores, gt_bbox, gt_label, prior_boxes):
  116. boxes = paddle.concat(boxes, axis=1)
  117. scores = paddle.concat(scores, axis=1)
  118. gt_label = gt_label.unsqueeze(-1).astype('int64')
  119. prior_boxes = paddle.concat(prior_boxes, axis=0)
  120. bg_index = scores.shape[-1] - 1
  121. # Match bbox and get targets.
  122. targets_bbox, targets_label = \
  123. self._bipartite_match_for_batch(gt_bbox, gt_label, prior_boxes, bg_index)
  124. targets_bbox.stop_gradient = True
  125. targets_label.stop_gradient = True
  126. # Compute regression loss.
  127. # Select positive samples.
  128. bbox_mask = paddle.tile(targets_label != bg_index, [1, 1, 4])
  129. if bbox_mask.astype(boxes.dtype).sum() > 0:
  130. location = paddle.masked_select(boxes, bbox_mask)
  131. targets_bbox = paddle.masked_select(targets_bbox, bbox_mask)
  132. loc_loss = F.smooth_l1_loss(location, targets_bbox, reduction='sum')
  133. loc_loss = loc_loss * self.loc_loss_weight
  134. else:
  135. loc_loss = paddle.zeros([1])
  136. # Compute confidence loss.
  137. conf_loss = F.cross_entropy(scores, targets_label, reduction="none")
  138. # Mining hard examples.
  139. label_mask = self._mine_hard_example(
  140. conf_loss.squeeze(-1), targets_label.squeeze(-1), bg_index)
  141. conf_loss = paddle.masked_select(conf_loss, label_mask.unsqueeze(-1))
  142. conf_loss = conf_loss.sum() * self.conf_loss_weight
  143. # Compute overall weighted loss.
  144. normalizer = (targets_label != bg_index).astype('float32').sum().clip(
  145. min=1)
  146. loss = (conf_loss + loc_loss) / normalizer
  147. return loss