bifpn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle import ParamAttr
  18. from paddle.nn.initializer import Constant
  19. from ppdet.core.workspace import register, serializable
  20. from ppdet.modeling.layers import ConvNormLayer
  21. from ..shape_spec import ShapeSpec
  22. __all__ = ['BiFPN']
  23. class SeparableConvLayer(nn.Layer):
  24. def __init__(self,
  25. in_channels,
  26. out_channels=None,
  27. kernel_size=3,
  28. norm_type='bn',
  29. norm_groups=32,
  30. act='swish'):
  31. super(SeparableConvLayer, self).__init__()
  32. assert norm_type in ['bn', 'sync_bn', 'gn', None]
  33. assert act in ['swish', 'relu', None]
  34. self.in_channels = in_channels
  35. if out_channels is None:
  36. self.out_channels = self.in_channels
  37. self.norm_type = norm_type
  38. self.norm_groups = norm_groups
  39. self.depthwise_conv = nn.Conv2D(
  40. in_channels,
  41. in_channels,
  42. kernel_size,
  43. padding=kernel_size // 2,
  44. groups=in_channels,
  45. bias_attr=False)
  46. self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1)
  47. # norm type
  48. if self.norm_type in ['bn', 'sync_bn']:
  49. self.norm = nn.BatchNorm2D(self.out_channels)
  50. elif self.norm_type == 'gn':
  51. self.norm = nn.GroupNorm(
  52. num_groups=self.norm_groups, num_channels=self.out_channels)
  53. # activation
  54. if act == 'swish':
  55. self.act = nn.Swish()
  56. elif act == 'relu':
  57. self.act = nn.ReLU()
  58. def forward(self, x):
  59. if self.act is not None:
  60. x = self.act(x)
  61. out = self.depthwise_conv(x)
  62. out = self.pointwise_conv(out)
  63. if self.norm_type is not None:
  64. out = self.norm(out)
  65. return out
  66. class BiFPNCell(nn.Layer):
  67. def __init__(self,
  68. channels=256,
  69. num_levels=5,
  70. eps=1e-5,
  71. use_weighted_fusion=True,
  72. kernel_size=3,
  73. norm_type='bn',
  74. norm_groups=32,
  75. act='swish'):
  76. super(BiFPNCell, self).__init__()
  77. self.channels = channels
  78. self.num_levels = num_levels
  79. self.eps = eps
  80. self.use_weighted_fusion = use_weighted_fusion
  81. # up
  82. self.conv_up = nn.LayerList([
  83. SeparableConvLayer(
  84. self.channels,
  85. kernel_size=kernel_size,
  86. norm_type=norm_type,
  87. norm_groups=norm_groups,
  88. act=act) for _ in range(self.num_levels - 1)
  89. ])
  90. # down
  91. self.conv_down = nn.LayerList([
  92. SeparableConvLayer(
  93. self.channels,
  94. kernel_size=kernel_size,
  95. norm_type=norm_type,
  96. norm_groups=norm_groups,
  97. act=act) for _ in range(self.num_levels - 1)
  98. ])
  99. if self.use_weighted_fusion:
  100. self.up_weights = self.create_parameter(
  101. shape=[self.num_levels - 1, 2],
  102. attr=ParamAttr(initializer=Constant(1.)))
  103. self.down_weights = self.create_parameter(
  104. shape=[self.num_levels - 1, 3],
  105. attr=ParamAttr(initializer=Constant(1.)))
  106. def _feature_fusion_cell(self,
  107. conv_layer,
  108. lateral_feat,
  109. sampling_feat,
  110. route_feat=None,
  111. weights=None):
  112. if self.use_weighted_fusion:
  113. weights = F.relu(weights)
  114. weights = weights / (weights.sum() + self.eps)
  115. if route_feat is not None:
  116. out_feat = weights[0] * lateral_feat + \
  117. weights[1] * sampling_feat + \
  118. weights[2] * route_feat
  119. else:
  120. out_feat = weights[0] * lateral_feat + \
  121. weights[1] * sampling_feat
  122. else:
  123. if route_feat is not None:
  124. out_feat = lateral_feat + sampling_feat + route_feat
  125. else:
  126. out_feat = lateral_feat + sampling_feat
  127. out_feat = conv_layer(out_feat)
  128. return out_feat
  129. def forward(self, feats):
  130. # feats: [P3 - P7]
  131. lateral_feats = []
  132. # up
  133. up_feature = feats[-1]
  134. for i, feature in enumerate(feats[::-1]):
  135. if i == 0:
  136. lateral_feats.append(feature)
  137. else:
  138. shape = paddle.shape(feature)
  139. up_feature = F.interpolate(
  140. up_feature, size=[shape[2], shape[3]])
  141. lateral_feature = self._feature_fusion_cell(
  142. self.conv_up[i - 1],
  143. feature,
  144. up_feature,
  145. weights=self.up_weights[i - 1]
  146. if self.use_weighted_fusion else None)
  147. lateral_feats.append(lateral_feature)
  148. up_feature = lateral_feature
  149. out_feats = []
  150. # down
  151. down_feature = lateral_feats[-1]
  152. for i, (lateral_feature,
  153. route_feature) in enumerate(zip(lateral_feats[::-1], feats)):
  154. if i == 0:
  155. out_feats.append(lateral_feature)
  156. else:
  157. down_feature = F.max_pool2d(down_feature, 3, 2, 1)
  158. if i == len(feats) - 1:
  159. route_feature = None
  160. weights = self.down_weights[
  161. i - 1][:2] if self.use_weighted_fusion else None
  162. else:
  163. weights = self.down_weights[
  164. i - 1] if self.use_weighted_fusion else None
  165. out_feature = self._feature_fusion_cell(
  166. self.conv_down[i - 1],
  167. lateral_feature,
  168. down_feature,
  169. route_feature,
  170. weights=weights)
  171. out_feats.append(out_feature)
  172. down_feature = out_feature
  173. return out_feats
  174. @register
  175. @serializable
  176. class BiFPN(nn.Layer):
  177. """
  178. Bidirectional Feature Pyramid Network, see https://arxiv.org/abs/1911.09070
  179. Args:
  180. in_channels (list[int]): input channels of each level which can be
  181. derived from the output shape of backbone by from_config.
  182. out_channel (int): output channel of each level.
  183. num_extra_levels (int): the number of extra stages added to the last level.
  184. default: 2
  185. fpn_strides (List): The stride of each level.
  186. num_stacks (int): the number of stacks for BiFPN, default: 1.
  187. use_weighted_fusion (bool): use weighted feature fusion in BiFPN, default: True.
  188. norm_type (string|None): the normalization type in BiFPN module. If
  189. norm_type is None, norm will not be used after conv and if
  190. norm_type is string, bn, gn, sync_bn are available. default: bn.
  191. norm_groups (int): if you use gn, set this param.
  192. act (string|None): the activation function of BiFPN.
  193. """
  194. def __init__(self,
  195. in_channels=(512, 1024, 2048),
  196. out_channel=256,
  197. num_extra_levels=2,
  198. fpn_strides=[8, 16, 32, 64, 128],
  199. num_stacks=1,
  200. use_weighted_fusion=True,
  201. norm_type='bn',
  202. norm_groups=32,
  203. act='swish'):
  204. super(BiFPN, self).__init__()
  205. assert num_stacks > 0, "The number of stacks of BiFPN is at least 1."
  206. assert norm_type in ['bn', 'sync_bn', 'gn', None]
  207. assert act in ['swish', 'relu', None]
  208. assert num_extra_levels >= 0, \
  209. "The `num_extra_levels` must be non negative(>=0)."
  210. self.in_channels = in_channels
  211. self.out_channel = out_channel
  212. self.num_extra_levels = num_extra_levels
  213. self.num_stacks = num_stacks
  214. self.use_weighted_fusion = use_weighted_fusion
  215. self.norm_type = norm_type
  216. self.norm_groups = norm_groups
  217. self.act = act
  218. self.num_levels = len(self.in_channels) + self.num_extra_levels
  219. if len(fpn_strides) != self.num_levels:
  220. for i in range(self.num_extra_levels):
  221. fpn_strides += [fpn_strides[-1] * 2]
  222. self.fpn_strides = fpn_strides
  223. self.lateral_convs = nn.LayerList()
  224. for in_c in in_channels:
  225. self.lateral_convs.append(
  226. ConvNormLayer(in_c, self.out_channel, 1, 1))
  227. if self.num_extra_levels > 0:
  228. self.extra_convs = nn.LayerList()
  229. for i in range(self.num_extra_levels):
  230. if i == 0:
  231. self.extra_convs.append(
  232. ConvNormLayer(self.in_channels[-1], self.out_channel, 3,
  233. 2))
  234. else:
  235. self.extra_convs.append(nn.MaxPool2D(3, 2, 1))
  236. self.bifpn_cells = nn.LayerList()
  237. for i in range(self.num_stacks):
  238. self.bifpn_cells.append(
  239. BiFPNCell(
  240. self.out_channel,
  241. self.num_levels,
  242. use_weighted_fusion=self.use_weighted_fusion,
  243. norm_type=self.norm_type,
  244. norm_groups=self.norm_groups,
  245. act=self.act))
  246. @classmethod
  247. def from_config(cls, cfg, input_shape):
  248. return {
  249. 'in_channels': [i.channels for i in input_shape],
  250. 'fpn_strides': [i.stride for i in input_shape]
  251. }
  252. @property
  253. def out_shape(self):
  254. return [
  255. ShapeSpec(
  256. channels=self.out_channel, stride=s) for s in self.fpn_strides
  257. ]
  258. def forward(self, feats):
  259. assert len(feats) == len(self.in_channels)
  260. fpn_feats = []
  261. for conv_layer, feature in zip(self.lateral_convs, feats):
  262. fpn_feats.append(conv_layer(feature))
  263. if self.num_extra_levels > 0:
  264. feat = feats[-1]
  265. for conv_layer in self.extra_convs:
  266. feat = conv_layer(feat)
  267. fpn_feats.append(feat)
  268. for bifpn_cell in self.bifpn_cells:
  269. fpn_feats = bifpn_cell(fpn_feats)
  270. return fpn_feats