faster_rcnn.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. import copy
  19. from paddle import fluid
  20. from ppdet.experimental import mixed_precision_global_state
  21. from ppdet.core.workspace import register
  22. from .input_helper import multiscale_def
  23. __all__ = ['FasterRCNN']
  24. @register
  25. class FasterRCNN(object):
  26. """
  27. Faster R-CNN architecture, see https://arxiv.org/abs/1506.01497
  28. Args:
  29. backbone (object): backbone instance
  30. rpn_head (object): `RPNhead` instance
  31. bbox_assigner (object): `BBoxAssigner` instance
  32. roi_extractor (object): ROI extractor instance
  33. bbox_head (object): `BBoxHead` instance
  34. fpn (object): feature pyramid network instance
  35. """
  36. __category__ = 'architecture'
  37. __inject__ = [
  38. 'backbone', 'rpn_head', 'bbox_assigner', 'roi_extractor', 'bbox_head',
  39. 'fpn'
  40. ]
  41. def __init__(self,
  42. backbone,
  43. rpn_head,
  44. roi_extractor,
  45. bbox_head='BBoxHead',
  46. bbox_assigner='BBoxAssigner',
  47. rpn_only=False,
  48. fpn=None):
  49. super(FasterRCNN, self).__init__()
  50. self.backbone = backbone
  51. self.rpn_head = rpn_head
  52. self.bbox_assigner = bbox_assigner
  53. self.roi_extractor = roi_extractor
  54. self.bbox_head = bbox_head
  55. self.fpn = fpn
  56. self.rpn_only = rpn_only
  57. def build(self, feed_vars, mode='train'):
  58. if mode == 'train':
  59. required_fields = ['gt_class', 'gt_bbox', 'is_crowd', 'im_info']
  60. else:
  61. required_fields = ['im_shape', 'im_info']
  62. self._input_check(required_fields, feed_vars)
  63. im = feed_vars['image']
  64. im_info = feed_vars['im_info']
  65. if mode == 'train':
  66. gt_bbox = feed_vars['gt_bbox']
  67. is_crowd = feed_vars['is_crowd']
  68. else:
  69. im_shape = feed_vars['im_shape']
  70. mixed_precision_enabled = mixed_precision_global_state() is not None
  71. # cast inputs to FP16
  72. if mixed_precision_enabled:
  73. im = fluid.layers.cast(im, 'float16')
  74. body_feats = self.backbone(im)
  75. body_feat_names = list(body_feats.keys())
  76. # cast features back to FP32
  77. if mixed_precision_enabled:
  78. body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32'))
  79. for k, v in body_feats.items())
  80. if self.fpn is not None:
  81. body_feats, spatial_scale = self.fpn.get_output(body_feats)
  82. rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode)
  83. if mode == 'train':
  84. rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd)
  85. # sampled rpn proposals
  86. for var in ['gt_class', 'is_crowd', 'gt_bbox', 'im_info']:
  87. assert var in feed_vars, "{} has no {}".format(feed_vars, var)
  88. outs = self.bbox_assigner(
  89. rpn_rois=rois,
  90. gt_classes=feed_vars['gt_class'],
  91. is_crowd=feed_vars['is_crowd'],
  92. gt_boxes=feed_vars['gt_bbox'],
  93. im_info=feed_vars['im_info'])
  94. rois = outs[0]
  95. labels_int32 = outs[1]
  96. bbox_targets = outs[2]
  97. bbox_inside_weights = outs[3]
  98. bbox_outside_weights = outs[4]
  99. else:
  100. if self.rpn_only:
  101. im_scale = fluid.layers.slice(
  102. im_info, [1], starts=[2], ends=[3])
  103. im_scale = fluid.layers.sequence_expand(im_scale, rois)
  104. rois = rois / im_scale
  105. return {'proposal': rois}
  106. if self.fpn is None:
  107. # in models without FPN, roi extractor only uses the last level of
  108. # feature maps. And body_feat_names[-1] represents the name of
  109. # last feature map.
  110. body_feat = body_feats[body_feat_names[-1]]
  111. roi_feat = self.roi_extractor(body_feat, rois)
  112. else:
  113. roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)
  114. if mode == 'train':
  115. loss = self.bbox_head.get_loss(roi_feat, labels_int32, bbox_targets,
  116. bbox_inside_weights,
  117. bbox_outside_weights)
  118. loss.update(rpn_loss)
  119. total_loss = fluid.layers.sum(list(loss.values()))
  120. loss.update({'loss': total_loss})
  121. return loss
  122. else:
  123. pred = self.bbox_head.get_prediction(roi_feat, rois, im_info,
  124. im_shape)
  125. return pred
  126. def build_multi_scale(self, feed_vars):
  127. required_fields = ['image', 'im_info', 'im_shape']
  128. self._input_check(required_fields, feed_vars)
  129. result = {}
  130. im_shape = feed_vars['im_shape']
  131. result['im_shape'] = im_shape
  132. for i in range(len(self.im_info_names) // 2):
  133. im = feed_vars[self.im_info_names[2 * i]]
  134. im_info = feed_vars[self.im_info_names[2 * i + 1]]
  135. body_feats = self.backbone(im)
  136. body_feat_names = list(body_feats.keys())
  137. if self.fpn is not None:
  138. body_feats, spatial_scale = self.fpn.get_output(body_feats)
  139. rois = self.rpn_head.get_proposals(body_feats, im_info, mode='test')
  140. if self.fpn is None:
  141. # in models without FPN, roi extractor only uses the last level of
  142. # feature maps. And body_feat_names[-1] represents the name of
  143. # last feature map.
  144. body_feat = body_feats[body_feat_names[-1]]
  145. roi_feat = self.roi_extractor(body_feat, rois)
  146. else:
  147. roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)
  148. pred = self.bbox_head.get_prediction(
  149. roi_feat, rois, im_info, im_shape, return_box_score=True)
  150. bbox_name = 'bbox_' + str(i)
  151. score_name = 'score_' + str(i)
  152. if 'flip' in im.name:
  153. bbox_name += '_flip'
  154. score_name += '_flip'
  155. result[bbox_name] = pred['bbox']
  156. result[score_name] = pred['score']
  157. return result
  158. def _input_check(self, require_fields, feed_vars):
  159. for var in require_fields:
  160. assert var in feed_vars, \
  161. "{} has no {} field".format(feed_vars, var)
  162. def _inputs_def(self, image_shape):
  163. im_shape = [None] + image_shape
  164. # yapf: disable
  165. inputs_def = {
  166. 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
  167. 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  168. 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
  169. 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  170. 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1},
  171. 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  172. 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  173. 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  174. }
  175. # yapf: enable
  176. return inputs_def
  177. def build_inputs(
  178. self,
  179. image_shape=[3, None, None],
  180. fields=[
  181. 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd'
  182. ], # for train
  183. multi_scale=False,
  184. num_scales=-1,
  185. use_flip=None,
  186. use_dataloader=True,
  187. iterable=False):
  188. inputs_def = self._inputs_def(image_shape)
  189. fields = copy.deepcopy(fields)
  190. if multi_scale:
  191. ms_def, ms_fields = multiscale_def(image_shape, num_scales,
  192. use_flip)
  193. inputs_def.update(ms_def)
  194. fields += ms_fields
  195. self.im_info_names = ['image', 'im_info'] + ms_fields
  196. feed_vars = OrderedDict([(key, fluid.data(
  197. name=key,
  198. shape=inputs_def[key]['shape'],
  199. dtype=inputs_def[key]['dtype'],
  200. lod_level=inputs_def[key]['lod_level'])) for key in fields])
  201. loader = fluid.io.DataLoader.from_generator(
  202. feed_list=list(feed_vars.values()),
  203. capacity=16,
  204. use_double_buffer=True,
  205. iterable=iterable) if use_dataloader else None
  206. return feed_vars, loader
  207. def train(self, feed_vars):
  208. return self.build(feed_vars, 'train')
  209. def eval(self, feed_vars, multi_scale=None):
  210. if multi_scale:
  211. return self.build_multi_scale(feed_vars)
  212. return self.build(feed_vars, 'test')
  213. def test(self, feed_vars, exclude_nms=False):
  214. assert not exclude_nms, "exclude_nms for {} is not support currently".format(
  215. self.__class__.__name__)
  216. return self.build(feed_vars, 'test')