custom_pan.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from ppdet.core.workspace import register, serializable
  18. from ppdet.modeling.layers import DropBlock
  19. from ppdet.modeling.ops import get_act_fn
  20. from ..backbones.cspresnet import ConvBNLayer, BasicBlock
  21. from ..shape_spec import ShapeSpec
  22. __all__ = ['CustomCSPPAN']
  23. class SPP(nn.Layer):
  24. def __init__(self,
  25. ch_in,
  26. ch_out,
  27. k,
  28. pool_size,
  29. act='swish',
  30. data_format='NCHW'):
  31. super(SPP, self).__init__()
  32. self.pool = []
  33. self.data_format = data_format
  34. for i, size in enumerate(pool_size):
  35. pool = self.add_sublayer(
  36. 'pool{}'.format(i),
  37. nn.MaxPool2D(
  38. kernel_size=size,
  39. stride=1,
  40. padding=size // 2,
  41. data_format=data_format,
  42. ceil_mode=False))
  43. self.pool.append(pool)
  44. self.conv = ConvBNLayer(ch_in, ch_out, k, padding=k // 2, act=act)
  45. def forward(self, x):
  46. outs = [x]
  47. for pool in self.pool:
  48. outs.append(pool(x))
  49. if self.data_format == 'NCHW':
  50. y = paddle.concat(outs, axis=1)
  51. else:
  52. y = paddle.concat(outs, axis=-1)
  53. y = self.conv(y)
  54. return y
  55. class CSPStage(nn.Layer):
  56. def __init__(self, block_fn, ch_in, ch_out, n, act='swish', spp=False):
  57. super(CSPStage, self).__init__()
  58. ch_mid = int(ch_out // 2)
  59. self.conv1 = ConvBNLayer(ch_in, ch_mid, 1, act=act)
  60. self.conv2 = ConvBNLayer(ch_in, ch_mid, 1, act=act)
  61. self.convs = nn.Sequential()
  62. next_ch_in = ch_mid
  63. for i in range(n):
  64. self.convs.add_sublayer(
  65. str(i),
  66. eval(block_fn)(next_ch_in, ch_mid, act=act, shortcut=False))
  67. if i == (n - 1) // 2 and spp:
  68. self.convs.add_sublayer(
  69. 'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))
  70. next_ch_in = ch_mid
  71. self.conv3 = ConvBNLayer(ch_mid * 2, ch_out, 1, act=act)
  72. def forward(self, x):
  73. y1 = self.conv1(x)
  74. y2 = self.conv2(x)
  75. y2 = self.convs(y2)
  76. y = paddle.concat([y1, y2], axis=1)
  77. y = self.conv3(y)
  78. return y
  79. @register
  80. @serializable
  81. class CustomCSPPAN(nn.Layer):
  82. __shared__ = ['norm_type', 'data_format', 'width_mult', 'depth_mult', 'trt']
  83. def __init__(self,
  84. in_channels=[256, 512, 1024],
  85. out_channels=[1024, 512, 256],
  86. norm_type='bn',
  87. act='leaky',
  88. stage_fn='CSPStage',
  89. block_fn='BasicBlock',
  90. stage_num=1,
  91. block_num=3,
  92. drop_block=False,
  93. block_size=3,
  94. keep_prob=0.9,
  95. spp=False,
  96. data_format='NCHW',
  97. width_mult=1.0,
  98. depth_mult=1.0,
  99. trt=False):
  100. super(CustomCSPPAN, self).__init__()
  101. out_channels = [max(round(c * width_mult), 1) for c in out_channels]
  102. block_num = max(round(block_num * depth_mult), 1)
  103. act = get_act_fn(
  104. act, trt=trt) if act is None or isinstance(act,
  105. (str, dict)) else act
  106. self.num_blocks = len(in_channels)
  107. self.data_format = data_format
  108. self._out_channels = out_channels
  109. in_channels = in_channels[::-1]
  110. fpn_stages = []
  111. fpn_routes = []
  112. for i, (ch_in, ch_out) in enumerate(zip(in_channels, out_channels)):
  113. if i > 0:
  114. ch_in += ch_pre // 2
  115. stage = nn.Sequential()
  116. for j in range(stage_num):
  117. stage.add_sublayer(
  118. str(j),
  119. eval(stage_fn)(block_fn,
  120. ch_in if j == 0 else ch_out,
  121. ch_out,
  122. block_num,
  123. act=act,
  124. spp=(spp and i == 0)))
  125. if drop_block:
  126. stage.add_sublayer('drop', DropBlock(block_size, keep_prob))
  127. fpn_stages.append(stage)
  128. if i < self.num_blocks - 1:
  129. fpn_routes.append(
  130. ConvBNLayer(
  131. ch_in=ch_out,
  132. ch_out=ch_out // 2,
  133. filter_size=1,
  134. stride=1,
  135. padding=0,
  136. act=act))
  137. ch_pre = ch_out
  138. self.fpn_stages = nn.LayerList(fpn_stages)
  139. self.fpn_routes = nn.LayerList(fpn_routes)
  140. pan_stages = []
  141. pan_routes = []
  142. for i in reversed(range(self.num_blocks - 1)):
  143. pan_routes.append(
  144. ConvBNLayer(
  145. ch_in=out_channels[i + 1],
  146. ch_out=out_channels[i + 1],
  147. filter_size=3,
  148. stride=2,
  149. padding=1,
  150. act=act))
  151. ch_in = out_channels[i] + out_channels[i + 1]
  152. ch_out = out_channels[i]
  153. stage = nn.Sequential()
  154. for j in range(stage_num):
  155. stage.add_sublayer(
  156. str(j),
  157. eval(stage_fn)(block_fn,
  158. ch_in if j == 0 else ch_out,
  159. ch_out,
  160. block_num,
  161. act=act,
  162. spp=False))
  163. if drop_block:
  164. stage.add_sublayer('drop', DropBlock(block_size, keep_prob))
  165. pan_stages.append(stage)
  166. self.pan_stages = nn.LayerList(pan_stages[::-1])
  167. self.pan_routes = nn.LayerList(pan_routes[::-1])
  168. def forward(self, blocks, for_mot=False):
  169. blocks = blocks[::-1]
  170. fpn_feats = []
  171. for i, block in enumerate(blocks):
  172. if i > 0:
  173. block = paddle.concat([route, block], axis=1)
  174. route = self.fpn_stages[i](block)
  175. fpn_feats.append(route)
  176. if i < self.num_blocks - 1:
  177. route = self.fpn_routes[i](route)
  178. route = F.interpolate(
  179. route, scale_factor=2., data_format=self.data_format)
  180. pan_feats = [fpn_feats[-1], ]
  181. route = fpn_feats[-1]
  182. for i in reversed(range(self.num_blocks - 1)):
  183. block = fpn_feats[i]
  184. route = self.pan_routes[i](route)
  185. block = paddle.concat([route, block], axis=1)
  186. route = self.pan_stages[i](block)
  187. pan_feats.append(route)
  188. return pan_feats[::-1]
  189. @classmethod
  190. def from_config(cls, cfg, input_shape):
  191. return {'in_channels': [i.channels for i in input_shape], }
  192. @property
  193. def out_shape(self):
  194. return [ShapeSpec(channels=c) for c in self._out_channels]