123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from paddle import ParamAttr
- from paddle.nn.initializer import Constant
- from ppdet.core.workspace import register, serializable
- from ppdet.modeling.layers import ConvNormLayer
- from ..shape_spec import ShapeSpec
- __all__ = ['BiFPN']
- class SeparableConvLayer(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels=None,
- kernel_size=3,
- norm_type='bn',
- norm_groups=32,
- act='swish'):
- super(SeparableConvLayer, self).__init__()
- assert norm_type in ['bn', 'sync_bn', 'gn', None]
- assert act in ['swish', 'relu', None]
- self.in_channels = in_channels
- if out_channels is None:
- self.out_channels = self.in_channels
- self.norm_type = norm_type
- self.norm_groups = norm_groups
- self.depthwise_conv = nn.Conv2D(
- in_channels,
- in_channels,
- kernel_size,
- padding=kernel_size // 2,
- groups=in_channels,
- bias_attr=False)
- self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1)
- # norm type
- if self.norm_type in ['bn', 'sync_bn']:
- self.norm = nn.BatchNorm2D(self.out_channels)
- elif self.norm_type == 'gn':
- self.norm = nn.GroupNorm(
- num_groups=self.norm_groups, num_channels=self.out_channels)
- # activation
- if act == 'swish':
- self.act = nn.Swish()
- elif act == 'relu':
- self.act = nn.ReLU()
- def forward(self, x):
- if self.act is not None:
- x = self.act(x)
- out = self.depthwise_conv(x)
- out = self.pointwise_conv(out)
- if self.norm_type is not None:
- out = self.norm(out)
- return out
- class BiFPNCell(nn.Layer):
- def __init__(self,
- channels=256,
- num_levels=5,
- eps=1e-5,
- use_weighted_fusion=True,
- kernel_size=3,
- norm_type='bn',
- norm_groups=32,
- act='swish'):
- super(BiFPNCell, self).__init__()
- self.channels = channels
- self.num_levels = num_levels
- self.eps = eps
- self.use_weighted_fusion = use_weighted_fusion
- # up
- self.conv_up = nn.LayerList([
- SeparableConvLayer(
- self.channels,
- kernel_size=kernel_size,
- norm_type=norm_type,
- norm_groups=norm_groups,
- act=act) for _ in range(self.num_levels - 1)
- ])
- # down
- self.conv_down = nn.LayerList([
- SeparableConvLayer(
- self.channels,
- kernel_size=kernel_size,
- norm_type=norm_type,
- norm_groups=norm_groups,
- act=act) for _ in range(self.num_levels - 1)
- ])
- if self.use_weighted_fusion:
- self.up_weights = self.create_parameter(
- shape=[self.num_levels - 1, 2],
- attr=ParamAttr(initializer=Constant(1.)))
- self.down_weights = self.create_parameter(
- shape=[self.num_levels - 1, 3],
- attr=ParamAttr(initializer=Constant(1.)))
- def _feature_fusion_cell(self,
- conv_layer,
- lateral_feat,
- sampling_feat,
- route_feat=None,
- weights=None):
- if self.use_weighted_fusion:
- weights = F.relu(weights)
- weights = weights / (weights.sum() + self.eps)
- if route_feat is not None:
- out_feat = weights[0] * lateral_feat + \
- weights[1] * sampling_feat + \
- weights[2] * route_feat
- else:
- out_feat = weights[0] * lateral_feat + \
- weights[1] * sampling_feat
- else:
- if route_feat is not None:
- out_feat = lateral_feat + sampling_feat + route_feat
- else:
- out_feat = lateral_feat + sampling_feat
- out_feat = conv_layer(out_feat)
- return out_feat
- def forward(self, feats):
- # feats: [P3 - P7]
- lateral_feats = []
- # up
- up_feature = feats[-1]
- for i, feature in enumerate(feats[::-1]):
- if i == 0:
- lateral_feats.append(feature)
- else:
- shape = paddle.shape(feature)
- up_feature = F.interpolate(
- up_feature, size=[shape[2], shape[3]])
- lateral_feature = self._feature_fusion_cell(
- self.conv_up[i - 1],
- feature,
- up_feature,
- weights=self.up_weights[i - 1]
- if self.use_weighted_fusion else None)
- lateral_feats.append(lateral_feature)
- up_feature = lateral_feature
- out_feats = []
- # down
- down_feature = lateral_feats[-1]
- for i, (lateral_feature,
- route_feature) in enumerate(zip(lateral_feats[::-1], feats)):
- if i == 0:
- out_feats.append(lateral_feature)
- else:
- down_feature = F.max_pool2d(down_feature, 3, 2, 1)
- if i == len(feats) - 1:
- route_feature = None
- weights = self.down_weights[
- i - 1][:2] if self.use_weighted_fusion else None
- else:
- weights = self.down_weights[
- i - 1] if self.use_weighted_fusion else None
- out_feature = self._feature_fusion_cell(
- self.conv_down[i - 1],
- lateral_feature,
- down_feature,
- route_feature,
- weights=weights)
- out_feats.append(out_feature)
- down_feature = out_feature
- return out_feats
- @register
- @serializable
- class BiFPN(nn.Layer):
- """
- Bidirectional Feature Pyramid Network, see https://arxiv.org/abs/1911.09070
- Args:
- in_channels (list[int]): input channels of each level which can be
- derived from the output shape of backbone by from_config.
- out_channel (int): output channel of each level.
- num_extra_levels (int): the number of extra stages added to the last level.
- default: 2
- fpn_strides (List): The stride of each level.
- num_stacks (int): the number of stacks for BiFPN, default: 1.
- use_weighted_fusion (bool): use weighted feature fusion in BiFPN, default: True.
- norm_type (string|None): the normalization type in BiFPN module. If
- norm_type is None, norm will not be used after conv and if
- norm_type is string, bn, gn, sync_bn are available. default: bn.
- norm_groups (int): if you use gn, set this param.
- act (string|None): the activation function of BiFPN.
- """
- def __init__(self,
- in_channels=(512, 1024, 2048),
- out_channel=256,
- num_extra_levels=2,
- fpn_strides=[8, 16, 32, 64, 128],
- num_stacks=1,
- use_weighted_fusion=True,
- norm_type='bn',
- norm_groups=32,
- act='swish'):
- super(BiFPN, self).__init__()
- assert num_stacks > 0, "The number of stacks of BiFPN is at least 1."
- assert norm_type in ['bn', 'sync_bn', 'gn', None]
- assert act in ['swish', 'relu', None]
- assert num_extra_levels >= 0, \
- "The `num_extra_levels` must be non negative(>=0)."
- self.in_channels = in_channels
- self.out_channel = out_channel
- self.num_extra_levels = num_extra_levels
- self.num_stacks = num_stacks
- self.use_weighted_fusion = use_weighted_fusion
- self.norm_type = norm_type
- self.norm_groups = norm_groups
- self.act = act
- self.num_levels = len(self.in_channels) + self.num_extra_levels
- if len(fpn_strides) != self.num_levels:
- for i in range(self.num_extra_levels):
- fpn_strides += [fpn_strides[-1] * 2]
- self.fpn_strides = fpn_strides
- self.lateral_convs = nn.LayerList()
- for in_c in in_channels:
- self.lateral_convs.append(
- ConvNormLayer(in_c, self.out_channel, 1, 1))
- if self.num_extra_levels > 0:
- self.extra_convs = nn.LayerList()
- for i in range(self.num_extra_levels):
- if i == 0:
- self.extra_convs.append(
- ConvNormLayer(self.in_channels[-1], self.out_channel, 3,
- 2))
- else:
- self.extra_convs.append(nn.MaxPool2D(3, 2, 1))
- self.bifpn_cells = nn.LayerList()
- for i in range(self.num_stacks):
- self.bifpn_cells.append(
- BiFPNCell(
- self.out_channel,
- self.num_levels,
- use_weighted_fusion=self.use_weighted_fusion,
- norm_type=self.norm_type,
- norm_groups=self.norm_groups,
- act=self.act))
- @classmethod
- def from_config(cls, cfg, input_shape):
- return {
- 'in_channels': [i.channels for i in input_shape],
- 'fpn_strides': [i.stride for i in input_shape]
- }
- @property
- def out_shape(self):
- return [
- ShapeSpec(
- channels=self.out_channel, stride=s) for s in self.fpn_strides
- ]
- def forward(self, feats):
- assert len(feats) == len(self.in_channels)
- fpn_feats = []
- for conv_layer, feature in zip(self.lateral_convs, feats):
- fpn_feats.append(conv_layer(feature))
- if self.num_extra_levels > 0:
- feat = feats[-1]
- for conv_layer in self.extra_convs:
- feat = conv_layer(feat)
- fpn_feats.append(feat)
- for bifpn_cell in self.bifpn_cells:
- fpn_feats = bifpn_cell(fpn_feats)
- return fpn_feats
|