dla.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from ppdet.core.workspace import register, serializable
  18. from ppdet.modeling.layers import ConvNormLayer
  19. from ..shape_spec import ShapeSpec
  20. DLA_cfg = {34: ([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512])}
  21. class BasicBlock(nn.Layer):
  22. def __init__(self, ch_in, ch_out, stride=1):
  23. super(BasicBlock, self).__init__()
  24. self.conv1 = ConvNormLayer(
  25. ch_in,
  26. ch_out,
  27. filter_size=3,
  28. stride=stride,
  29. bias_on=False,
  30. norm_decay=None)
  31. self.conv2 = ConvNormLayer(
  32. ch_out,
  33. ch_out,
  34. filter_size=3,
  35. stride=1,
  36. bias_on=False,
  37. norm_decay=None)
  38. def forward(self, inputs, residual=None):
  39. if residual is None:
  40. residual = inputs
  41. out = self.conv1(inputs)
  42. out = F.relu(out)
  43. out = self.conv2(out)
  44. out = paddle.add(x=out, y=residual)
  45. out = F.relu(out)
  46. return out
  47. class Root(nn.Layer):
  48. def __init__(self, ch_in, ch_out, kernel_size, residual):
  49. super(Root, self).__init__()
  50. self.conv = ConvNormLayer(
  51. ch_in,
  52. ch_out,
  53. filter_size=1,
  54. stride=1,
  55. bias_on=False,
  56. norm_decay=None)
  57. self.residual = residual
  58. def forward(self, inputs):
  59. children = inputs
  60. out = self.conv(paddle.concat(inputs, axis=1))
  61. if self.residual:
  62. out = paddle.add(x=out, y=children[0])
  63. out = F.relu(out)
  64. return out
  65. class Tree(nn.Layer):
  66. def __init__(self,
  67. level,
  68. block,
  69. ch_in,
  70. ch_out,
  71. stride=1,
  72. level_root=False,
  73. root_dim=0,
  74. root_kernel_size=1,
  75. root_residual=False):
  76. super(Tree, self).__init__()
  77. if root_dim == 0:
  78. root_dim = 2 * ch_out
  79. if level_root:
  80. root_dim += ch_in
  81. if level == 1:
  82. self.tree1 = block(ch_in, ch_out, stride)
  83. self.tree2 = block(ch_out, ch_out, 1)
  84. else:
  85. self.tree1 = Tree(
  86. level - 1,
  87. block,
  88. ch_in,
  89. ch_out,
  90. stride,
  91. root_dim=0,
  92. root_kernel_size=root_kernel_size,
  93. root_residual=root_residual)
  94. self.tree2 = Tree(
  95. level - 1,
  96. block,
  97. ch_out,
  98. ch_out,
  99. 1,
  100. root_dim=root_dim + ch_out,
  101. root_kernel_size=root_kernel_size,
  102. root_residual=root_residual)
  103. if level == 1:
  104. self.root = Root(root_dim, ch_out, root_kernel_size, root_residual)
  105. self.level_root = level_root
  106. self.root_dim = root_dim
  107. self.downsample = None
  108. self.project = None
  109. self.level = level
  110. if stride > 1:
  111. self.downsample = nn.MaxPool2D(stride, stride=stride)
  112. if ch_in != ch_out:
  113. self.project = ConvNormLayer(
  114. ch_in,
  115. ch_out,
  116. filter_size=1,
  117. stride=1,
  118. bias_on=False,
  119. norm_decay=None)
  120. def forward(self, x, residual=None, children=None):
  121. children = [] if children is None else children
  122. bottom = self.downsample(x) if self.downsample else x
  123. residual = self.project(bottom) if self.project else bottom
  124. if self.level_root:
  125. children.append(bottom)
  126. x1 = self.tree1(x, residual)
  127. if self.level == 1:
  128. x2 = self.tree2(x1)
  129. x = self.root([x2, x1] + children)
  130. else:
  131. children.append(x1)
  132. x = self.tree2(x1, children=children)
  133. return x
  134. @register
  135. @serializable
  136. class DLA(nn.Layer):
  137. """
  138. DLA, see https://arxiv.org/pdf/1707.06484.pdf
  139. Args:
  140. depth (int): DLA depth, should be 34.
  141. residual_root (bool): whether use a reidual layer in the root block
  142. """
  143. def __init__(self, depth=34, residual_root=False):
  144. super(DLA, self).__init__()
  145. levels, channels = DLA_cfg[depth]
  146. if depth == 34:
  147. block = BasicBlock
  148. self.channels = channels
  149. self.base_layer = nn.Sequential(
  150. ConvNormLayer(
  151. 3,
  152. channels[0],
  153. filter_size=7,
  154. stride=1,
  155. bias_on=False,
  156. norm_decay=None),
  157. nn.ReLU())
  158. self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
  159. self.level1 = self._make_conv_level(
  160. channels[0], channels[1], levels[1], stride=2)
  161. self.level2 = Tree(
  162. levels[2],
  163. block,
  164. channels[1],
  165. channels[2],
  166. 2,
  167. level_root=False,
  168. root_residual=residual_root)
  169. self.level3 = Tree(
  170. levels[3],
  171. block,
  172. channels[2],
  173. channels[3],
  174. 2,
  175. level_root=True,
  176. root_residual=residual_root)
  177. self.level4 = Tree(
  178. levels[4],
  179. block,
  180. channels[3],
  181. channels[4],
  182. 2,
  183. level_root=True,
  184. root_residual=residual_root)
  185. self.level5 = Tree(
  186. levels[5],
  187. block,
  188. channels[4],
  189. channels[5],
  190. 2,
  191. level_root=True,
  192. root_residual=residual_root)
  193. def _make_conv_level(self, ch_in, ch_out, conv_num, stride=1):
  194. modules = []
  195. for i in range(conv_num):
  196. modules.extend([
  197. ConvNormLayer(
  198. ch_in,
  199. ch_out,
  200. filter_size=3,
  201. stride=stride if i == 0 else 1,
  202. bias_on=False,
  203. norm_decay=None), nn.ReLU()
  204. ])
  205. ch_in = ch_out
  206. return nn.Sequential(*modules)
  207. @property
  208. def out_shape(self):
  209. return [ShapeSpec(channels=self.channels[i]) for i in range(6)]
  210. def forward(self, inputs):
  211. outs = []
  212. im = inputs['image']
  213. feats = self.base_layer(im)
  214. for i in range(6):
  215. feats = getattr(self, 'level{}'.format(i))(feats)
  216. outs.append(feats)
  217. return outs