blazeface_fpn.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn.functional as F
  16. from paddle import ParamAttr
  17. import paddle.nn as nn
  18. from paddle.nn.initializer import KaimingNormal
  19. from ppdet.core.workspace import register, serializable
  20. from ..shape_spec import ShapeSpec
  21. __all__ = ['BlazeNeck']
  22. def hard_swish(x):
  23. return x * F.relu6(x + 3) / 6.
  24. class ConvBNLayer(nn.Layer):
  25. def __init__(self,
  26. in_channels,
  27. out_channels,
  28. kernel_size,
  29. stride,
  30. padding,
  31. num_groups=1,
  32. act='relu',
  33. conv_lr=0.1,
  34. conv_decay=0.,
  35. norm_decay=0.,
  36. norm_type='bn',
  37. name=None):
  38. super(ConvBNLayer, self).__init__()
  39. self.act = act
  40. self._conv = nn.Conv2D(
  41. in_channels,
  42. out_channels,
  43. kernel_size=kernel_size,
  44. stride=stride,
  45. padding=padding,
  46. groups=num_groups,
  47. weight_attr=ParamAttr(
  48. learning_rate=conv_lr, initializer=KaimingNormal()),
  49. bias_attr=False)
  50. if norm_type in ['sync_bn', 'bn']:
  51. self._batch_norm = nn.BatchNorm2D(out_channels)
  52. def forward(self, x):
  53. x = self._conv(x)
  54. x = self._batch_norm(x)
  55. if self.act == "relu":
  56. x = F.relu(x)
  57. elif self.act == "relu6":
  58. x = F.relu6(x)
  59. elif self.act == 'leaky':
  60. x = F.leaky_relu(x)
  61. elif self.act == 'hard_swish':
  62. x = hard_swish(x)
  63. return x
  64. class FPN(nn.Layer):
  65. def __init__(self, in_channels, out_channels, name=None):
  66. super(FPN, self).__init__()
  67. self.conv1_fpn = ConvBNLayer(
  68. in_channels,
  69. out_channels // 2,
  70. kernel_size=1,
  71. padding=0,
  72. stride=1,
  73. act='leaky',
  74. name=name + '_output1')
  75. self.conv2_fpn = ConvBNLayer(
  76. in_channels,
  77. out_channels // 2,
  78. kernel_size=1,
  79. padding=0,
  80. stride=1,
  81. act='leaky',
  82. name=name + '_output2')
  83. self.conv3_fpn = ConvBNLayer(
  84. out_channels // 2,
  85. out_channels // 2,
  86. kernel_size=3,
  87. padding=1,
  88. stride=1,
  89. act='leaky',
  90. name=name + '_merge')
  91. def forward(self, input):
  92. output1 = self.conv1_fpn(input[0])
  93. output2 = self.conv2_fpn(input[1])
  94. up2 = F.upsample(
  95. output2, size=paddle.shape(output1)[-2:], mode='nearest')
  96. output1 = paddle.add(output1, up2)
  97. output1 = self.conv3_fpn(output1)
  98. return output1, output2
  99. class SSH(nn.Layer):
  100. def __init__(self, in_channels, out_channels, name=None):
  101. super(SSH, self).__init__()
  102. assert out_channels % 4 == 0
  103. self.conv0_ssh = ConvBNLayer(
  104. in_channels,
  105. out_channels // 2,
  106. kernel_size=3,
  107. padding=1,
  108. stride=1,
  109. act=None,
  110. name=name + 'ssh_conv3')
  111. self.conv1_ssh = ConvBNLayer(
  112. out_channels // 2,
  113. out_channels // 4,
  114. kernel_size=3,
  115. padding=1,
  116. stride=1,
  117. act='leaky',
  118. name=name + 'ssh_conv5_1')
  119. self.conv2_ssh = ConvBNLayer(
  120. out_channels // 4,
  121. out_channels // 4,
  122. kernel_size=3,
  123. padding=1,
  124. stride=1,
  125. act=None,
  126. name=name + 'ssh_conv5_2')
  127. self.conv3_ssh = ConvBNLayer(
  128. out_channels // 4,
  129. out_channels // 4,
  130. kernel_size=3,
  131. padding=1,
  132. stride=1,
  133. act='leaky',
  134. name=name + 'ssh_conv7_1')
  135. self.conv4_ssh = ConvBNLayer(
  136. out_channels // 4,
  137. out_channels // 4,
  138. kernel_size=3,
  139. padding=1,
  140. stride=1,
  141. act=None,
  142. name=name + 'ssh_conv7_2')
  143. def forward(self, x):
  144. conv0 = self.conv0_ssh(x)
  145. conv1 = self.conv1_ssh(conv0)
  146. conv2 = self.conv2_ssh(conv1)
  147. conv3 = self.conv3_ssh(conv2)
  148. conv4 = self.conv4_ssh(conv3)
  149. concat = paddle.concat([conv0, conv2, conv4], axis=1)
  150. return F.relu(concat)
  151. @register
  152. @serializable
  153. class BlazeNeck(nn.Layer):
  154. def __init__(self, in_channel, neck_type="None", data_format='NCHW'):
  155. super(BlazeNeck, self).__init__()
  156. self.neck_type = neck_type
  157. self.reture_input = False
  158. self._out_channels = in_channel
  159. if self.neck_type == 'None':
  160. self.reture_input = True
  161. if "fpn" in self.neck_type:
  162. self.fpn = FPN(self._out_channels[0],
  163. self._out_channels[1],
  164. name='fpn')
  165. self._out_channels = [
  166. self._out_channels[0] // 2, self._out_channels[1] // 2
  167. ]
  168. if "ssh" in self.neck_type:
  169. self.ssh1 = SSH(self._out_channels[0],
  170. self._out_channels[0],
  171. name='ssh1')
  172. self.ssh2 = SSH(self._out_channels[1],
  173. self._out_channels[1],
  174. name='ssh2')
  175. self._out_channels = [self._out_channels[0], self._out_channels[1]]
  176. def forward(self, inputs):
  177. if self.reture_input:
  178. return inputs
  179. output1, output2 = None, None
  180. if "fpn" in self.neck_type:
  181. backout_4, backout_1 = inputs
  182. output1, output2 = self.fpn([backout_4, backout_1])
  183. if self.neck_type == "only_fpn":
  184. return [output1, output2]
  185. if self.neck_type == "only_ssh":
  186. output1, output2 = inputs
  187. feature1 = self.ssh1(output1)
  188. feature2 = self.ssh2(output2)
  189. return [feature1, feature2]
  190. @property
  191. def out_shape(self):
  192. return [
  193. ShapeSpec(channels=c)
  194. for c in [self._out_channels[0], self._out_channels[1]]
  195. ]