cascade_head.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn.initializer import Normal
  18. from ppdet.core.workspace import register
  19. from .bbox_head import BBoxHead, TwoFCHead, XConvNormHead
  20. from .roi_extractor import RoIAlign
  21. from ..shape_spec import ShapeSpec
  22. from ..bbox_utils import delta2bbox, clip_bbox, nonempty_bbox
  23. __all__ = ['CascadeTwoFCHead', 'CascadeXConvNormHead', 'CascadeHead']
  24. @register
  25. class CascadeTwoFCHead(nn.Layer):
  26. __shared__ = ['num_cascade_stage']
  27. """
  28. Cascade RCNN bbox head with Two fc layers to extract feature
  29. Args:
  30. in_channel (int): Input channel which can be derived by from_config
  31. out_channel (int): Output channel
  32. resolution (int): Resolution of input feature map, default 7
  33. num_cascade_stage (int): The number of cascade stage, default 3
  34. """
  35. def __init__(self,
  36. in_channel=256,
  37. out_channel=1024,
  38. resolution=7,
  39. num_cascade_stage=3):
  40. super(CascadeTwoFCHead, self).__init__()
  41. self.in_channel = in_channel
  42. self.out_channel = out_channel
  43. self.head_list = []
  44. for stage in range(num_cascade_stage):
  45. head_per_stage = self.add_sublayer(
  46. str(stage), TwoFCHead(in_channel, out_channel, resolution))
  47. self.head_list.append(head_per_stage)
  48. @classmethod
  49. def from_config(cls, cfg, input_shape):
  50. s = input_shape
  51. s = s[0] if isinstance(s, (list, tuple)) else s
  52. return {'in_channel': s.channels}
  53. @property
  54. def out_shape(self):
  55. return [ShapeSpec(channels=self.out_channel, )]
  56. def forward(self, rois_feat, stage=0):
  57. out = self.head_list[stage](rois_feat)
  58. return out
  59. @register
  60. class CascadeXConvNormHead(nn.Layer):
  61. __shared__ = ['norm_type', 'freeze_norm', 'num_cascade_stage']
  62. """
  63. Cascade RCNN bbox head with serveral convolution layers
  64. Args:
  65. in_channel (int): Input channels which can be derived by from_config
  66. num_convs (int): The number of conv layers
  67. conv_dim (int): The number of channels for the conv layers
  68. out_channel (int): Output channels
  69. resolution (int): Resolution of input feature map
  70. norm_type (string): Norm type, bn, gn, sync_bn are available,
  71. default `gn`
  72. freeze_norm (bool): Whether to freeze the norm
  73. num_cascade_stage (int): The number of cascade stage, default 3
  74. """
  75. def __init__(self,
  76. in_channel=256,
  77. num_convs=4,
  78. conv_dim=256,
  79. out_channel=1024,
  80. resolution=7,
  81. norm_type='gn',
  82. freeze_norm=False,
  83. num_cascade_stage=3):
  84. super(CascadeXConvNormHead, self).__init__()
  85. self.in_channel = in_channel
  86. self.out_channel = out_channel
  87. self.head_list = []
  88. for stage in range(num_cascade_stage):
  89. head_per_stage = self.add_sublayer(
  90. str(stage),
  91. XConvNormHead(
  92. in_channel,
  93. num_convs,
  94. conv_dim,
  95. out_channel,
  96. resolution,
  97. norm_type,
  98. freeze_norm,
  99. stage_name='stage{}_'.format(stage)))
  100. self.head_list.append(head_per_stage)
  101. @classmethod
  102. def from_config(cls, cfg, input_shape):
  103. s = input_shape
  104. s = s[0] if isinstance(s, (list, tuple)) else s
  105. return {'in_channel': s.channels}
  106. @property
  107. def out_shape(self):
  108. return [ShapeSpec(channels=self.out_channel, )]
  109. def forward(self, rois_feat, stage=0):
  110. out = self.head_list[stage](rois_feat)
  111. return out
  112. @register
  113. class CascadeHead(BBoxHead):
  114. __shared__ = ['num_classes', 'num_cascade_stages']
  115. __inject__ = ['bbox_assigner', 'bbox_loss']
  116. """
  117. Cascade RCNN bbox head
  118. Args:
  119. head (nn.Layer): Extract feature in bbox head
  120. in_channel (int): Input channel after RoI extractor
  121. roi_extractor (object): The module of RoI Extractor
  122. bbox_assigner (object): The module of Box Assigner, label and sample the
  123. box.
  124. num_classes (int): The number of classes
  125. bbox_weight (List[List[float]]): The weight to get the decode box and the
  126. length of weight is the number of cascade stage
  127. num_cascade_stages (int): THe number of stage to refine the box
  128. """
  129. def __init__(self,
  130. head,
  131. in_channel,
  132. roi_extractor=RoIAlign().__dict__,
  133. bbox_assigner='BboxAssigner',
  134. num_classes=80,
  135. bbox_weight=[[10., 10., 5., 5.], [20.0, 20.0, 10.0, 10.0],
  136. [30.0, 30.0, 15.0, 15.0]],
  137. num_cascade_stages=3,
  138. bbox_loss=None):
  139. nn.Layer.__init__(self, )
  140. self.head = head
  141. self.roi_extractor = roi_extractor
  142. if isinstance(roi_extractor, dict):
  143. self.roi_extractor = RoIAlign(**roi_extractor)
  144. self.bbox_assigner = bbox_assigner
  145. self.num_classes = num_classes
  146. self.bbox_weight = bbox_weight
  147. self.num_cascade_stages = num_cascade_stages
  148. self.bbox_loss = bbox_loss
  149. self.bbox_score_list = []
  150. self.bbox_delta_list = []
  151. for i in range(num_cascade_stages):
  152. score_name = 'bbox_score_stage{}'.format(i)
  153. delta_name = 'bbox_delta_stage{}'.format(i)
  154. bbox_score = self.add_sublayer(
  155. score_name,
  156. nn.Linear(
  157. in_channel,
  158. self.num_classes + 1,
  159. weight_attr=paddle.ParamAttr(initializer=Normal(
  160. mean=0.0, std=0.01))))
  161. bbox_delta = self.add_sublayer(
  162. delta_name,
  163. nn.Linear(
  164. in_channel,
  165. 4,
  166. weight_attr=paddle.ParamAttr(initializer=Normal(
  167. mean=0.0, std=0.001))))
  168. self.bbox_score_list.append(bbox_score)
  169. self.bbox_delta_list.append(bbox_delta)
  170. self.assigned_label = None
  171. self.assigned_rois = None
  172. def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None):
  173. """
  174. body_feats (list[Tensor]): Feature maps from backbone
  175. rois (Tensor): RoIs generated from RPN module
  176. rois_num (Tensor): The number of RoIs in each image
  177. inputs (dict{Tensor}): The ground-truth of image
  178. """
  179. targets = []
  180. if self.training:
  181. rois, rois_num, targets = self.bbox_assigner(rois, rois_num, inputs)
  182. targets_list = [targets]
  183. self.assigned_rois = (rois, rois_num)
  184. self.assigned_targets = targets
  185. pred_bbox = None
  186. head_out_list = []
  187. for i in range(self.num_cascade_stages):
  188. if i > 0:
  189. rois, rois_num = self._get_rois_from_boxes(pred_bbox,
  190. inputs['im_shape'])
  191. if self.training:
  192. rois, rois_num, targets = self.bbox_assigner(
  193. rois, rois_num, inputs, i, is_cascade=True)
  194. targets_list.append(targets)
  195. rois_feat = self.roi_extractor(body_feats, rois, rois_num)
  196. bbox_feat = self.head(rois_feat, i)
  197. scores = self.bbox_score_list[i](bbox_feat)
  198. deltas = self.bbox_delta_list[i](bbox_feat)
  199. head_out_list.append([scores, deltas, rois])
  200. pred_bbox = self._get_pred_bbox(deltas, rois, self.bbox_weight[i])
  201. if self.training:
  202. loss = {}
  203. for stage, value in enumerate(zip(head_out_list, targets_list)):
  204. (scores, deltas, rois), targets = value
  205. loss_stage = self.get_loss(scores, deltas, targets, rois,
  206. self.bbox_weight[stage])
  207. for k, v in loss_stage.items():
  208. loss[k + "_stage{}".format(
  209. stage)] = v / self.num_cascade_stages
  210. return loss, bbox_feat
  211. else:
  212. scores, deltas, self.refined_rois = self.get_prediction(
  213. head_out_list)
  214. return (deltas, scores), self.head
  215. def _get_rois_from_boxes(self, boxes, im_shape):
  216. rois = []
  217. for i, boxes_per_image in enumerate(boxes):
  218. clip_box = clip_bbox(boxes_per_image, im_shape[i])
  219. if self.training:
  220. keep = nonempty_bbox(clip_box)
  221. if keep.shape[0] == 0:
  222. keep = paddle.zeros([1], dtype='int32')
  223. clip_box = paddle.gather(clip_box, keep)
  224. rois.append(clip_box)
  225. rois_num = paddle.concat([paddle.shape(r)[0] for r in rois])
  226. return rois, rois_num
  227. def _get_pred_bbox(self, deltas, proposals, weights):
  228. pred_proposals = paddle.concat(proposals) if len(
  229. proposals) > 1 else proposals[0]
  230. pred_bbox = delta2bbox(deltas, pred_proposals, weights)
  231. pred_bbox = paddle.reshape(pred_bbox, [-1, deltas.shape[-1]])
  232. num_prop = []
  233. for p in proposals:
  234. num_prop.append(p.shape[0])
  235. return pred_bbox.split(num_prop)
  236. def get_prediction(self, head_out_list):
  237. """
  238. head_out_list(List[Tensor]): scores, deltas, rois
  239. """
  240. pred_list = []
  241. scores_list = [F.softmax(head[0]) for head in head_out_list]
  242. scores = paddle.add_n(scores_list) / self.num_cascade_stages
  243. # Get deltas and rois from the last stage
  244. _, deltas, rois = head_out_list[-1]
  245. return scores, deltas, rois
  246. def get_refined_rois(self, ):
  247. return self.refined_rois