jde_loss.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register
  21. __all__ = ['JDEDetectionLoss', 'JDEEmbeddingLoss', 'JDELoss']
  22. @register
  23. class JDEDetectionLoss(nn.Layer):
  24. __shared__ = ['num_classes']
  25. def __init__(self, num_classes=1, for_mot=True):
  26. super(JDEDetectionLoss, self).__init__()
  27. self.num_classes = num_classes
  28. self.for_mot = for_mot
  29. def det_loss(self, p_det, anchor, t_conf, t_box):
  30. pshape = paddle.shape(p_det)
  31. pshape.stop_gradient = True
  32. nB, nGh, nGw = pshape[0], pshape[-2], pshape[-1]
  33. nA = len(anchor)
  34. p_det = paddle.reshape(
  35. p_det, [nB, nA, self.num_classes + 5, nGh, nGw]).transpose(
  36. (0, 1, 3, 4, 2))
  37. # 1. loss_conf: cross_entropy
  38. p_conf = p_det[:, :, :, :, 4:6]
  39. p_conf_flatten = paddle.reshape(p_conf, [-1, 2])
  40. t_conf_flatten = t_conf.flatten()
  41. t_conf_flatten = paddle.cast(t_conf_flatten, dtype="int64")
  42. t_conf_flatten.stop_gradient = True
  43. loss_conf = F.cross_entropy(
  44. p_conf_flatten, t_conf_flatten, ignore_index=-1, reduction='mean')
  45. loss_conf.stop_gradient = False
  46. # 2. loss_box: smooth_l1_loss
  47. p_box = p_det[:, :, :, :, :4]
  48. p_box_flatten = paddle.reshape(p_box, [-1, 4])
  49. t_box_flatten = paddle.reshape(t_box, [-1, 4])
  50. fg_inds = paddle.nonzero(t_conf_flatten > 0).flatten()
  51. if fg_inds.numel() > 0:
  52. reg_delta = paddle.gather(p_box_flatten, fg_inds)
  53. reg_target = paddle.gather(t_box_flatten, fg_inds)
  54. else:
  55. reg_delta = paddle.to_tensor([0, 0, 0, 0], dtype='float32')
  56. reg_delta.stop_gradient = False
  57. reg_target = paddle.to_tensor([0, 0, 0, 0], dtype='float32')
  58. reg_target.stop_gradient = True
  59. loss_box = F.smooth_l1_loss(
  60. reg_delta, reg_target, reduction='mean', delta=1.0)
  61. loss_box.stop_gradient = False
  62. return loss_conf, loss_box
  63. def forward(self, det_outs, targets, anchors):
  64. """
  65. Args:
  66. det_outs (list[Tensor]): output from detection head, each one
  67. is a 4-D Tensor with shape [N, C, H, W].
  68. targets (dict): contains 'im_id', 'gt_bbox', 'gt_ide', 'image',
  69. 'im_shape', 'scale_factor' and 'tbox', 'tconf', 'tide' of
  70. each FPN level.
  71. anchors (list[list]): anchor setting of JDE model, N row M col, N is
  72. the anchor levels(FPN levels), M is the anchor scales each
  73. level.
  74. """
  75. assert len(det_outs) == len(anchors)
  76. loss_confs = []
  77. loss_boxes = []
  78. for i, (p_det, anchor) in enumerate(zip(det_outs, anchors)):
  79. t_conf = targets['tconf{}'.format(i)]
  80. t_box = targets['tbox{}'.format(i)]
  81. loss_conf, loss_box = self.det_loss(p_det, anchor, t_conf, t_box)
  82. loss_confs.append(loss_conf)
  83. loss_boxes.append(loss_box)
  84. if self.for_mot:
  85. return {'loss_confs': loss_confs, 'loss_boxes': loss_boxes}
  86. else:
  87. jde_conf_losses = sum(loss_confs)
  88. jde_box_losses = sum(loss_boxes)
  89. jde_det_losses = {
  90. "loss_conf": jde_conf_losses,
  91. "loss_box": jde_box_losses,
  92. "loss": jde_conf_losses + jde_box_losses,
  93. }
  94. return jde_det_losses
  95. @register
  96. class JDEEmbeddingLoss(nn.Layer):
  97. def __init__(self, ):
  98. super(JDEEmbeddingLoss, self).__init__()
  99. self.phony = self.create_parameter(shape=[1], dtype="float32")
  100. def emb_loss(self, p_ide, t_conf, t_ide, emb_scale, classifier):
  101. emb_dim = p_ide.shape[1]
  102. p_ide = p_ide.transpose((0, 2, 3, 1))
  103. p_ide_flatten = paddle.reshape(p_ide, [-1, emb_dim])
  104. mask = t_conf > 0
  105. mask = paddle.cast(mask, dtype="int64")
  106. mask.stop_gradient = True
  107. emb_mask = mask.max(1).flatten()
  108. emb_mask_inds = paddle.nonzero(emb_mask > 0).flatten()
  109. emb_mask_inds.stop_gradient = True
  110. # use max(1) to decide the id, TODO: more reseanable strategy
  111. t_ide_flatten = t_ide.max(1).flatten()
  112. t_ide_flatten = paddle.cast(t_ide_flatten, dtype="int64")
  113. valid_inds = paddle.nonzero(t_ide_flatten != -1).flatten()
  114. if emb_mask_inds.numel() == 0 or valid_inds.numel() == 0:
  115. # loss_ide = paddle.to_tensor([0]) # will be error in gradient backward
  116. loss_ide = self.phony * 0 # todo
  117. else:
  118. embedding = paddle.gather(p_ide_flatten, emb_mask_inds)
  119. embedding = emb_scale * F.normalize(embedding)
  120. logits = classifier(embedding)
  121. ide_target = paddle.gather(t_ide_flatten, emb_mask_inds)
  122. loss_ide = F.cross_entropy(
  123. logits, ide_target, ignore_index=-1, reduction='mean')
  124. loss_ide.stop_gradient = False
  125. return loss_ide
  126. def forward(self, ide_outs, targets, emb_scale, classifier):
  127. loss_ides = []
  128. for i, p_ide in enumerate(ide_outs):
  129. t_conf = targets['tconf{}'.format(i)]
  130. t_ide = targets['tide{}'.format(i)]
  131. loss_ide = self.emb_loss(p_ide, t_conf, t_ide, emb_scale,
  132. classifier)
  133. loss_ides.append(loss_ide)
  134. return loss_ides
  135. @register
  136. class JDELoss(nn.Layer):
  137. def __init__(self):
  138. super(JDELoss, self).__init__()
  139. def forward(self, loss_confs, loss_boxes, loss_ides, loss_params_cls,
  140. loss_params_reg, loss_params_ide, targets):
  141. assert len(loss_confs) == len(loss_boxes) == len(loss_ides)
  142. assert len(loss_params_cls) == len(loss_params_reg) == len(
  143. loss_params_ide)
  144. assert len(loss_confs) == len(loss_params_cls)
  145. batchsize = targets['gt_bbox'].shape[0]
  146. nTargets = paddle.nonzero(paddle.sum(targets['gt_bbox'], axis=2)).shape[
  147. 0] / batchsize
  148. nTargets = paddle.to_tensor(nTargets, dtype='float32')
  149. nTargets.stop_gradient = True
  150. jde_losses = []
  151. for i, (loss_conf, loss_box, loss_ide, l_conf_p, l_box_p,
  152. l_ide_p) in enumerate(
  153. zip(loss_confs, loss_boxes, loss_ides, loss_params_cls,
  154. loss_params_reg, loss_params_ide)):
  155. jde_loss = l_conf_p(loss_conf) + l_box_p(loss_box) + l_ide_p(
  156. loss_ide)
  157. jde_losses.append(jde_loss)
  158. loss_all = {
  159. "loss_conf": sum(loss_confs),
  160. "loss_box": sum(loss_boxes),
  161. "loss_ide": sum(loss_ides),
  162. "loss": sum(jde_losses),
  163. "nTargets": nTargets,
  164. }
  165. return loss_all