cascade_rcnn.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import copy
  18. from collections import OrderedDict
  19. import paddle.fluid as fluid
  20. from ppdet.experimental import mixed_precision_global_state
  21. from ppdet.core.workspace import register
  22. from ppdet.utils.check import check_version
  23. from .input_helper import multiscale_def
  24. __all__ = ['CascadeRCNN']
  25. @register
  26. class CascadeRCNN(object):
  27. """
  28. Cascade R-CNN architecture, see https://arxiv.org/abs/1712.00726
  29. Args:
  30. backbone (object): backbone instance
  31. rpn_head (object): `RPNhead` instance
  32. bbox_assigner (object): `BBoxAssigner` instance
  33. roi_extractor (object): ROI extractor instance
  34. bbox_head (object): `BBoxHead` instance
  35. fpn (object): feature pyramid network instance
  36. """
  37. __category__ = 'architecture'
  38. __inject__ = [
  39. 'backbone', 'fpn', 'rpn_head', 'bbox_assigner', 'roi_extractor',
  40. 'bbox_head'
  41. ]
  42. def __init__(self,
  43. backbone,
  44. rpn_head,
  45. roi_extractor='FPNRoIAlign',
  46. bbox_head='CascadeBBoxHead',
  47. bbox_assigner='CascadeBBoxAssigner',
  48. rpn_only=False,
  49. fpn='FPN'):
  50. super(CascadeRCNN, self).__init__()
  51. check_version('2.0.0-rc0')
  52. assert fpn is not None, "cascade RCNN requires FPN"
  53. self.backbone = backbone
  54. self.fpn = fpn
  55. self.rpn_head = rpn_head
  56. self.bbox_assigner = bbox_assigner
  57. self.roi_extractor = roi_extractor
  58. self.bbox_head = bbox_head
  59. self.rpn_only = rpn_only
  60. # Cascade local cfg
  61. self.cls_agnostic_bbox_reg = 2
  62. (brw0, brw1, brw2) = self.bbox_assigner.bbox_reg_weights
  63. self.cascade_bbox_reg_weights = [
  64. [1. / brw0, 1. / brw0, 2. / brw0, 2. / brw0],
  65. [1. / brw1, 1. / brw1, 2. / brw1, 2. / brw1],
  66. [1. / brw2, 1. / brw2, 2. / brw2, 2. / brw2]
  67. ]
  68. self.cascade_rcnn_loss_weight = [1.0, 0.5, 0.25]
  69. def build(self, feed_vars, mode='train'):
  70. if mode == 'train':
  71. required_fields = ['gt_class', 'gt_bbox', 'is_crowd', 'im_info']
  72. else:
  73. required_fields = ['im_shape', 'im_info']
  74. self._input_check(required_fields, feed_vars)
  75. im = feed_vars['image']
  76. im_info = feed_vars['im_info']
  77. if mode == 'train':
  78. gt_bbox = feed_vars['gt_bbox']
  79. is_crowd = feed_vars['is_crowd']
  80. mixed_precision_enabled = mixed_precision_global_state() is not None
  81. # cast inputs to FP16
  82. if mixed_precision_enabled:
  83. im = fluid.layers.cast(im, 'float16')
  84. # backbone
  85. body_feats = self.backbone(im)
  86. # cast features back to FP32
  87. if mixed_precision_enabled:
  88. body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32'))
  89. for k, v in body_feats.items())
  90. # FPN
  91. if self.fpn is not None:
  92. body_feats, spatial_scale = self.fpn.get_output(body_feats)
  93. # rpn proposals
  94. rpn_rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode)
  95. if mode == 'train':
  96. #fluid.layers.Print(gt_bbox)
  97. #fluid.layers.Print(is_crowd)
  98. rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd)
  99. else:
  100. if self.rpn_only:
  101. im_scale = fluid.layers.slice(
  102. im_info, [1], starts=[2], ends=[3])
  103. im_scale = fluid.layers.sequence_expand(im_scale, rpn_rois)
  104. rois = rpn_rois / im_scale
  105. return {'proposal': rois}
  106. proposal_list = []
  107. roi_feat_list = []
  108. rcnn_pred_list = []
  109. rcnn_target_list = []
  110. proposals = None
  111. bbox_pred = None
  112. max_overlap = None
  113. for i in range(3):
  114. if i > 0:
  115. refined_bbox = self._decode_box(
  116. proposals,
  117. bbox_pred,
  118. curr_stage=i - 1, )
  119. else:
  120. refined_bbox = rpn_rois
  121. if mode == 'train':
  122. outs = self.bbox_assigner(
  123. input_rois=refined_bbox,
  124. feed_vars=feed_vars,
  125. curr_stage=i,
  126. max_overlap=max_overlap)
  127. proposals = outs[0]
  128. max_overlap = outs[-1]
  129. rcnn_target_list.append(outs[:-1])
  130. else:
  131. proposals = refined_bbox
  132. proposal_list.append(proposals)
  133. # extract roi features
  134. roi_feat = self.roi_extractor(body_feats, proposals, spatial_scale)
  135. roi_feat_list.append(roi_feat)
  136. # bbox head
  137. cls_score, bbox_pred = self.bbox_head.get_output(
  138. roi_feat,
  139. wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i],
  140. name='_' + str(i + 1) if i > 0 else '')
  141. rcnn_pred_list.append((cls_score, bbox_pred))
  142. if mode == 'train':
  143. loss = self.bbox_head.get_loss(rcnn_pred_list, rcnn_target_list,
  144. self.cascade_rcnn_loss_weight)
  145. loss.update(rpn_loss)
  146. total_loss = fluid.layers.sum(list(loss.values()))
  147. loss.update({'loss': total_loss})
  148. return loss
  149. else:
  150. pred = self.bbox_head.get_prediction(
  151. im_info, feed_vars['im_shape'], roi_feat_list, rcnn_pred_list,
  152. proposal_list, self.cascade_bbox_reg_weights,
  153. self.cls_agnostic_bbox_reg)
  154. return pred
  155. def build_multi_scale(self, feed_vars):
  156. required_fields = ['image', 'im_shape', 'im_info']
  157. self._input_check(required_fields, feed_vars)
  158. result = {}
  159. im_shape = feed_vars['im_shape']
  160. result['im_shape'] = im_shape
  161. for i in range(len(self.im_info_names) // 2):
  162. im = feed_vars[self.im_info_names[2 * i]]
  163. im_info = feed_vars[self.im_info_names[2 * i + 1]]
  164. # backbone
  165. body_feats = self.backbone(im)
  166. result.update(body_feats)
  167. # FPN
  168. if self.fpn is not None:
  169. body_feats, spatial_scale = self.fpn.get_output(body_feats)
  170. # rpn proposals
  171. rpn_rois = self.rpn_head.get_proposals(
  172. body_feats, im_info, mode='test')
  173. proposal_list = []
  174. roi_feat_list = []
  175. rcnn_pred_list = []
  176. proposals = None
  177. bbox_pred = None
  178. for i in range(3):
  179. if i > 0:
  180. refined_bbox = self._decode_box(
  181. proposals,
  182. bbox_pred,
  183. curr_stage=i - 1, )
  184. else:
  185. refined_bbox = rpn_rois
  186. proposals = refined_bbox
  187. proposal_list.append(proposals)
  188. # extract roi features
  189. roi_feat = self.roi_extractor(body_feats, proposals,
  190. spatial_scale)
  191. roi_feat_list.append(roi_feat)
  192. # bbox head
  193. cls_score, bbox_pred = self.bbox_head.get_output(
  194. roi_feat,
  195. wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i],
  196. name='_' + str(i + 1) if i > 0 else '')
  197. rcnn_pred_list.append((cls_score, bbox_pred))
  198. # get mask rois
  199. rois = proposal_list[2]
  200. if self.fpn is None:
  201. last_feat = body_feats[list(body_feats.keys())[-1]]
  202. roi_feat = self.roi_extractor(last_feat, rois)
  203. else:
  204. roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)
  205. pred = self.bbox_head.get_prediction(
  206. im_info,
  207. im_shape,
  208. roi_feat_list,
  209. rcnn_pred_list,
  210. proposal_list,
  211. self.cascade_bbox_reg_weights,
  212. self.cls_agnostic_bbox_reg,
  213. return_box_score=True)
  214. bbox_name = 'bbox_' + str(i)
  215. score_name = 'score_' + str(i)
  216. if 'flip' in im.name:
  217. bbox_name += '_flip'
  218. score_name += '_flip'
  219. result[bbox_name] = pred['bbox']
  220. result[score_name] = pred['score']
  221. return result
  222. def _input_check(self, require_fields, feed_vars):
  223. for var in require_fields:
  224. assert var in feed_vars, \
  225. "{} has no {} field".format(feed_vars, var)
  226. def _decode_box(self, proposals, bbox_pred, curr_stage):
  227. rcnn_loc_delta_r = fluid.layers.reshape(
  228. bbox_pred, (-1, self.cls_agnostic_bbox_reg, 4))
  229. # only use fg box delta to decode box
  230. rcnn_loc_delta_s = fluid.layers.slice(
  231. rcnn_loc_delta_r, axes=[1], starts=[1], ends=[2])
  232. refined_bbox = fluid.layers.box_coder(
  233. prior_box=proposals,
  234. prior_box_var=self.cascade_bbox_reg_weights[curr_stage],
  235. target_box=rcnn_loc_delta_s,
  236. code_type='decode_center_size',
  237. box_normalized=False,
  238. axis=1, )
  239. refined_bbox = fluid.layers.reshape(refined_bbox, shape=[-1, 4])
  240. return refined_bbox
  241. def _inputs_def(self, image_shape):
  242. im_shape = [None] + image_shape
  243. # yapf: disable
  244. inputs_def = {
  245. 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
  246. 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  247. 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  248. 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
  249. 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1},
  250. 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  251. 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  252. 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  253. }
  254. # yapf: enable
  255. return inputs_def
  256. def build_inputs(self,
  257. image_shape=[3, None, None],
  258. fields=[
  259. 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class',
  260. 'is_crowd'
  261. ],
  262. multi_scale=False,
  263. num_scales=-1,
  264. use_flip=None,
  265. use_dataloader=True,
  266. iterable=False):
  267. inputs_def = self._inputs_def(image_shape)
  268. fields = copy.deepcopy(fields)
  269. if multi_scale:
  270. ms_def, ms_fields = multiscale_def(image_shape, num_scales,
  271. use_flip)
  272. inputs_def.update(ms_def)
  273. fields += ms_fields
  274. self.im_info_names = ['image', 'im_info'] + ms_fields
  275. feed_vars = OrderedDict([(key, fluid.data(
  276. name=key,
  277. shape=inputs_def[key]['shape'],
  278. dtype=inputs_def[key]['dtype'],
  279. lod_level=inputs_def[key]['lod_level'])) for key in fields])
  280. loader = fluid.io.DataLoader.from_generator(
  281. feed_list=list(feed_vars.values()),
  282. capacity=16,
  283. use_double_buffer=True,
  284. iterable=iterable) if use_dataloader else None
  285. return feed_vars, loader
  286. def train(self, feed_vars):
  287. return self.build(feed_vars, 'train')
  288. def eval(self, feed_vars, multi_scale=None):
  289. if multi_scale:
  290. return self.build_multi_scale(feed_vars)
  291. return self.build(feed_vars, 'test')
  292. def test(self, feed_vars, exclude_nms=False):
  293. assert not exclude_nms, "exclude_nms for {} is not support currently".format(
  294. self.__class__.__name__)
  295. return self.build(feed_vars, 'test')