123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from ppdet.core.workspace import register
- from .iou_loss import GIoULoss
- from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
- __all__ = ['DETRLoss']
- @register
- class DETRLoss(nn.Layer):
- __shared__ = ['num_classes', 'use_focal_loss']
- __inject__ = ['matcher']
- def __init__(self,
- num_classes=80,
- matcher='HungarianMatcher',
- loss_coeff={
- 'class': 1,
- 'bbox': 5,
- 'giou': 2,
- 'no_object': 0.1,
- 'mask': 1,
- 'dice': 1
- },
- aux_loss=True,
- use_focal_loss=False):
- r"""
- Args:
- num_classes (int): The number of classes.
- matcher (HungarianMatcher): It computes an assignment between the targets
- and the predictions of the network.
- loss_coeff (dict): The coefficient of loss.
- aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
- use_focal_loss (bool): Use focal loss or not.
- """
- super(DETRLoss, self).__init__()
- self.num_classes = num_classes
- self.matcher = matcher
- self.loss_coeff = loss_coeff
- self.aux_loss = aux_loss
- self.use_focal_loss = use_focal_loss
- if not self.use_focal_loss:
- self.loss_coeff['class'] = paddle.full([num_classes + 1],
- loss_coeff['class'])
- self.loss_coeff['class'][-1] = loss_coeff['no_object']
- self.giou_loss = GIoULoss()
- def _get_loss_class(self, logits, gt_class, match_indices, bg_index,
- num_gts):
- # logits: [b, query, num_classes], gt_class: list[[n, 1]]
- target_label = paddle.full(logits.shape[:2], bg_index, dtype='int64')
- bs, num_query_objects = target_label.shape
- if sum(len(a) for a in gt_class) > 0:
- index, updates = self._get_index_updates(num_query_objects,
- gt_class, match_indices)
- target_label = paddle.scatter(
- target_label.reshape([-1, 1]), index, updates.astype('int64'))
- target_label = target_label.reshape([bs, num_query_objects])
- if self.use_focal_loss:
- target_label = F.one_hot(target_label,
- self.num_classes + 1)[..., :-1]
- return {
- 'loss_class': self.loss_coeff['class'] * sigmoid_focal_loss(
- logits, target_label, num_gts / num_query_objects)
- if self.use_focal_loss else F.cross_entropy(
- logits, target_label, weight=self.loss_coeff['class'])
- }
- def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts):
- # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
- loss = dict()
- if sum(len(a) for a in gt_bbox) == 0:
- loss['loss_bbox'] = paddle.to_tensor([0.])
- loss['loss_giou'] = paddle.to_tensor([0.])
- return loss
- src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox,
- match_indices)
- loss['loss_bbox'] = self.loss_coeff['bbox'] * F.l1_loss(
- src_bbox, target_bbox, reduction='sum') / num_gts
- loss['loss_giou'] = self.giou_loss(
- bbox_cxcywh_to_xyxy(src_bbox), bbox_cxcywh_to_xyxy(target_bbox))
- loss['loss_giou'] = loss['loss_giou'].sum() / num_gts
- loss['loss_giou'] = self.loss_coeff['giou'] * loss['loss_giou']
- return loss
- def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts):
- # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
- loss = dict()
- if sum(len(a) for a in gt_mask) == 0:
- loss['loss_mask'] = paddle.to_tensor([0.])
- loss['loss_dice'] = paddle.to_tensor([0.])
- return loss
- src_masks, target_masks = self._get_src_target_assign(masks, gt_mask,
- match_indices)
- src_masks = F.interpolate(
- src_masks.unsqueeze(0),
- size=target_masks.shape[-2:],
- mode="bilinear")[0]
- loss['loss_mask'] = self.loss_coeff['mask'] * F.sigmoid_focal_loss(
- src_masks,
- target_masks,
- paddle.to_tensor(
- [num_gts], dtype='float32'))
- loss['loss_dice'] = self.loss_coeff['dice'] * self._dice_loss(
- src_masks, target_masks, num_gts)
- return loss
- def _dice_loss(self, inputs, targets, num_gts):
- inputs = F.sigmoid(inputs)
- inputs = inputs.flatten(1)
- targets = targets.flatten(1)
- numerator = 2 * (inputs * targets).sum(1)
- denominator = inputs.sum(-1) + targets.sum(-1)
- loss = 1 - (numerator + 1) / (denominator + 1)
- return loss.sum() / num_gts
- def _get_loss_aux(self, boxes, logits, gt_bbox, gt_class, bg_index,
- num_gts):
- loss_class = []
- loss_bbox = []
- loss_giou = []
- for aux_boxes, aux_logits in zip(boxes, logits):
- match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
- gt_class)
- loss_class.append(
- self._get_loss_class(aux_logits, gt_class, match_indices,
- bg_index, num_gts)['loss_class'])
- loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
- num_gts)
- loss_bbox.append(loss_['loss_bbox'])
- loss_giou.append(loss_['loss_giou'])
- loss = {
- 'loss_class_aux': paddle.add_n(loss_class),
- 'loss_bbox_aux': paddle.add_n(loss_bbox),
- 'loss_giou_aux': paddle.add_n(loss_giou)
- }
- return loss
- def _get_index_updates(self, num_query_objects, target, match_indices):
- batch_idx = paddle.concat([
- paddle.full_like(src, i) for i, (src, _) in enumerate(match_indices)
- ])
- src_idx = paddle.concat([src for (src, _) in match_indices])
- src_idx += (batch_idx * num_query_objects)
- target_assign = paddle.concat([
- paddle.gather(
- t, dst, axis=0) for t, (_, dst) in zip(target, match_indices)
- ])
- return src_idx, target_assign
- def _get_src_target_assign(self, src, target, match_indices):
- src_assign = paddle.concat([
- paddle.gather(
- t, I, axis=0) if len(I) > 0 else paddle.zeros([0, t.shape[-1]])
- for t, (I, _) in zip(src, match_indices)
- ])
- target_assign = paddle.concat([
- paddle.gather(
- t, J, axis=0) if len(J) > 0 else paddle.zeros([0, t.shape[-1]])
- for t, (_, J) in zip(target, match_indices)
- ])
- return src_assign, target_assign
- def forward(self,
- boxes,
- logits,
- gt_bbox,
- gt_class,
- masks=None,
- gt_mask=None):
- r"""
- Args:
- boxes (Tensor): [l, b, query, 4]
- logits (Tensor): [l, b, query, num_classes]
- gt_bbox (List(Tensor)): list[[n, 4]]
- gt_class (List(Tensor)): list[[n, 1]]
- masks (Tensor, optional): [b, query, h, w]
- gt_mask (List(Tensor), optional): list[[n, H, W]]
- """
- match_indices = self.matcher(boxes[-1].detach(), logits[-1].detach(),
- gt_bbox, gt_class)
- num_gts = sum(len(a) for a in gt_bbox)
- try:
- # TODO: Paddle does not have a "paddle.distributed.is_initialized()"
- num_gts = paddle.to_tensor([num_gts], dtype=paddle.float32)
- paddle.distributed.all_reduce(num_gts)
- num_gts = paddle.clip(
- num_gts / paddle.distributed.get_world_size(), min=1).item()
- except:
- num_gts = max(num_gts.item(), 1)
- total_loss = dict()
- total_loss.update(
- self._get_loss_class(logits[-1], gt_class, match_indices,
- self.num_classes, num_gts))
- total_loss.update(
- self._get_loss_bbox(boxes[-1], gt_bbox, match_indices, num_gts))
- if masks is not None and gt_mask is not None:
- total_loss.update(
- self._get_loss_mask(masks, gt_mask, match_indices, num_gts))
- if self.aux_loss:
- total_loss.update(
- self._get_loss_aux(boxes[:-1], logits[:-1], gt_bbox, gt_class,
- self.num_classes, num_gts))
- return total_loss
|