yolo_loss.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register
  21. from ..bbox_utils import decode_yolo, xywh2xyxy, iou_similarity
  22. __all__ = ['YOLOv3Loss']
  23. def bbox_transform(pbox, anchor, downsample):
  24. pbox = decode_yolo(pbox, anchor, downsample)
  25. pbox = xywh2xyxy(pbox)
  26. return pbox
  27. @register
  28. class YOLOv3Loss(nn.Layer):
  29. __inject__ = ['iou_loss', 'iou_aware_loss']
  30. __shared__ = ['num_classes']
  31. def __init__(self,
  32. num_classes=80,
  33. ignore_thresh=0.7,
  34. label_smooth=False,
  35. downsample=[32, 16, 8],
  36. scale_x_y=1.,
  37. iou_loss=None,
  38. iou_aware_loss=None):
  39. """
  40. YOLOv3Loss layer
  41. Args:
  42. num_calsses (int): number of foreground classes
  43. ignore_thresh (float): threshold to ignore confidence loss
  44. label_smooth (bool): whether to use label smoothing
  45. downsample (list): downsample ratio for each detection block
  46. scale_x_y (float): scale_x_y factor
  47. iou_loss (object): IoULoss instance
  48. iou_aware_loss (object): IouAwareLoss instance
  49. """
  50. super(YOLOv3Loss, self).__init__()
  51. self.num_classes = num_classes
  52. self.ignore_thresh = ignore_thresh
  53. self.label_smooth = label_smooth
  54. self.downsample = downsample
  55. self.scale_x_y = scale_x_y
  56. self.iou_loss = iou_loss
  57. self.iou_aware_loss = iou_aware_loss
  58. self.distill_pairs = []
  59. def obj_loss(self, pbox, gbox, pobj, tobj, anchor, downsample):
  60. # pbox
  61. pbox = decode_yolo(pbox, anchor, downsample)
  62. pbox = xywh2xyxy(pbox)
  63. pbox = paddle.concat(pbox, axis=-1)
  64. b = pbox.shape[0]
  65. pbox = pbox.reshape((b, -1, 4))
  66. # gbox
  67. gxy = gbox[:, :, 0:2] - gbox[:, :, 2:4] * 0.5
  68. gwh = gbox[:, :, 0:2] + gbox[:, :, 2:4] * 0.5
  69. gbox = paddle.concat([gxy, gwh], axis=-1)
  70. iou = iou_similarity(pbox, gbox)
  71. iou.stop_gradient = True
  72. iou_max = iou.max(2) # [N, M1]
  73. iou_mask = paddle.cast(iou_max <= self.ignore_thresh, dtype=pbox.dtype)
  74. iou_mask.stop_gradient = True
  75. pobj = pobj.reshape((b, -1))
  76. tobj = tobj.reshape((b, -1))
  77. obj_mask = paddle.cast(tobj > 0, dtype=pbox.dtype)
  78. obj_mask.stop_gradient = True
  79. loss_obj = F.binary_cross_entropy_with_logits(
  80. pobj, obj_mask, reduction='none')
  81. loss_obj_pos = (loss_obj * tobj)
  82. loss_obj_neg = (loss_obj * (1 - obj_mask) * iou_mask)
  83. return loss_obj_pos + loss_obj_neg
  84. def cls_loss(self, pcls, tcls):
  85. if self.label_smooth:
  86. delta = min(1. / self.num_classes, 1. / 40)
  87. pos, neg = 1 - delta, delta
  88. # 1 for positive, 0 for negative
  89. tcls = pos * paddle.cast(
  90. tcls > 0., dtype=tcls.dtype) + neg * paddle.cast(
  91. tcls <= 0., dtype=tcls.dtype)
  92. loss_cls = F.binary_cross_entropy_with_logits(
  93. pcls, tcls, reduction='none')
  94. return loss_cls
  95. def yolov3_loss(self, p, t, gt_box, anchor, downsample, scale=1.,
  96. eps=1e-10):
  97. na = len(anchor)
  98. b, c, h, w = p.shape
  99. if self.iou_aware_loss:
  100. ioup, p = p[:, 0:na, :, :], p[:, na:, :, :]
  101. ioup = ioup.unsqueeze(-1)
  102. p = p.reshape((b, na, -1, h, w)).transpose((0, 1, 3, 4, 2))
  103. x, y = p[:, :, :, :, 0:1], p[:, :, :, :, 1:2]
  104. w, h = p[:, :, :, :, 2:3], p[:, :, :, :, 3:4]
  105. obj, pcls = p[:, :, :, :, 4:5], p[:, :, :, :, 5:]
  106. self.distill_pairs.append([x, y, w, h, obj, pcls])
  107. t = t.transpose((0, 1, 3, 4, 2))
  108. tx, ty = t[:, :, :, :, 0:1], t[:, :, :, :, 1:2]
  109. tw, th = t[:, :, :, :, 2:3], t[:, :, :, :, 3:4]
  110. tscale = t[:, :, :, :, 4:5]
  111. tobj, tcls = t[:, :, :, :, 5:6], t[:, :, :, :, 6:]
  112. tscale_obj = tscale * tobj
  113. loss = dict()
  114. x = scale * F.sigmoid(x) - 0.5 * (scale - 1.)
  115. y = scale * F.sigmoid(y) - 0.5 * (scale - 1.)
  116. if abs(scale - 1.) < eps:
  117. loss_x = F.binary_cross_entropy(x, tx, reduction='none')
  118. loss_y = F.binary_cross_entropy(y, ty, reduction='none')
  119. loss_xy = tscale_obj * (loss_x + loss_y)
  120. else:
  121. loss_x = paddle.abs(x - tx)
  122. loss_y = paddle.abs(y - ty)
  123. loss_xy = tscale_obj * (loss_x + loss_y)
  124. loss_xy = loss_xy.sum([1, 2, 3, 4]).mean()
  125. loss_w = paddle.abs(w - tw)
  126. loss_h = paddle.abs(h - th)
  127. loss_wh = tscale_obj * (loss_w + loss_h)
  128. loss_wh = loss_wh.sum([1, 2, 3, 4]).mean()
  129. loss['loss_xy'] = loss_xy
  130. loss['loss_wh'] = loss_wh
  131. if self.iou_loss is not None:
  132. # warn: do not modify x, y, w, h in place
  133. box, tbox = [x, y, w, h], [tx, ty, tw, th]
  134. pbox = bbox_transform(box, anchor, downsample)
  135. gbox = bbox_transform(tbox, anchor, downsample)
  136. loss_iou = self.iou_loss(pbox, gbox)
  137. loss_iou = loss_iou * tscale_obj
  138. loss_iou = loss_iou.sum([1, 2, 3, 4]).mean()
  139. loss['loss_iou'] = loss_iou
  140. if self.iou_aware_loss is not None:
  141. box, tbox = [x, y, w, h], [tx, ty, tw, th]
  142. pbox = bbox_transform(box, anchor, downsample)
  143. gbox = bbox_transform(tbox, anchor, downsample)
  144. loss_iou_aware = self.iou_aware_loss(ioup, pbox, gbox)
  145. loss_iou_aware = loss_iou_aware * tobj
  146. loss_iou_aware = loss_iou_aware.sum([1, 2, 3, 4]).mean()
  147. loss['loss_iou_aware'] = loss_iou_aware
  148. box = [x, y, w, h]
  149. loss_obj = self.obj_loss(box, gt_box, obj, tobj, anchor, downsample)
  150. loss_obj = loss_obj.sum(-1).mean()
  151. loss['loss_obj'] = loss_obj
  152. loss_cls = self.cls_loss(pcls, tcls) * tobj
  153. loss_cls = loss_cls.sum([1, 2, 3, 4]).mean()
  154. loss['loss_cls'] = loss_cls
  155. return loss
  156. def forward(self, inputs, targets, anchors):
  157. np = len(inputs)
  158. gt_targets = [targets['target{}'.format(i)] for i in range(np)]
  159. gt_box = targets['gt_bbox']
  160. yolo_losses = dict()
  161. self.distill_pairs.clear()
  162. for x, t, anchor, downsample in zip(inputs, gt_targets, anchors,
  163. self.downsample):
  164. yolo_loss = self.yolov3_loss(x, t, gt_box, anchor, downsample,
  165. self.scale_x_y)
  166. for k, v in yolo_loss.items():
  167. if k in yolo_losses:
  168. yolo_losses[k] += v
  169. else:
  170. yolo_losses[k] = v
  171. loss = 0
  172. for k, v in yolo_losses.items():
  173. loss += v
  174. yolo_losses['loss'] = loss
  175. return yolo_losses