mask_head.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn.initializer import KaimingNormal
  18. from ppdet.core.workspace import register, create
  19. from ppdet.modeling.layers import ConvNormLayer
  20. from .roi_extractor import RoIAlign
  21. @register
  22. class MaskFeat(nn.Layer):
  23. """
  24. Feature extraction in Mask head
  25. Args:
  26. in_channel (int): Input channels
  27. out_channel (int): Output channels
  28. num_convs (int): The number of conv layers, default 4
  29. norm_type (string | None): Norm type, bn, gn, sync_bn are available,
  30. default None
  31. """
  32. def __init__(self,
  33. in_channel=256,
  34. out_channel=256,
  35. num_convs=4,
  36. norm_type=None):
  37. super(MaskFeat, self).__init__()
  38. self.num_convs = num_convs
  39. self.in_channel = in_channel
  40. self.out_channel = out_channel
  41. self.norm_type = norm_type
  42. fan_conv = out_channel * 3 * 3
  43. fan_deconv = out_channel * 2 * 2
  44. mask_conv = nn.Sequential()
  45. if norm_type == 'gn':
  46. for i in range(self.num_convs):
  47. conv_name = 'mask_inter_feat_{}'.format(i + 1)
  48. mask_conv.add_sublayer(
  49. conv_name,
  50. ConvNormLayer(
  51. ch_in=in_channel if i == 0 else out_channel,
  52. ch_out=out_channel,
  53. filter_size=3,
  54. stride=1,
  55. norm_type=self.norm_type,
  56. initializer=KaimingNormal(fan_in=fan_conv),
  57. skip_quant=True))
  58. mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
  59. else:
  60. for i in range(self.num_convs):
  61. conv_name = 'mask_inter_feat_{}'.format(i + 1)
  62. conv = nn.Conv2D(
  63. in_channels=in_channel if i == 0 else out_channel,
  64. out_channels=out_channel,
  65. kernel_size=3,
  66. padding=1,
  67. weight_attr=paddle.ParamAttr(
  68. initializer=KaimingNormal(fan_in=fan_conv)))
  69. conv.skip_quant = True
  70. mask_conv.add_sublayer(conv_name, conv)
  71. mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
  72. mask_conv.add_sublayer(
  73. 'conv5_mask',
  74. nn.Conv2DTranspose(
  75. in_channels=self.in_channel,
  76. out_channels=self.out_channel,
  77. kernel_size=2,
  78. stride=2,
  79. weight_attr=paddle.ParamAttr(
  80. initializer=KaimingNormal(fan_in=fan_deconv))))
  81. mask_conv.add_sublayer('conv5_mask' + 'act', nn.ReLU())
  82. self.upsample = mask_conv
  83. @classmethod
  84. def from_config(cls, cfg, input_shape):
  85. if isinstance(input_shape, (list, tuple)):
  86. input_shape = input_shape[0]
  87. return {'in_channel': input_shape.channels, }
  88. def out_channels(self):
  89. return self.out_channel
  90. def forward(self, feats):
  91. return self.upsample(feats)
  92. @register
  93. class MaskHead(nn.Layer):
  94. __shared__ = ['num_classes', 'export_onnx']
  95. __inject__ = ['mask_assigner']
  96. """
  97. RCNN mask head
  98. Args:
  99. head (nn.Layer): Extract feature in mask head
  100. roi_extractor (object): The module of RoI Extractor
  101. mask_assigner (object): The module of Mask Assigner,
  102. label and sample the mask
  103. num_classes (int): The number of classes
  104. share_bbox_feat (bool): Whether to share the feature from bbox head,
  105. default false
  106. """
  107. def __init__(self,
  108. head,
  109. roi_extractor=RoIAlign().__dict__,
  110. mask_assigner='MaskAssigner',
  111. num_classes=80,
  112. share_bbox_feat=False,
  113. export_onnx=False):
  114. super(MaskHead, self).__init__()
  115. self.num_classes = num_classes
  116. self.export_onnx = export_onnx
  117. self.roi_extractor = roi_extractor
  118. if isinstance(roi_extractor, dict):
  119. self.roi_extractor = RoIAlign(**roi_extractor)
  120. self.head = head
  121. self.in_channels = head.out_channels()
  122. self.mask_assigner = mask_assigner
  123. self.share_bbox_feat = share_bbox_feat
  124. self.bbox_head = None
  125. self.mask_fcn_logits = nn.Conv2D(
  126. in_channels=self.in_channels,
  127. out_channels=self.num_classes,
  128. kernel_size=1,
  129. weight_attr=paddle.ParamAttr(initializer=KaimingNormal(
  130. fan_in=self.num_classes)))
  131. self.mask_fcn_logits.skip_quant = True
  132. @classmethod
  133. def from_config(cls, cfg, input_shape):
  134. roi_pooler = cfg['roi_extractor']
  135. assert isinstance(roi_pooler, dict)
  136. kwargs = RoIAlign.from_config(cfg, input_shape)
  137. roi_pooler.update(kwargs)
  138. kwargs = {'input_shape': input_shape}
  139. head = create(cfg['head'], **kwargs)
  140. return {
  141. 'roi_extractor': roi_pooler,
  142. 'head': head,
  143. }
  144. def get_loss(self, mask_logits, mask_label, mask_target, mask_weight):
  145. mask_label = F.one_hot(mask_label, self.num_classes).unsqueeze([2, 3])
  146. mask_label = paddle.expand_as(mask_label, mask_logits)
  147. mask_label.stop_gradient = True
  148. mask_pred = paddle.gather_nd(mask_logits, paddle.nonzero(mask_label))
  149. shape = mask_logits.shape
  150. mask_pred = paddle.reshape(mask_pred, [shape[0], shape[2], shape[3]])
  151. mask_target = mask_target.cast('float32')
  152. mask_weight = mask_weight.unsqueeze([1, 2])
  153. loss_mask = F.binary_cross_entropy_with_logits(
  154. mask_pred, mask_target, weight=mask_weight, reduction="mean")
  155. return loss_mask
  156. def forward_train(self, body_feats, rois, rois_num, inputs, targets,
  157. bbox_feat):
  158. """
  159. body_feats (list[Tensor]): Multi-level backbone features
  160. rois (list[Tensor]): Proposals for each batch with shape [N, 4]
  161. rois_num (Tensor): The number of proposals for each batch
  162. inputs (dict): ground truth info
  163. """
  164. tgt_labels, _, tgt_gt_inds = targets
  165. rois, rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights = self.mask_assigner(
  166. rois, tgt_labels, tgt_gt_inds, inputs)
  167. if self.share_bbox_feat:
  168. rois_feat = paddle.gather(bbox_feat, mask_index)
  169. else:
  170. rois_feat = self.roi_extractor(body_feats, rois, rois_num)
  171. mask_feat = self.head(rois_feat)
  172. mask_logits = self.mask_fcn_logits(mask_feat)
  173. loss_mask = self.get_loss(mask_logits, tgt_classes, tgt_masks,
  174. tgt_weights)
  175. return {'loss_mask': loss_mask}
  176. def forward_test(self,
  177. body_feats,
  178. rois,
  179. rois_num,
  180. scale_factor,
  181. feat_func=None):
  182. """
  183. body_feats (list[Tensor]): Multi-level backbone features
  184. rois (Tensor): Prediction from bbox head with shape [N, 6]
  185. rois_num (Tensor): The number of prediction for each batch
  186. scale_factor (Tensor): The scale factor from origin size to input size
  187. """
  188. if not self.export_onnx and rois.shape[0] == 0:
  189. mask_out = paddle.full([1, 1, 1], -1)
  190. else:
  191. bbox = [rois[:, 2:]]
  192. labels = rois[:, 0].cast('int32')
  193. rois_feat = self.roi_extractor(body_feats, bbox, rois_num)
  194. if self.share_bbox_feat:
  195. assert feat_func is not None
  196. rois_feat = feat_func(rois_feat)
  197. mask_feat = self.head(rois_feat)
  198. mask_logit = self.mask_fcn_logits(mask_feat)
  199. if self.num_classes == 1:
  200. mask_out = F.sigmoid(mask_logit)[:, 0, :, :]
  201. else:
  202. num_masks = paddle.shape(mask_logit)[0]
  203. index = paddle.arange(num_masks).cast('int32')
  204. mask_out = mask_logit[index, labels]
  205. mask_out_shape = paddle.shape(mask_out)
  206. mask_out = paddle.reshape(mask_out, [
  207. paddle.shape(index), mask_out_shape[-2], mask_out_shape[-1]
  208. ])
  209. mask_out = F.sigmoid(mask_out)
  210. return mask_out
  211. def forward(self,
  212. body_feats,
  213. rois,
  214. rois_num,
  215. inputs,
  216. targets=None,
  217. bbox_feat=None,
  218. feat_func=None):
  219. if self.training:
  220. return self.forward_train(body_feats, rois, rois_num, inputs,
  221. targets, bbox_feat)
  222. else:
  223. im_scale = inputs['scale_factor']
  224. return self.forward_test(body_feats, rois, rois_num, im_scale,
  225. feat_func)