iou_loss.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. from ppdet.core.workspace import register, serializable
  20. from ..bbox_utils import bbox_iou
  21. __all__ = ['IouLoss', 'GIoULoss', 'DIouLoss']
  22. @register
  23. @serializable
  24. class IouLoss(object):
  25. """
  26. iou loss, see https://arxiv.org/abs/1908.03851
  27. loss = 1.0 - iou * iou
  28. Args:
  29. loss_weight (float): iou loss weight, default is 2.5
  30. max_height (int): max height of input to support random shape input
  31. max_width (int): max width of input to support random shape input
  32. ciou_term (bool): whether to add ciou_term
  33. loss_square (bool): whether to square the iou term
  34. """
  35. def __init__(self,
  36. loss_weight=2.5,
  37. giou=False,
  38. diou=False,
  39. ciou=False,
  40. loss_square=True):
  41. self.loss_weight = loss_weight
  42. self.giou = giou
  43. self.diou = diou
  44. self.ciou = ciou
  45. self.loss_square = loss_square
  46. def __call__(self, pbox, gbox):
  47. iou = bbox_iou(
  48. pbox, gbox, giou=self.giou, diou=self.diou, ciou=self.ciou)
  49. if self.loss_square:
  50. loss_iou = 1 - iou * iou
  51. else:
  52. loss_iou = 1 - iou
  53. loss_iou = loss_iou * self.loss_weight
  54. return loss_iou
  55. @register
  56. @serializable
  57. class GIoULoss(object):
  58. """
  59. Generalized Intersection over Union, see https://arxiv.org/abs/1902.09630
  60. Args:
  61. loss_weight (float): giou loss weight, default as 1
  62. eps (float): epsilon to avoid divide by zero, default as 1e-10
  63. reduction (string): Options are "none", "mean" and "sum". default as none
  64. """
  65. def __init__(self, loss_weight=1., eps=1e-10, reduction='none'):
  66. self.loss_weight = loss_weight
  67. self.eps = eps
  68. assert reduction in ('none', 'mean', 'sum')
  69. self.reduction = reduction
  70. def bbox_overlap(self, box1, box2, eps=1e-10):
  71. """calculate the iou of box1 and box2
  72. Args:
  73. box1 (Tensor): box1 with the shape (..., 4)
  74. box2 (Tensor): box1 with the shape (..., 4)
  75. eps (float): epsilon to avoid divide by zero
  76. Return:
  77. iou (Tensor): iou of box1 and box2
  78. overlap (Tensor): overlap of box1 and box2
  79. union (Tensor): union of box1 and box2
  80. """
  81. x1, y1, x2, y2 = box1
  82. x1g, y1g, x2g, y2g = box2
  83. xkis1 = paddle.maximum(x1, x1g)
  84. ykis1 = paddle.maximum(y1, y1g)
  85. xkis2 = paddle.minimum(x2, x2g)
  86. ykis2 = paddle.minimum(y2, y2g)
  87. w_inter = (xkis2 - xkis1).clip(0)
  88. h_inter = (ykis2 - ykis1).clip(0)
  89. overlap = w_inter * h_inter
  90. area1 = (x2 - x1) * (y2 - y1)
  91. area2 = (x2g - x1g) * (y2g - y1g)
  92. union = area1 + area2 - overlap + eps
  93. iou = overlap / union
  94. return iou, overlap, union
  95. def __call__(self, pbox, gbox, iou_weight=1., loc_reweight=None):
  96. x1, y1, x2, y2 = paddle.split(pbox, num_or_sections=4, axis=-1)
  97. x1g, y1g, x2g, y2g = paddle.split(gbox, num_or_sections=4, axis=-1)
  98. box1 = [x1, y1, x2, y2]
  99. box2 = [x1g, y1g, x2g, y2g]
  100. iou, overlap, union = self.bbox_overlap(box1, box2, self.eps)
  101. xc1 = paddle.minimum(x1, x1g)
  102. yc1 = paddle.minimum(y1, y1g)
  103. xc2 = paddle.maximum(x2, x2g)
  104. yc2 = paddle.maximum(y2, y2g)
  105. area_c = (xc2 - xc1) * (yc2 - yc1) + self.eps
  106. miou = iou - ((area_c - union) / area_c)
  107. if loc_reweight is not None:
  108. loc_reweight = paddle.reshape(loc_reweight, shape=(-1, 1))
  109. loc_thresh = 0.9
  110. giou = 1 - (1 - loc_thresh
  111. ) * miou - loc_thresh * miou * loc_reweight
  112. else:
  113. giou = 1 - miou
  114. if self.reduction == 'none':
  115. loss = giou
  116. elif self.reduction == 'sum':
  117. loss = paddle.sum(giou * iou_weight)
  118. else:
  119. loss = paddle.mean(giou * iou_weight)
  120. return loss * self.loss_weight
  121. @register
  122. @serializable
  123. class DIouLoss(GIoULoss):
  124. """
  125. Distance-IoU Loss, see https://arxiv.org/abs/1911.08287
  126. Args:
  127. loss_weight (float): giou loss weight, default as 1
  128. eps (float): epsilon to avoid divide by zero, default as 1e-10
  129. use_complete_iou_loss (bool): whether to use complete iou loss
  130. """
  131. def __init__(self, loss_weight=1., eps=1e-10, use_complete_iou_loss=True):
  132. super(DIouLoss, self).__init__(loss_weight=loss_weight, eps=eps)
  133. self.use_complete_iou_loss = use_complete_iou_loss
  134. def __call__(self, pbox, gbox, iou_weight=1.):
  135. x1, y1, x2, y2 = paddle.split(pbox, num_or_sections=4, axis=-1)
  136. x1g, y1g, x2g, y2g = paddle.split(gbox, num_or_sections=4, axis=-1)
  137. cx = (x1 + x2) / 2
  138. cy = (y1 + y2) / 2
  139. w = x2 - x1
  140. h = y2 - y1
  141. cxg = (x1g + x2g) / 2
  142. cyg = (y1g + y2g) / 2
  143. wg = x2g - x1g
  144. hg = y2g - y1g
  145. x2 = paddle.maximum(x1, x2)
  146. y2 = paddle.maximum(y1, y2)
  147. # A and B
  148. xkis1 = paddle.maximum(x1, x1g)
  149. ykis1 = paddle.maximum(y1, y1g)
  150. xkis2 = paddle.minimum(x2, x2g)
  151. ykis2 = paddle.minimum(y2, y2g)
  152. # A or B
  153. xc1 = paddle.minimum(x1, x1g)
  154. yc1 = paddle.minimum(y1, y1g)
  155. xc2 = paddle.maximum(x2, x2g)
  156. yc2 = paddle.maximum(y2, y2g)
  157. intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
  158. intsctk = intsctk * paddle.greater_than(
  159. xkis2, xkis1) * paddle.greater_than(ykis2, ykis1)
  160. unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
  161. ) - intsctk + self.eps
  162. iouk = intsctk / unionk
  163. # DIOU term
  164. dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg)
  165. dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
  166. diou_term = (dist_intersection + self.eps) / (dist_union + self.eps)
  167. # CIOU term
  168. ciou_term = 0
  169. if self.use_complete_iou_loss:
  170. ar_gt = wg / hg
  171. ar_pred = w / h
  172. arctan = paddle.atan(ar_gt) - paddle.atan(ar_pred)
  173. ar_loss = 4. / np.pi / np.pi * arctan * arctan
  174. alpha = ar_loss / (1 - iouk + ar_loss + self.eps)
  175. alpha.stop_gradient = True
  176. ciou_term = alpha * ar_loss
  177. diou = paddle.mean((1 - iouk + ciou_term + diou_term) * iou_weight)
  178. return diou * self.loss_weight