htc.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. import copy
  19. import paddle.fluid as fluid
  20. from ppdet.core.workspace import register
  21. from ppdet.utils.check import check_version
  22. from .input_helper import multiscale_def
  23. __all__ = ['HybridTaskCascade']
  24. @register
  25. class HybridTaskCascade(object):
  26. """
  27. Hybrid Task Cascade Mask R-CNN architecture, see https://arxiv.org/abs/1901.07518
  28. Args:
  29. backbone (object): backbone instance
  30. rpn_head (object): `RPNhead` instance
  31. bbox_assigner (object): `BBoxAssigner` instance
  32. roi_extractor (object): ROI extractor instance
  33. bbox_head (object): `HTCBBoxHead` instance
  34. mask_assigner (object): `MaskAssigner` instance
  35. mask_head (object): `HTCMaskHead` instance
  36. fpn (object): feature pyramid network instance
  37. semantic_roi_extractor(object): ROI extractor instance
  38. fused_semantic_head (object): `FusedSemanticHead` instance
  39. """
  40. __category__ = 'architecture'
  41. __inject__ = [
  42. 'backbone', 'rpn_head', 'bbox_assigner', 'roi_extractor', 'bbox_head',
  43. 'mask_assigner', 'mask_head', 'fpn', 'semantic_roi_extractor',
  44. 'fused_semantic_head'
  45. ]
  46. def __init__(self,
  47. backbone,
  48. rpn_head,
  49. roi_extractor='FPNRoIAlign',
  50. semantic_roi_extractor='RoIAlign',
  51. fused_semantic_head='FusedSemanticHead',
  52. bbox_head='HTCBBoxHead',
  53. bbox_assigner='CascadeBBoxAssigner',
  54. mask_assigner='MaskAssigner',
  55. mask_head='HTCMaskHead',
  56. rpn_only=False,
  57. fpn='FPN'):
  58. super(HybridTaskCascade, self).__init__()
  59. check_version('2.0.0-rc0')
  60. assert fpn is not None, "HTC requires FPN"
  61. self.backbone = backbone
  62. self.fpn = fpn
  63. self.rpn_head = rpn_head
  64. self.bbox_assigner = bbox_assigner
  65. self.roi_extractor = roi_extractor
  66. self.semantic_roi_extractor = semantic_roi_extractor
  67. self.fused_semantic_head = fused_semantic_head
  68. self.bbox_head = bbox_head
  69. self.mask_assigner = mask_assigner
  70. self.mask_head = mask_head
  71. self.rpn_only = rpn_only
  72. # Cascade local cfg
  73. self.cls_agnostic_bbox_reg = 2
  74. (brw0, brw1, brw2) = self.bbox_assigner.bbox_reg_weights
  75. self.cascade_bbox_reg_weights = [
  76. [1. / brw0, 1. / brw0, 2. / brw0, 2. / brw0],
  77. [1. / brw1, 1. / brw1, 2. / brw1, 2. / brw1],
  78. [1. / brw2, 1. / brw2, 2. / brw2, 2. / brw2]
  79. ]
  80. self.cascade_rcnn_loss_weight = [1.0, 0.5, 0.25]
  81. self.num_stage = 3
  82. self.with_mask = True
  83. self.interleaved = True
  84. self.mask_info_flow = True
  85. self.with_semantic = True
  86. self.use_bias_scalar = True
  87. def build(self, feed_vars, mode='train'):
  88. if mode == 'train':
  89. required_fields = [
  90. 'gt_class', 'gt_bbox', 'gt_mask', 'is_crowd', 'im_info',
  91. 'semantic'
  92. ]
  93. else:
  94. required_fields = ['im_shape', 'im_info']
  95. self._input_check(required_fields, feed_vars)
  96. im = feed_vars['image']
  97. if mode == 'train':
  98. gt_bbox = feed_vars['gt_bbox']
  99. is_crowd = feed_vars['is_crowd']
  100. im_info = feed_vars['im_info']
  101. # backbone
  102. body_feats = self.backbone(im)
  103. loss = {}
  104. # FPN
  105. if self.fpn is not None:
  106. body_feats, spatial_scale = self.fpn.get_output(body_feats)
  107. if self.with_semantic:
  108. # TODO: use cfg
  109. semantic_feat, seg_pred = self.fused_semantic_head.get_out(
  110. body_feats)
  111. if mode == 'train':
  112. s_label = feed_vars['semantic']
  113. semantic_loss = self.fused_semantic_head.get_loss(seg_pred,
  114. s_label) * 0.2
  115. loss.update({"semantic_loss": semantic_loss})
  116. else:
  117. semantic_feat = None
  118. # rpn proposals
  119. rpn_rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode)
  120. if mode == 'train':
  121. rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd)
  122. loss.update(rpn_loss)
  123. else:
  124. if self.rpn_only:
  125. im_scale = fluid.layers.slice(
  126. im_info, [1], starts=[2], ends=[3])
  127. im_scale = fluid.layers.sequence_expand(im_scale, rpn_rois)
  128. rois = rpn_rois / im_scale
  129. return {'proposal': rois}
  130. proposal_list = []
  131. roi_feat_list = []
  132. rcnn_pred_list = []
  133. rcnn_target_list = []
  134. mask_logits_list = []
  135. mask_target_list = []
  136. proposals = None
  137. bbox_pred = None
  138. outs = None
  139. refined_bbox = rpn_rois
  140. max_overlap = None
  141. for i in range(self.num_stage):
  142. # BBox Branch
  143. if mode == 'train':
  144. outs = self.bbox_assigner(
  145. input_rois=refined_bbox,
  146. feed_vars=feed_vars,
  147. curr_stage=i,
  148. max_overlap=max_overlap)
  149. proposals = outs[0]
  150. max_overlap = outs[-1]
  151. rcnn_target_list.append(outs[:-1])
  152. else:
  153. proposals = refined_bbox
  154. proposal_list.append(proposals)
  155. # extract roi features
  156. roi_feat = self.roi_extractor(body_feats, proposals, spatial_scale)
  157. if self.with_semantic:
  158. semantic_roi_feat = self.semantic_roi_extractor(semantic_feat,
  159. proposals)
  160. if semantic_roi_feat is not None:
  161. semantic_roi_feat = fluid.layers.pool2d(
  162. semantic_roi_feat,
  163. pool_size=2,
  164. pool_stride=2,
  165. pool_padding='SAME')
  166. roi_feat = fluid.layers.sum([roi_feat, semantic_roi_feat])
  167. roi_feat_list.append(roi_feat)
  168. # bbox head
  169. cls_score, bbox_pred = self.bbox_head.get_output(
  170. roi_feat,
  171. wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i],
  172. name='_' + str(i))
  173. rcnn_pred_list.append((cls_score, bbox_pred))
  174. # Mask Branch
  175. if self.with_mask:
  176. if mode == 'train':
  177. labels_int32 = outs[1]
  178. if self.interleaved:
  179. refined_bbox = self._decode_box(
  180. proposals, bbox_pred, curr_stage=i)
  181. proposals = refined_bbox
  182. mask_rois, roi_has_mask_int32, mask_int32 = self.mask_assigner(
  183. rois=proposals,
  184. gt_classes=feed_vars['gt_class'],
  185. is_crowd=feed_vars['is_crowd'],
  186. gt_segms=feed_vars['gt_mask'],
  187. im_info=feed_vars['im_info'],
  188. labels_int32=labels_int32)
  189. mask_target_list.append(mask_int32)
  190. mask_feat = self.roi_extractor(
  191. body_feats, mask_rois, spatial_scale, is_mask=True)
  192. if self.with_semantic:
  193. semantic_roi_feat = self.semantic_roi_extractor(
  194. semantic_feat, mask_rois)
  195. if semantic_roi_feat is not None:
  196. mask_feat = fluid.layers.sum(
  197. [mask_feat, semantic_roi_feat])
  198. if self.mask_info_flow:
  199. last_feat = None
  200. for j in range(i):
  201. last_feat = self.mask_head.get_output(
  202. mask_feat,
  203. last_feat,
  204. return_logits=False,
  205. return_feat=True,
  206. wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i]
  207. if self.use_bias_scalar else 1.0,
  208. name='_' + str(i) + '_' + str(j))
  209. mask_logits = self.mask_head.get_output(
  210. mask_feat,
  211. last_feat,
  212. return_logits=True,
  213. return_feat=False,
  214. wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i]
  215. if self.use_bias_scalar else 1.0,
  216. name='_' + str(i))
  217. else:
  218. mask_logits = self.mask_head.get_output(
  219. mask_feat,
  220. return_logits=True,
  221. wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i]
  222. if self.use_bias_scalar else 1.0,
  223. name='_' + str(i))
  224. mask_logits_list.append(mask_logits)
  225. if i < self.num_stage - 1 and not self.interleaved:
  226. refined_bbox = self._decode_box(
  227. proposals, bbox_pred, curr_stage=i)
  228. elif i < self.num_stage - 1 and mode != 'train':
  229. refined_bbox = self._decode_box(
  230. proposals, bbox_pred, curr_stage=i)
  231. if mode == 'train':
  232. bbox_loss = self.bbox_head.get_loss(
  233. rcnn_pred_list, rcnn_target_list, self.cascade_rcnn_loss_weight)
  234. loss.update(bbox_loss)
  235. mask_loss = self.mask_head.get_loss(mask_logits_list,
  236. mask_target_list,
  237. self.cascade_rcnn_loss_weight)
  238. loss.update(mask_loss)
  239. total_loss = fluid.layers.sum(list(loss.values()))
  240. loss.update({'loss': total_loss})
  241. return loss
  242. else:
  243. mask_name = 'mask_pred'
  244. mask_pred, bbox_pred = self.single_scale_eval(
  245. body_feats,
  246. spatial_scale,
  247. im_info,
  248. mask_name,
  249. bbox_pred,
  250. roi_feat_list,
  251. rcnn_pred_list,
  252. proposal_list,
  253. feed_vars['im_shape'],
  254. semantic_feat=semantic_feat if self.with_semantic else None)
  255. return {'bbox': bbox_pred, 'mask': mask_pred}
  256. def single_scale_eval(self,
  257. body_feats,
  258. spatial_scale,
  259. im_info,
  260. mask_name,
  261. bbox_pred,
  262. roi_feat_list=None,
  263. rcnn_pred_list=None,
  264. proposal_list=None,
  265. im_shape=None,
  266. use_multi_test=False,
  267. semantic_feat=None):
  268. if not use_multi_test:
  269. bbox_pred = self.bbox_head.get_prediction(
  270. im_info, im_shape, roi_feat_list, rcnn_pred_list, proposal_list,
  271. self.cascade_bbox_reg_weights)
  272. bbox_pred = bbox_pred['bbox']
  273. # share weight
  274. bbox_shape = fluid.layers.shape(bbox_pred)
  275. bbox_size = fluid.layers.reduce_prod(bbox_shape)
  276. bbox_size = fluid.layers.reshape(bbox_size, [1, 1])
  277. size = fluid.layers.fill_constant([1, 1], value=6, dtype='int32')
  278. cond = fluid.layers.less_than(x=bbox_size, y=size)
  279. mask_pred = fluid.layers.create_global_var(
  280. shape=[1],
  281. value=0.0,
  282. dtype='float32',
  283. persistable=False,
  284. name=mask_name)
  285. def noop():
  286. fluid.layers.assign(input=bbox_pred, output=mask_pred)
  287. def process_boxes():
  288. bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6])
  289. im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
  290. im_scale = fluid.layers.sequence_expand(im_scale, bbox)
  291. bbox = fluid.layers.cast(bbox, dtype='float32')
  292. im_scale = fluid.layers.cast(im_scale, dtype='float32')
  293. mask_rois = bbox * im_scale
  294. mask_feat = self.roi_extractor(
  295. body_feats, mask_rois, spatial_scale, is_mask=True)
  296. if self.with_semantic:
  297. semantic_roi_feat = self.semantic_roi_extractor(semantic_feat,
  298. mask_rois)
  299. if semantic_roi_feat is not None:
  300. mask_feat = fluid.layers.sum([mask_feat, semantic_roi_feat])
  301. mask_logits_list = []
  302. mask_pred_list = []
  303. for i in range(self.num_stage):
  304. if self.mask_info_flow:
  305. last_feat = None
  306. for j in range(i):
  307. last_feat = self.mask_head.get_output(
  308. mask_feat,
  309. last_feat,
  310. return_logits=False,
  311. return_feat=True,
  312. wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i]
  313. if self.use_bias_scalar else 1.0,
  314. name='_' + str(i) + '_' + str(j))
  315. mask_logits = self.mask_head.get_output(
  316. mask_feat,
  317. last_feat,
  318. return_logits=True,
  319. return_feat=False,
  320. wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i]
  321. if self.use_bias_scalar else 1.0,
  322. name='_' + str(i))
  323. mask_logits_list.append(mask_logits)
  324. else:
  325. mask_logits = self.mask_head.get_output(
  326. mask_feat,
  327. return_logits=True,
  328. return_feat=False,
  329. name='_' + str(i))
  330. mask_pred_out = self.mask_head.get_prediction(mask_logits, bbox)
  331. mask_pred_list.append(mask_pred_out)
  332. mask_pred_out = fluid.layers.sum(mask_pred_list) / float(
  333. len(mask_pred_list))
  334. fluid.layers.assign(input=mask_pred_out, output=mask_pred)
  335. fluid.layers.cond(cond, noop, process_boxes)
  336. return mask_pred, bbox_pred
  337. def _input_check(self, require_fields, feed_vars):
  338. for var in require_fields:
  339. assert var in feed_vars, \
  340. "{} has no {} field".format(feed_vars, var)
  341. def _decode_box(self, proposals, bbox_pred, curr_stage):
  342. rcnn_loc_delta_r = fluid.layers.reshape(
  343. bbox_pred, (-1, self.cls_agnostic_bbox_reg, 4))
  344. # only use fg box delta to decode box
  345. rcnn_loc_delta_s = fluid.layers.slice(
  346. rcnn_loc_delta_r, axes=[1], starts=[1], ends=[2])
  347. refined_bbox = fluid.layers.box_coder(
  348. prior_box=proposals,
  349. prior_box_var=self.cascade_bbox_reg_weights[curr_stage],
  350. target_box=rcnn_loc_delta_s,
  351. code_type='decode_center_size',
  352. box_normalized=False,
  353. axis=1, )
  354. refined_bbox = fluid.layers.reshape(refined_bbox, shape=[-1, 4])
  355. return refined_bbox
  356. def _inputs_def(self, image_shape):
  357. im_shape = [None] + image_shape
  358. # yapf: disable
  359. inputs_def = {
  360. 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
  361. 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  362. 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
  363. 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  364. 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1},
  365. 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  366. 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  367. 'gt_mask': {'shape': [None, 2], 'dtype': 'float32', 'lod_level': 3}, # polygon coordinates
  368. 'semantic': {'shape': [None, 1, None, None], 'dtype': 'int32', 'lod_level': 0},
  369. }
  370. # yapf: enable
  371. return inputs_def
  372. def build_inputs(self,
  373. image_shape=[3, None, None],
  374. fields=[
  375. 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class',
  376. 'is_crowd', 'gt_mask', 'semantic'
  377. ],
  378. multi_scale=False,
  379. num_scales=-1,
  380. use_flip=None,
  381. use_dataloader=True,
  382. iterable=False,
  383. mask_branch=False):
  384. inputs_def = self._inputs_def(image_shape)
  385. fields = copy.deepcopy(fields)
  386. if multi_scale:
  387. ms_def, ms_fields = multiscale_def(image_shape, num_scales,
  388. use_flip)
  389. inputs_def.update(ms_def)
  390. fields += ms_fields
  391. self.im_info_names = ['image', 'im_info'] + ms_fields
  392. if mask_branch:
  393. box_fields = ['bbox', 'bbox_flip'] if use_flip else ['bbox']
  394. for key in box_fields:
  395. inputs_def[key] = {
  396. 'shape': [6],
  397. 'dtype': 'float32',
  398. 'lod_level': 1
  399. }
  400. fields += box_fields
  401. feed_vars = OrderedDict([(key, fluid.data(
  402. name=key,
  403. shape=inputs_def[key]['shape'],
  404. dtype=inputs_def[key]['dtype'],
  405. lod_level=inputs_def[key]['lod_level'])) for key in fields])
  406. use_dataloader = use_dataloader and not mask_branch
  407. loader = fluid.io.DataLoader.from_generator(
  408. feed_list=list(feed_vars.values()),
  409. capacity=64,
  410. use_double_buffer=True,
  411. iterable=iterable) if use_dataloader else None
  412. return feed_vars, loader
  413. def train(self, feed_vars):
  414. return self.build(feed_vars, 'train')
  415. def eval(self, feed_vars, multi_scale=None, mask_branch=False):
  416. if multi_scale:
  417. return self.build_multi_scale(feed_vars, mask_branch)
  418. return self.build(feed_vars, 'test')
  419. def test(self, feed_vars, exclude_nms=False):
  420. assert not exclude_nms, "exclude_nms for {} is not support currently".format(
  421. self.__class__.__name__)
  422. return self.build(feed_vars, 'test')