123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import paddle
- import paddle.nn as nn
- from paddle import ParamAttr
- from paddle.nn import AdaptiveAvgPool2D, Conv2D
- from paddle.regularizer import L2Decay
- from paddle.nn.initializer import KaimingNormal
- from ppdet.core.workspace import register, serializable
- from numbers import Integral
- from ..shape_spec import ShapeSpec
- __all__ = ['LCNet']
- NET_CONFIG = {
- "blocks2":
- #k, in_c, out_c, s, use_se
- [[3, 16, 32, 1, False], ],
- "blocks3": [
- [3, 32, 64, 2, False],
- [3, 64, 64, 1, False],
- ],
- "blocks4": [
- [3, 64, 128, 2, False],
- [3, 128, 128, 1, False],
- ],
- "blocks5": [
- [3, 128, 256, 2, False],
- [5, 256, 256, 1, False],
- [5, 256, 256, 1, False],
- [5, 256, 256, 1, False],
- [5, 256, 256, 1, False],
- [5, 256, 256, 1, False],
- ],
- "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
- }
- def make_divisible(v, divisor=8, min_value=None):
- if min_value is None:
- min_value = divisor
- new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
- if new_v < 0.9 * v:
- new_v += divisor
- return new_v
- class ConvBNLayer(nn.Layer):
- def __init__(self,
- num_channels,
- filter_size,
- num_filters,
- stride,
- num_groups=1):
- super().__init__()
- self.conv = Conv2D(
- in_channels=num_channels,
- out_channels=num_filters,
- kernel_size=filter_size,
- stride=stride,
- padding=(filter_size - 1) // 2,
- groups=num_groups,
- weight_attr=ParamAttr(initializer=KaimingNormal()),
- bias_attr=False)
- self.bn = nn.BatchNorm2D(
- num_filters,
- weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
- bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
- self.hardswish = nn.Hardswish()
- def forward(self, x):
- x = self.conv(x)
- x = self.bn(x)
- x = self.hardswish(x)
- return x
- class DepthwiseSeparable(nn.Layer):
- def __init__(self,
- num_channels,
- num_filters,
- stride,
- dw_size=3,
- use_se=False):
- super().__init__()
- self.use_se = use_se
- self.dw_conv = ConvBNLayer(
- num_channels=num_channels,
- num_filters=num_channels,
- filter_size=dw_size,
- stride=stride,
- num_groups=num_channels)
- if use_se:
- self.se = SEModule(num_channels)
- self.pw_conv = ConvBNLayer(
- num_channels=num_channels,
- filter_size=1,
- num_filters=num_filters,
- stride=1)
- def forward(self, x):
- x = self.dw_conv(x)
- if self.use_se:
- x = self.se(x)
- x = self.pw_conv(x)
- return x
- class SEModule(nn.Layer):
- def __init__(self, channel, reduction=4):
- super().__init__()
- self.avg_pool = AdaptiveAvgPool2D(1)
- self.conv1 = Conv2D(
- in_channels=channel,
- out_channels=channel // reduction,
- kernel_size=1,
- stride=1,
- padding=0)
- self.relu = nn.ReLU()
- self.conv2 = Conv2D(
- in_channels=channel // reduction,
- out_channels=channel,
- kernel_size=1,
- stride=1,
- padding=0)
- self.hardsigmoid = nn.Hardsigmoid()
- def forward(self, x):
- identity = x
- x = self.avg_pool(x)
- x = self.conv1(x)
- x = self.relu(x)
- x = self.conv2(x)
- x = self.hardsigmoid(x)
- x = paddle.multiply(x=identity, y=x)
- return x
- @register
- @serializable
- class LCNet(nn.Layer):
- def __init__(self, scale=1.0, feature_maps=[3, 4, 5]):
- super().__init__()
- self.scale = scale
- self.feature_maps = feature_maps
- out_channels = []
- self.conv1 = ConvBNLayer(
- num_channels=3,
- filter_size=3,
- num_filters=make_divisible(16 * scale),
- stride=2)
- self.blocks2 = nn.Sequential(* [
- DepthwiseSeparable(
- num_channels=make_divisible(in_c * scale),
- num_filters=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se)
- for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"])
- ])
- self.blocks3 = nn.Sequential(* [
- DepthwiseSeparable(
- num_channels=make_divisible(in_c * scale),
- num_filters=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se)
- for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"])
- ])
- out_channels.append(
- make_divisible(NET_CONFIG["blocks3"][-1][2] * scale))
- self.blocks4 = nn.Sequential(* [
- DepthwiseSeparable(
- num_channels=make_divisible(in_c * scale),
- num_filters=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se)
- for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"])
- ])
- out_channels.append(
- make_divisible(NET_CONFIG["blocks4"][-1][2] * scale))
- self.blocks5 = nn.Sequential(* [
- DepthwiseSeparable(
- num_channels=make_divisible(in_c * scale),
- num_filters=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se)
- for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"])
- ])
- out_channels.append(
- make_divisible(NET_CONFIG["blocks5"][-1][2] * scale))
- self.blocks6 = nn.Sequential(* [
- DepthwiseSeparable(
- num_channels=make_divisible(in_c * scale),
- num_filters=make_divisible(out_c * scale),
- dw_size=k,
- stride=s,
- use_se=se)
- for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"])
- ])
- out_channels.append(
- make_divisible(NET_CONFIG["blocks6"][-1][2] * scale))
- self._out_channels = [
- ch for idx, ch in enumerate(out_channels) if idx + 2 in feature_maps
- ]
- def forward(self, inputs):
- x = inputs['image']
- outs = []
- x = self.conv1(x)
- x = self.blocks2(x)
- x = self.blocks3(x)
- outs.append(x)
- x = self.blocks4(x)
- outs.append(x)
- x = self.blocks5(x)
- outs.append(x)
- x = self.blocks6(x)
- outs.append(x)
- outs = [o for i, o in enumerate(outs) if i + 2 in self.feature_maps]
- return outs
- @property
- def out_shape(self):
- return [ShapeSpec(channels=c) for c in self._out_channels]
|