gc_block.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import paddle.fluid as fluid
  19. from paddle.fluid import ParamAttr
  20. from paddle.fluid.initializer import ConstantInitializer
  21. def spatial_pool(x, pooling_type, name):
  22. _, channel, height, width = x.shape
  23. if pooling_type == 'att':
  24. input_x = x
  25. # [N, 1, C, H * W]
  26. input_x = fluid.layers.reshape(input_x, shape=(0, 1, channel, -1))
  27. context_mask = fluid.layers.conv2d(
  28. input=x,
  29. num_filters=1,
  30. filter_size=1,
  31. stride=1,
  32. padding=0,
  33. param_attr=ParamAttr(name=name + "_weights"),
  34. bias_attr=ParamAttr(name=name + "_bias"))
  35. # [N, 1, H * W]
  36. context_mask = fluid.layers.reshape(context_mask, shape=(0, 0, -1))
  37. # [N, 1, H * W]
  38. context_mask = fluid.layers.softmax(context_mask, axis=2)
  39. # [N, 1, H * W, 1]
  40. context_mask = fluid.layers.reshape(context_mask, shape=(0, 0, -1, 1))
  41. # [N, 1, C, 1]
  42. context = fluid.layers.matmul(input_x, context_mask)
  43. # [N, C, 1, 1]
  44. context = fluid.layers.reshape(context, shape=(0, channel, 1, 1))
  45. else:
  46. # [N, C, 1, 1]
  47. context = fluid.layers.pool2d(
  48. input=x, pool_type='avg', global_pooling=True)
  49. return context
  50. def channel_conv(input, inner_ch, out_ch, name):
  51. conv = fluid.layers.conv2d(
  52. input=input,
  53. num_filters=inner_ch,
  54. filter_size=1,
  55. stride=1,
  56. padding=0,
  57. param_attr=ParamAttr(name=name + "_conv1_weights"),
  58. bias_attr=ParamAttr(name=name + "_conv1_bias"),
  59. name=name + "_conv1", )
  60. conv = fluid.layers.layer_norm(
  61. conv,
  62. begin_norm_axis=1,
  63. param_attr=ParamAttr(name=name + "_ln_weights"),
  64. bias_attr=ParamAttr(name=name + "_ln_bias"),
  65. act="relu",
  66. name=name + "_ln")
  67. conv = fluid.layers.conv2d(
  68. input=conv,
  69. num_filters=out_ch,
  70. filter_size=1,
  71. stride=1,
  72. padding=0,
  73. param_attr=ParamAttr(
  74. name=name + "_conv2_weights",
  75. initializer=ConstantInitializer(value=0.0), ),
  76. bias_attr=ParamAttr(
  77. name=name + "_conv2_bias",
  78. initializer=ConstantInitializer(value=0.0), ),
  79. name=name + "_conv2")
  80. return conv
  81. def add_gc_block(x,
  82. ratio=1.0 / 16,
  83. pooling_type='att',
  84. fusion_types=['channel_add'],
  85. name=None):
  86. '''
  87. GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond, see https://arxiv.org/abs/1904.11492
  88. Args:
  89. ratio (float): channel reduction ratio
  90. pooling_type (str): pooling type, support att and avg
  91. fusion_types (list): fusion types, support channel_add and channel_mul
  92. name (str): prefix name of gc block
  93. '''
  94. assert pooling_type in ['avg', 'att']
  95. assert isinstance(fusion_types, (list, tuple))
  96. valid_fusion_types = ['channel_add', 'channel_mul']
  97. assert all([f in valid_fusion_types for f in fusion_types])
  98. assert len(fusion_types) > 0, 'at least one fusion should be used'
  99. inner_ch = int(ratio * x.shape[1])
  100. out_ch = x.shape[1]
  101. context = spatial_pool(x, pooling_type, name + "_spatial_pool")
  102. out = x
  103. if 'channel_mul' in fusion_types:
  104. inner_out = channel_conv(context, inner_ch, out_ch, name + "_mul")
  105. channel_mul_term = fluid.layers.sigmoid(inner_out)
  106. out = out * channel_mul_term
  107. if 'channel_add' in fusion_types:
  108. channel_add_term = channel_conv(context, inner_ch, out_ch,
  109. name + "_add")
  110. out = out + channel_add_term
  111. return out