123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import math
- import numpy as np
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from paddle import ParamAttr
- from paddle.nn.initializer import Normal, Constant
- from paddle.fluid.dygraph import parallel_helper
- from ppdet.modeling.ops import get_static_shape
- from ..initializer import normal_
- from ..assigners.utils import generate_anchors_for_grid_cell
- from ..bbox_utils import bbox_center, batch_distance2bbox, bbox2distance
- from ppdet.core.workspace import register
- from ppdet.modeling.layers import ConvNormLayer
- from .simota_head import OTAVFLHead
- from .gfl_head import Integral, GFLHead
- from ppdet.modeling.necks.csp_pan import DPModule
- eps = 1e-9
- __all__ = ['PicoHead', 'PicoHeadV2', 'PicoFeat']
- class PicoSE(nn.Layer):
- def __init__(self, feat_channels):
- super(PicoSE, self).__init__()
- self.fc = nn.Conv2D(feat_channels, feat_channels, 1)
- self.conv = ConvNormLayer(feat_channels, feat_channels, 1, 1)
- self._init_weights()
- def _init_weights(self):
- normal_(self.fc.weight, std=0.001)
- def forward(self, feat, avg_feat):
- weight = F.sigmoid(self.fc(avg_feat))
- out = self.conv(feat * weight)
- return out
- @register
- class PicoFeat(nn.Layer):
- """
- PicoFeat of PicoDet
- Args:
- feat_in (int): The channel number of input Tensor.
- feat_out (int): The channel number of output Tensor.
- num_convs (int): The convolution number of the LiteGFLFeat.
- norm_type (str): Normalization type, 'bn'/'sync_bn'/'gn'.
- share_cls_reg (bool): Whether to share the cls and reg output.
- act (str): The act of per layers.
- use_se (bool): Whether to use se module.
- """
- def __init__(self,
- feat_in=256,
- feat_out=96,
- num_fpn_stride=3,
- num_convs=2,
- norm_type='bn',
- share_cls_reg=False,
- act='hard_swish',
- use_se=False):
- super(PicoFeat, self).__init__()
- self.num_convs = num_convs
- self.norm_type = norm_type
- self.share_cls_reg = share_cls_reg
- self.act = act
- self.use_se = use_se
- self.cls_convs = []
- self.reg_convs = []
- if use_se:
- assert share_cls_reg == True, \
- 'In the case of using se, share_cls_reg is not supported'
- self.se = nn.LayerList()
- for stage_idx in range(num_fpn_stride):
- cls_subnet_convs = []
- reg_subnet_convs = []
- for i in range(self.num_convs):
- in_c = feat_in if i == 0 else feat_out
- cls_conv_dw = self.add_sublayer(
- 'cls_conv_dw{}.{}'.format(stage_idx, i),
- ConvNormLayer(
- ch_in=in_c,
- ch_out=feat_out,
- filter_size=5,
- stride=1,
- groups=feat_out,
- norm_type=norm_type,
- bias_on=False,
- lr_scale=2.))
- cls_subnet_convs.append(cls_conv_dw)
- cls_conv_pw = self.add_sublayer(
- 'cls_conv_pw{}.{}'.format(stage_idx, i),
- ConvNormLayer(
- ch_in=in_c,
- ch_out=feat_out,
- filter_size=1,
- stride=1,
- norm_type=norm_type,
- bias_on=False,
- lr_scale=2.))
- cls_subnet_convs.append(cls_conv_pw)
- if not self.share_cls_reg:
- reg_conv_dw = self.add_sublayer(
- 'reg_conv_dw{}.{}'.format(stage_idx, i),
- ConvNormLayer(
- ch_in=in_c,
- ch_out=feat_out,
- filter_size=5,
- stride=1,
- groups=feat_out,
- norm_type=norm_type,
- bias_on=False,
- lr_scale=2.))
- reg_subnet_convs.append(reg_conv_dw)
- reg_conv_pw = self.add_sublayer(
- 'reg_conv_pw{}.{}'.format(stage_idx, i),
- ConvNormLayer(
- ch_in=in_c,
- ch_out=feat_out,
- filter_size=1,
- stride=1,
- norm_type=norm_type,
- bias_on=False,
- lr_scale=2.))
- reg_subnet_convs.append(reg_conv_pw)
- self.cls_convs.append(cls_subnet_convs)
- self.reg_convs.append(reg_subnet_convs)
- if use_se:
- self.se.append(PicoSE(feat_out))
- def act_func(self, x):
- if self.act == "leaky_relu":
- x = F.leaky_relu(x)
- elif self.act == "hard_swish":
- x = F.hardswish(x)
- return x
- def forward(self, fpn_feat, stage_idx):
- assert stage_idx < len(self.cls_convs)
- cls_feat = fpn_feat
- reg_feat = fpn_feat
- for i in range(len(self.cls_convs[stage_idx])):
- cls_feat = self.act_func(self.cls_convs[stage_idx][i](cls_feat))
- reg_feat = cls_feat
- if not self.share_cls_reg:
- reg_feat = self.act_func(self.reg_convs[stage_idx][i](reg_feat))
- if self.use_se:
- avg_feat = F.adaptive_avg_pool2d(cls_feat, (1, 1))
- se_feat = self.act_func(self.se[stage_idx](cls_feat, avg_feat))
- return cls_feat, se_feat
- return cls_feat, reg_feat
- @register
- class PicoHead(OTAVFLHead):
- """
- PicoHead
- Args:
- conv_feat (object): Instance of 'PicoFeat'
- num_classes (int): Number of classes
- fpn_stride (list): The stride of each FPN Layer
- prior_prob (float): Used to set the bias init for the class prediction layer
- loss_class (object): Instance of VariFocalLoss.
- loss_dfl (object): Instance of DistributionFocalLoss.
- loss_bbox (object): Instance of bbox loss.
- assigner (object): Instance of label assigner.
- reg_max: Max value of integral set :math: `{0, ..., reg_max}`
- n QFL setting. Default: 7.
- """
- __inject__ = [
- 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
- 'assigner', 'nms'
- ]
- __shared__ = ['num_classes', 'eval_size']
- def __init__(self,
- conv_feat='PicoFeat',
- dgqp_module=None,
- num_classes=80,
- fpn_stride=[8, 16, 32],
- prior_prob=0.01,
- loss_class='VariFocalLoss',
- loss_dfl='DistributionFocalLoss',
- loss_bbox='GIoULoss',
- assigner='SimOTAAssigner',
- reg_max=16,
- feat_in_chan=96,
- nms=None,
- nms_pre=1000,
- cell_offset=0,
- eval_size=None):
- super(PicoHead, self).__init__(
- conv_feat=conv_feat,
- dgqp_module=dgqp_module,
- num_classes=num_classes,
- fpn_stride=fpn_stride,
- prior_prob=prior_prob,
- loss_class=loss_class,
- loss_dfl=loss_dfl,
- loss_bbox=loss_bbox,
- assigner=assigner,
- reg_max=reg_max,
- feat_in_chan=feat_in_chan,
- nms=nms,
- nms_pre=nms_pre,
- cell_offset=cell_offset)
- self.conv_feat = conv_feat
- self.num_classes = num_classes
- self.fpn_stride = fpn_stride
- self.prior_prob = prior_prob
- self.loss_vfl = loss_class
- self.loss_dfl = loss_dfl
- self.loss_bbox = loss_bbox
- self.assigner = assigner
- self.reg_max = reg_max
- self.feat_in_chan = feat_in_chan
- self.nms = nms
- self.nms_pre = nms_pre
- self.cell_offset = cell_offset
- self.eval_size = eval_size
- self.use_sigmoid = self.loss_vfl.use_sigmoid
- if self.use_sigmoid:
- self.cls_out_channels = self.num_classes
- else:
- self.cls_out_channels = self.num_classes + 1
- bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
- # Clear the super class initialization
- self.gfl_head_cls = None
- self.gfl_head_reg = None
- self.scales_regs = None
- self.head_cls_list = []
- self.head_reg_list = []
- for i in range(len(fpn_stride)):
- head_cls = self.add_sublayer(
- "head_cls" + str(i),
- nn.Conv2D(
- in_channels=self.feat_in_chan,
- out_channels=self.cls_out_channels + 4 * (self.reg_max + 1)
- if self.conv_feat.share_cls_reg else self.cls_out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- weight_attr=ParamAttr(initializer=Normal(
- mean=0., std=0.01)),
- bias_attr=ParamAttr(
- initializer=Constant(value=bias_init_value))))
- self.head_cls_list.append(head_cls)
- if not self.conv_feat.share_cls_reg:
- head_reg = self.add_sublayer(
- "head_reg" + str(i),
- nn.Conv2D(
- in_channels=self.feat_in_chan,
- out_channels=4 * (self.reg_max + 1),
- kernel_size=1,
- stride=1,
- padding=0,
- weight_attr=ParamAttr(initializer=Normal(
- mean=0., std=0.01)),
- bias_attr=ParamAttr(initializer=Constant(value=0))))
- self.head_reg_list.append(head_reg)
- # initialize the anchor points
- if self.eval_size:
- self.anchor_points, self.stride_tensor = self._generate_anchors()
- def forward(self, fpn_feats, export_post_process=True):
- assert len(fpn_feats) == len(
- self.fpn_stride
- ), "The size of fpn_feats is not equal to size of fpn_stride"
- if self.training:
- return self.forward_train(fpn_feats)
- else:
- return self.forward_eval(
- fpn_feats, export_post_process=export_post_process)
- def forward_train(self, fpn_feats):
- cls_logits_list, bboxes_reg_list = [], []
- for i, fpn_feat in enumerate(fpn_feats):
- conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i)
- if self.conv_feat.share_cls_reg:
- cls_logits = self.head_cls_list[i](conv_cls_feat)
- cls_score, bbox_pred = paddle.split(
- cls_logits,
- [self.cls_out_channels, 4 * (self.reg_max + 1)],
- axis=1)
- else:
- cls_score = self.head_cls_list[i](conv_cls_feat)
- bbox_pred = self.head_reg_list[i](conv_reg_feat)
- if self.dgqp_module:
- quality_score = self.dgqp_module(bbox_pred)
- cls_score = F.sigmoid(cls_score) * quality_score
- cls_logits_list.append(cls_score)
- bboxes_reg_list.append(bbox_pred)
- return (cls_logits_list, bboxes_reg_list)
- def forward_eval(self, fpn_feats, export_post_process=True):
- if self.eval_size:
- anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
- else:
- anchor_points, stride_tensor = self._generate_anchors(fpn_feats)
- cls_logits_list, bboxes_reg_list = [], []
- for i, fpn_feat in enumerate(fpn_feats):
- conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i)
- if self.conv_feat.share_cls_reg:
- cls_logits = self.head_cls_list[i](conv_cls_feat)
- cls_score, bbox_pred = paddle.split(
- cls_logits,
- [self.cls_out_channels, 4 * (self.reg_max + 1)],
- axis=1)
- else:
- cls_score = self.head_cls_list[i](conv_cls_feat)
- bbox_pred = self.head_reg_list[i](conv_reg_feat)
- if self.dgqp_module:
- quality_score = self.dgqp_module(bbox_pred)
- cls_score = F.sigmoid(cls_score) * quality_score
- if not export_post_process:
- # Now only supports batch size = 1 in deploy
- # TODO(ygh): support batch size > 1
- cls_score_out = F.sigmoid(cls_score).reshape(
- [1, self.cls_out_channels, -1]).transpose([0, 2, 1])
- bbox_pred = bbox_pred.reshape([1, (self.reg_max + 1) * 4,
- -1]).transpose([0, 2, 1])
- else:
- b, _, h, w = fpn_feat.shape
- l = h * w
- cls_score_out = F.sigmoid(
- cls_score.reshape([b, self.cls_out_channels, l]))
- bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
- bbox_pred = self.distribution_project(bbox_pred)
- bbox_pred = bbox_pred.reshape([b, l, 4])
- cls_logits_list.append(cls_score_out)
- bboxes_reg_list.append(bbox_pred)
- if export_post_process:
- cls_logits_list = paddle.concat(cls_logits_list, axis=-1)
- bboxes_reg_list = paddle.concat(bboxes_reg_list, axis=1)
- bboxes_reg_list = batch_distance2bbox(anchor_points,
- bboxes_reg_list)
- bboxes_reg_list *= stride_tensor
- return (cls_logits_list, bboxes_reg_list)
- def _generate_anchors(self, feats=None):
- # just use in eval time
- anchor_points = []
- stride_tensor = []
- for i, stride in enumerate(self.fpn_stride):
- if feats is not None:
- _, _, h, w = feats[i].shape
- else:
- h = math.ceil(self.eval_size[0] / stride)
- w = math.ceil(self.eval_size[1] / stride)
- shift_x = paddle.arange(end=w) + self.cell_offset
- shift_y = paddle.arange(end=h) + self.cell_offset
- shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
- anchor_point = paddle.cast(
- paddle.stack(
- [shift_x, shift_y], axis=-1), dtype='float32')
- anchor_points.append(anchor_point.reshape([-1, 2]))
- stride_tensor.append(
- paddle.full(
- [h * w, 1], stride, dtype='float32'))
- anchor_points = paddle.concat(anchor_points)
- stride_tensor = paddle.concat(stride_tensor)
- return anchor_points, stride_tensor
- def post_process(self, head_outs, scale_factor, export_nms=True):
- pred_scores, pred_bboxes = head_outs
- if not export_nms:
- return pred_bboxes, pred_scores
- else:
- # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
- scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
- scale_factor = paddle.concat(
- [scale_x, scale_y, scale_x, scale_y],
- axis=-1).reshape([-1, 1, 4])
- # scale bbox to origin image size.
- pred_bboxes /= scale_factor
- bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
- return bbox_pred, bbox_num
- @register
- class PicoHeadV2(GFLHead):
- """
- PicoHeadV2
- Args:
- conv_feat (object): Instance of 'PicoFeat'
- num_classes (int): Number of classes
- fpn_stride (list): The stride of each FPN Layer
- prior_prob (float): Used to set the bias init for the class prediction layer
- loss_class (object): Instance of VariFocalLoss.
- loss_dfl (object): Instance of DistributionFocalLoss.
- loss_bbox (object): Instance of bbox loss.
- assigner (object): Instance of label assigner.
- reg_max: Max value of integral set :math: `{0, ..., reg_max}`
- n QFL setting. Default: 7.
- """
- __inject__ = [
- 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
- 'static_assigner', 'assigner', 'nms'
- ]
- __shared__ = ['num_classes', 'eval_size']
- def __init__(self,
- conv_feat='PicoFeatV2',
- dgqp_module=None,
- num_classes=80,
- fpn_stride=[8, 16, 32],
- prior_prob=0.01,
- use_align_head=True,
- loss_class='VariFocalLoss',
- loss_dfl='DistributionFocalLoss',
- loss_bbox='GIoULoss',
- static_assigner_epoch=60,
- static_assigner='ATSSAssigner',
- assigner='TaskAlignedAssigner',
- reg_max=16,
- feat_in_chan=96,
- nms=None,
- nms_pre=1000,
- cell_offset=0,
- act='hard_swish',
- grid_cell_scale=5.0,
- eval_size=None):
- super(PicoHeadV2, self).__init__(
- conv_feat=conv_feat,
- dgqp_module=dgqp_module,
- num_classes=num_classes,
- fpn_stride=fpn_stride,
- prior_prob=prior_prob,
- loss_class=loss_class,
- loss_dfl=loss_dfl,
- loss_bbox=loss_bbox,
- reg_max=reg_max,
- feat_in_chan=feat_in_chan,
- nms=nms,
- nms_pre=nms_pre,
- cell_offset=cell_offset, )
- self.conv_feat = conv_feat
- self.num_classes = num_classes
- self.fpn_stride = fpn_stride
- self.prior_prob = prior_prob
- self.loss_vfl = loss_class
- self.loss_dfl = loss_dfl
- self.loss_bbox = loss_bbox
- self.static_assigner_epoch = static_assigner_epoch
- self.static_assigner = static_assigner
- self.assigner = assigner
- self.reg_max = reg_max
- self.feat_in_chan = feat_in_chan
- self.nms = nms
- self.nms_pre = nms_pre
- self.cell_offset = cell_offset
- self.act = act
- self.grid_cell_scale = grid_cell_scale
- self.use_align_head = use_align_head
- self.cls_out_channels = self.num_classes
- self.eval_size = eval_size
- bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
- # Clear the super class initialization
- self.gfl_head_cls = None
- self.gfl_head_reg = None
- self.scales_regs = None
- self.head_cls_list = []
- self.head_reg_list = []
- self.cls_align = nn.LayerList()
- for i in range(len(fpn_stride)):
- head_cls = self.add_sublayer(
- "head_cls" + str(i),
- nn.Conv2D(
- in_channels=self.feat_in_chan,
- out_channels=self.cls_out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- weight_attr=ParamAttr(initializer=Normal(
- mean=0., std=0.01)),
- bias_attr=ParamAttr(
- initializer=Constant(value=bias_init_value))))
- self.head_cls_list.append(head_cls)
- head_reg = self.add_sublayer(
- "head_reg" + str(i),
- nn.Conv2D(
- in_channels=self.feat_in_chan,
- out_channels=4 * (self.reg_max + 1),
- kernel_size=1,
- stride=1,
- padding=0,
- weight_attr=ParamAttr(initializer=Normal(
- mean=0., std=0.01)),
- bias_attr=ParamAttr(initializer=Constant(value=0))))
- self.head_reg_list.append(head_reg)
- if self.use_align_head:
- self.cls_align.append(
- DPModule(
- self.feat_in_chan,
- 1,
- 5,
- act=self.act,
- use_act_in_out=False))
- # initialize the anchor points
- if self.eval_size:
- self.anchor_points, self.stride_tensor = self._generate_anchors()
- def forward(self, fpn_feats, export_post_process=True):
- assert len(fpn_feats) == len(
- self.fpn_stride
- ), "The size of fpn_feats is not equal to size of fpn_stride"
- if self.training:
- return self.forward_train(fpn_feats)
- else:
- return self.forward_eval(
- fpn_feats, export_post_process=export_post_process)
- def forward_train(self, fpn_feats):
- cls_score_list, reg_list, box_list = [], [], []
- for i, (fpn_feat, stride) in enumerate(zip(fpn_feats, self.fpn_stride)):
- b, _, h, w = get_static_shape(fpn_feat)
- # task decomposition
- conv_cls_feat, se_feat = self.conv_feat(fpn_feat, i)
- cls_logit = self.head_cls_list[i](se_feat)
- reg_pred = self.head_reg_list[i](se_feat)
- # cls prediction and alignment
- if self.use_align_head:
- cls_prob = F.sigmoid(self.cls_align[i](conv_cls_feat))
- cls_score = (F.sigmoid(cls_logit) * cls_prob + eps).sqrt()
- else:
- cls_score = F.sigmoid(cls_logit)
- cls_score_out = cls_score.transpose([0, 2, 3, 1])
- bbox_pred = reg_pred.transpose([0, 2, 3, 1])
- b, cell_h, cell_w, _ = paddle.shape(cls_score_out)
- y, x = self.get_single_level_center_point(
- [cell_h, cell_w], stride, cell_offset=self.cell_offset)
- center_points = paddle.stack([x, y], axis=-1)
- cls_score_out = cls_score_out.reshape(
- [b, -1, self.cls_out_channels])
- bbox_pred = self.distribution_project(bbox_pred) * stride
- bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])
- bbox_pred = batch_distance2bbox(
- center_points, bbox_pred, max_shapes=None)
- cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
- reg_list.append(reg_pred.flatten(2).transpose([0, 2, 1]))
- box_list.append(bbox_pred / stride)
- cls_score_list = paddle.concat(cls_score_list, axis=1)
- box_list = paddle.concat(box_list, axis=1)
- reg_list = paddle.concat(reg_list, axis=1)
- return cls_score_list, reg_list, box_list, fpn_feats
- def forward_eval(self, fpn_feats, export_post_process=True):
- if self.eval_size:
- anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
- else:
- anchor_points, stride_tensor = self._generate_anchors(fpn_feats)
- cls_score_list, box_list = [], []
- for i, (fpn_feat, stride) in enumerate(zip(fpn_feats, self.fpn_stride)):
- b, _, h, w = fpn_feat.shape
- # task decomposition
- conv_cls_feat, se_feat = self.conv_feat(fpn_feat, i)
- cls_logit = self.head_cls_list[i](se_feat)
- reg_pred = self.head_reg_list[i](se_feat)
- # cls prediction and alignment
- if self.use_align_head:
- cls_prob = F.sigmoid(self.cls_align[i](conv_cls_feat))
- cls_score = (F.sigmoid(cls_logit) * cls_prob + eps).sqrt()
- else:
- cls_score = F.sigmoid(cls_logit)
- if not export_post_process:
- # Now only supports batch size = 1 in deploy
- cls_score_list.append(
- cls_score.reshape([1, self.cls_out_channels, -1]).transpose(
- [0, 2, 1]))
- box_list.append(
- reg_pred.reshape([1, (self.reg_max + 1) * 4, -1]).transpose(
- [0, 2, 1]))
- else:
- l = h * w
- cls_score_out = cls_score.reshape([b, self.cls_out_channels, l])
- bbox_pred = reg_pred.transpose([0, 2, 3, 1])
- bbox_pred = self.distribution_project(bbox_pred)
- bbox_pred = bbox_pred.reshape([b, l, 4])
- cls_score_list.append(cls_score_out)
- box_list.append(bbox_pred)
- if export_post_process:
- cls_score_list = paddle.concat(cls_score_list, axis=-1)
- box_list = paddle.concat(box_list, axis=1)
- box_list = batch_distance2bbox(anchor_points, box_list)
- box_list *= stride_tensor
- return cls_score_list, box_list
- def get_loss(self, head_outs, gt_meta):
- pred_scores, pred_regs, pred_bboxes, fpn_feats = head_outs
- gt_labels = gt_meta['gt_class']
- gt_bboxes = gt_meta['gt_bbox']
- gt_scores = gt_meta['gt_score'] if 'gt_score' in gt_meta else None
- num_imgs = gt_meta['im_id'].shape[0]
- pad_gt_mask = gt_meta['pad_gt_mask']
- anchors, _, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell(
- fpn_feats, self.fpn_stride, self.grid_cell_scale, self.cell_offset)
- centers = bbox_center(anchors)
- # label assignment
- if gt_meta['epoch_id'] < self.static_assigner_epoch:
- assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner(
- anchors,
- num_anchors_list,
- gt_labels,
- gt_bboxes,
- pad_gt_mask,
- bg_index=self.num_classes,
- gt_scores=gt_scores,
- pred_bboxes=pred_bboxes.detach() * stride_tensor_list)
- else:
- assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
- pred_scores.detach(),
- pred_bboxes.detach() * stride_tensor_list,
- centers,
- num_anchors_list,
- gt_labels,
- gt_bboxes,
- pad_gt_mask,
- bg_index=self.num_classes,
- gt_scores=gt_scores)
- assigned_bboxes /= stride_tensor_list
- centers_shape = centers.shape
- flatten_centers = centers.expand(
- [num_imgs, centers_shape[0], centers_shape[1]]).reshape([-1, 2])
- flatten_strides = stride_tensor_list.expand(
- [num_imgs, centers_shape[0], 1]).reshape([-1, 1])
- flatten_cls_preds = pred_scores.reshape([-1, self.num_classes])
- flatten_regs = pred_regs.reshape([-1, 4 * (self.reg_max + 1)])
- flatten_bboxes = pred_bboxes.reshape([-1, 4])
- flatten_bbox_targets = assigned_bboxes.reshape([-1, 4])
- flatten_labels = assigned_labels.reshape([-1])
- flatten_assigned_scores = assigned_scores.reshape(
- [-1, self.num_classes])
- pos_inds = paddle.nonzero(
- paddle.logical_and((flatten_labels >= 0),
- (flatten_labels < self.num_classes)),
- as_tuple=False).squeeze(1)
- num_total_pos = len(pos_inds)
- if num_total_pos > 0:
- pos_bbox_targets = paddle.gather(
- flatten_bbox_targets, pos_inds, axis=0)
- pos_decode_bbox_pred = paddle.gather(
- flatten_bboxes, pos_inds, axis=0)
- pos_reg = paddle.gather(flatten_regs, pos_inds, axis=0)
- pos_strides = paddle.gather(flatten_strides, pos_inds, axis=0)
- pos_centers = paddle.gather(
- flatten_centers, pos_inds, axis=0) / pos_strides
- weight_targets = flatten_assigned_scores.detach()
- weight_targets = paddle.gather(
- weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0)
- pred_corners = pos_reg.reshape([-1, self.reg_max + 1])
- target_corners = bbox2distance(pos_centers, pos_bbox_targets,
- self.reg_max).reshape([-1])
- # regression loss
- loss_bbox = paddle.sum(
- self.loss_bbox(pos_decode_bbox_pred,
- pos_bbox_targets) * weight_targets)
- # dfl loss
- loss_dfl = self.loss_dfl(
- pred_corners,
- target_corners,
- weight=weight_targets.expand([-1, 4]).reshape([-1]),
- avg_factor=4.0)
- else:
- loss_bbox = paddle.zeros([1])
- loss_dfl = paddle.zeros([1])
- avg_factor = flatten_assigned_scores.sum()
- if paddle.fluid.core.is_compiled_with_dist(
- ) and parallel_helper._is_parallel_ctx_initialized():
- paddle.distributed.all_reduce(avg_factor)
- avg_factor = paddle.clip(
- avg_factor / paddle.distributed.get_world_size(), min=1)
- loss_vfl = self.loss_vfl(
- flatten_cls_preds, flatten_assigned_scores, avg_factor=avg_factor)
- loss_bbox = loss_bbox / avg_factor
- loss_dfl = loss_dfl / avg_factor
- loss_states = dict(
- loss_vfl=loss_vfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
- return loss_states
- def _generate_anchors(self, feats=None):
- # just use in eval time
- anchor_points = []
- stride_tensor = []
- for i, stride in enumerate(self.fpn_stride):
- if feats is not None:
- _, _, h, w = feats[i].shape
- else:
- h = math.ceil(self.eval_size[0] / stride)
- w = math.ceil(self.eval_size[1] / stride)
- shift_x = paddle.arange(end=w) + self.cell_offset
- shift_y = paddle.arange(end=h) + self.cell_offset
- shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
- anchor_point = paddle.cast(
- paddle.stack(
- [shift_x, shift_y], axis=-1), dtype='float32')
- anchor_points.append(anchor_point.reshape([-1, 2]))
- stride_tensor.append(
- paddle.full(
- [h * w, 1], stride, dtype='float32'))
- anchor_points = paddle.concat(anchor_points)
- stride_tensor = paddle.concat(stride_tensor)
- return anchor_points, stride_tensor
- def post_process(self, head_outs, scale_factor, export_nms=True):
- pred_scores, pred_bboxes = head_outs
- if not export_nms:
- return pred_bboxes, pred_scores
- else:
- # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
- scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
- scale_factor = paddle.concat(
- [scale_x, scale_y, scale_x, scale_y],
- axis=-1).reshape([-1, 1, 4])
- # scale bbox to origin image size.
- pred_bboxes /= scale_factor
- bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
- return bbox_pred, bbox_num
|