detr_loss.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register
  21. from .iou_loss import GIoULoss
  22. from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
  23. __all__ = ['DETRLoss']
  24. @register
  25. class DETRLoss(nn.Layer):
  26. __shared__ = ['num_classes', 'use_focal_loss']
  27. __inject__ = ['matcher']
  28. def __init__(self,
  29. num_classes=80,
  30. matcher='HungarianMatcher',
  31. loss_coeff={
  32. 'class': 1,
  33. 'bbox': 5,
  34. 'giou': 2,
  35. 'no_object': 0.1,
  36. 'mask': 1,
  37. 'dice': 1
  38. },
  39. aux_loss=True,
  40. use_focal_loss=False):
  41. r"""
  42. Args:
  43. num_classes (int): The number of classes.
  44. matcher (HungarianMatcher): It computes an assignment between the targets
  45. and the predictions of the network.
  46. loss_coeff (dict): The coefficient of loss.
  47. aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
  48. use_focal_loss (bool): Use focal loss or not.
  49. """
  50. super(DETRLoss, self).__init__()
  51. self.num_classes = num_classes
  52. self.matcher = matcher
  53. self.loss_coeff = loss_coeff
  54. self.aux_loss = aux_loss
  55. self.use_focal_loss = use_focal_loss
  56. if not self.use_focal_loss:
  57. self.loss_coeff['class'] = paddle.full([num_classes + 1],
  58. loss_coeff['class'])
  59. self.loss_coeff['class'][-1] = loss_coeff['no_object']
  60. self.giou_loss = GIoULoss()
  61. def _get_loss_class(self, logits, gt_class, match_indices, bg_index,
  62. num_gts):
  63. # logits: [b, query, num_classes], gt_class: list[[n, 1]]
  64. target_label = paddle.full(logits.shape[:2], bg_index, dtype='int64')
  65. bs, num_query_objects = target_label.shape
  66. if sum(len(a) for a in gt_class) > 0:
  67. index, updates = self._get_index_updates(num_query_objects,
  68. gt_class, match_indices)
  69. target_label = paddle.scatter(
  70. target_label.reshape([-1, 1]), index, updates.astype('int64'))
  71. target_label = target_label.reshape([bs, num_query_objects])
  72. if self.use_focal_loss:
  73. target_label = F.one_hot(target_label,
  74. self.num_classes + 1)[..., :-1]
  75. return {
  76. 'loss_class': self.loss_coeff['class'] * sigmoid_focal_loss(
  77. logits, target_label, num_gts / num_query_objects)
  78. if self.use_focal_loss else F.cross_entropy(
  79. logits, target_label, weight=self.loss_coeff['class'])
  80. }
  81. def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts):
  82. # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
  83. loss = dict()
  84. if sum(len(a) for a in gt_bbox) == 0:
  85. loss['loss_bbox'] = paddle.to_tensor([0.])
  86. loss['loss_giou'] = paddle.to_tensor([0.])
  87. return loss
  88. src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox,
  89. match_indices)
  90. loss['loss_bbox'] = self.loss_coeff['bbox'] * F.l1_loss(
  91. src_bbox, target_bbox, reduction='sum') / num_gts
  92. loss['loss_giou'] = self.giou_loss(
  93. bbox_cxcywh_to_xyxy(src_bbox), bbox_cxcywh_to_xyxy(target_bbox))
  94. loss['loss_giou'] = loss['loss_giou'].sum() / num_gts
  95. loss['loss_giou'] = self.loss_coeff['giou'] * loss['loss_giou']
  96. return loss
  97. def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts):
  98. # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
  99. loss = dict()
  100. if sum(len(a) for a in gt_mask) == 0:
  101. loss['loss_mask'] = paddle.to_tensor([0.])
  102. loss['loss_dice'] = paddle.to_tensor([0.])
  103. return loss
  104. src_masks, target_masks = self._get_src_target_assign(masks, gt_mask,
  105. match_indices)
  106. src_masks = F.interpolate(
  107. src_masks.unsqueeze(0),
  108. size=target_masks.shape[-2:],
  109. mode="bilinear")[0]
  110. loss['loss_mask'] = self.loss_coeff['mask'] * F.sigmoid_focal_loss(
  111. src_masks,
  112. target_masks,
  113. paddle.to_tensor(
  114. [num_gts], dtype='float32'))
  115. loss['loss_dice'] = self.loss_coeff['dice'] * self._dice_loss(
  116. src_masks, target_masks, num_gts)
  117. return loss
  118. def _dice_loss(self, inputs, targets, num_gts):
  119. inputs = F.sigmoid(inputs)
  120. inputs = inputs.flatten(1)
  121. targets = targets.flatten(1)
  122. numerator = 2 * (inputs * targets).sum(1)
  123. denominator = inputs.sum(-1) + targets.sum(-1)
  124. loss = 1 - (numerator + 1) / (denominator + 1)
  125. return loss.sum() / num_gts
  126. def _get_loss_aux(self, boxes, logits, gt_bbox, gt_class, bg_index,
  127. num_gts):
  128. loss_class = []
  129. loss_bbox = []
  130. loss_giou = []
  131. for aux_boxes, aux_logits in zip(boxes, logits):
  132. match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
  133. gt_class)
  134. loss_class.append(
  135. self._get_loss_class(aux_logits, gt_class, match_indices,
  136. bg_index, num_gts)['loss_class'])
  137. loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
  138. num_gts)
  139. loss_bbox.append(loss_['loss_bbox'])
  140. loss_giou.append(loss_['loss_giou'])
  141. loss = {
  142. 'loss_class_aux': paddle.add_n(loss_class),
  143. 'loss_bbox_aux': paddle.add_n(loss_bbox),
  144. 'loss_giou_aux': paddle.add_n(loss_giou)
  145. }
  146. return loss
  147. def _get_index_updates(self, num_query_objects, target, match_indices):
  148. batch_idx = paddle.concat([
  149. paddle.full_like(src, i) for i, (src, _) in enumerate(match_indices)
  150. ])
  151. src_idx = paddle.concat([src for (src, _) in match_indices])
  152. src_idx += (batch_idx * num_query_objects)
  153. target_assign = paddle.concat([
  154. paddle.gather(
  155. t, dst, axis=0) for t, (_, dst) in zip(target, match_indices)
  156. ])
  157. return src_idx, target_assign
  158. def _get_src_target_assign(self, src, target, match_indices):
  159. src_assign = paddle.concat([
  160. paddle.gather(
  161. t, I, axis=0) if len(I) > 0 else paddle.zeros([0, t.shape[-1]])
  162. for t, (I, _) in zip(src, match_indices)
  163. ])
  164. target_assign = paddle.concat([
  165. paddle.gather(
  166. t, J, axis=0) if len(J) > 0 else paddle.zeros([0, t.shape[-1]])
  167. for t, (_, J) in zip(target, match_indices)
  168. ])
  169. return src_assign, target_assign
  170. def forward(self,
  171. boxes,
  172. logits,
  173. gt_bbox,
  174. gt_class,
  175. masks=None,
  176. gt_mask=None):
  177. r"""
  178. Args:
  179. boxes (Tensor): [l, b, query, 4]
  180. logits (Tensor): [l, b, query, num_classes]
  181. gt_bbox (List(Tensor)): list[[n, 4]]
  182. gt_class (List(Tensor)): list[[n, 1]]
  183. masks (Tensor, optional): [b, query, h, w]
  184. gt_mask (List(Tensor), optional): list[[n, H, W]]
  185. """
  186. match_indices = self.matcher(boxes[-1].detach(), logits[-1].detach(),
  187. gt_bbox, gt_class)
  188. num_gts = sum(len(a) for a in gt_bbox)
  189. try:
  190. # TODO: Paddle does not have a "paddle.distributed.is_initialized()"
  191. num_gts = paddle.to_tensor([num_gts], dtype=paddle.float32)
  192. paddle.distributed.all_reduce(num_gts)
  193. num_gts = paddle.clip(
  194. num_gts / paddle.distributed.get_world_size(), min=1).item()
  195. except:
  196. num_gts = max(num_gts.item(), 1)
  197. total_loss = dict()
  198. total_loss.update(
  199. self._get_loss_class(logits[-1], gt_class, match_indices,
  200. self.num_classes, num_gts))
  201. total_loss.update(
  202. self._get_loss_bbox(boxes[-1], gt_bbox, match_indices, num_gts))
  203. if masks is not None and gt_mask is not None:
  204. total_loss.update(
  205. self._get_loss_mask(masks, gt_mask, match_indices, num_gts))
  206. if self.aux_loss:
  207. total_loss.update(
  208. self._get_loss_aux(boxes[:-1], logits[:-1], gt_bbox, gt_class,
  209. self.num_classes, num_gts))
  210. return total_loss