solov2.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. from paddle import fluid
  19. from ppdet.experimental import mixed_precision_global_state
  20. from ppdet.core.workspace import register
  21. from ppdet.utils.check import check_version
  22. __all__ = ['SOLOv2']
  23. @register
  24. class SOLOv2(object):
  25. """
  26. SOLOv2 network, see https://arxiv.org/abs/2003.10152
  27. Args:
  28. backbone (object): an backbone instance
  29. fpn (object): feature pyramid network instance
  30. bbox_head (object): an `SOLOv2Head` instance
  31. mask_head (object): an `SOLOv2MaskHead` instance
  32. """
  33. __category__ = 'architecture'
  34. __inject__ = ['backbone', 'fpn', 'bbox_head', 'mask_head']
  35. def __init__(self,
  36. backbone,
  37. fpn=None,
  38. bbox_head='SOLOv2Head',
  39. mask_head='SOLOv2MaskHead'):
  40. super(SOLOv2, self).__init__()
  41. check_version('2.0.0-rc0')
  42. self.backbone = backbone
  43. self.fpn = fpn
  44. self.bbox_head = bbox_head
  45. self.mask_head = mask_head
  46. def build(self, feed_vars, mode='train'):
  47. im = feed_vars['image']
  48. mixed_precision_enabled = mixed_precision_global_state() is not None
  49. # cast inputs to FP16
  50. if mixed_precision_enabled:
  51. im = fluid.layers.cast(im, 'float16')
  52. body_feats = self.backbone(im)
  53. if self.fpn is not None:
  54. body_feats, spatial_scale = self.fpn.get_output(body_feats)
  55. if isinstance(body_feats, OrderedDict):
  56. body_feat_names = list(body_feats.keys())
  57. body_feats = [body_feats[name] for name in body_feat_names]
  58. # cast features back to FP32
  59. if mixed_precision_enabled:
  60. body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats]
  61. mask_feat_pred = self.mask_head.get_output(body_feats)
  62. if mode == 'train':
  63. ins_labels = []
  64. cate_labels = []
  65. grid_orders = []
  66. fg_num = feed_vars['fg_num']
  67. for i in range(self.num_level):
  68. ins_label = 'ins_label{}'.format(i)
  69. if ins_label in feed_vars:
  70. ins_labels.append(feed_vars[ins_label])
  71. cate_label = 'cate_label{}'.format(i)
  72. if cate_label in feed_vars:
  73. cate_labels.append(feed_vars[cate_label])
  74. grid_order = 'grid_order{}'.format(i)
  75. if grid_order in feed_vars:
  76. grid_orders.append(feed_vars[grid_order])
  77. cate_preds, kernel_preds = self.bbox_head.get_outputs(body_feats)
  78. losses = self.bbox_head.get_loss(cate_preds, kernel_preds,
  79. mask_feat_pred, ins_labels,
  80. cate_labels, grid_orders, fg_num)
  81. total_loss = fluid.layers.sum(list(losses.values()))
  82. losses.update({'loss': total_loss})
  83. return losses
  84. else:
  85. im_info = feed_vars['im_info']
  86. outs = self.bbox_head.get_outputs(body_feats, is_eval=True)
  87. seg_inputs = outs + (mask_feat_pred, im_info)
  88. return self.bbox_head.get_prediction(*seg_inputs)
  89. def _inputs_def(self, image_shape, fields):
  90. im_shape = [None] + image_shape
  91. # yapf: disable
  92. inputs_def = {
  93. 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
  94. 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  95. 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
  96. 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  97. }
  98. if 'gt_segm' in fields:
  99. for i in range(self.num_level):
  100. targets_def = {
  101. 'ins_label%d' % i: {'shape': [None, None, None], 'dtype': 'int32', 'lod_level': 1},
  102. 'cate_label%d' % i: {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
  103. 'grid_order%d' % i: {'shape': [None], 'dtype': 'int32', 'lod_level': 1},
  104. }
  105. inputs_def.update(targets_def)
  106. targets_def = {
  107. 'fg_num': {'shape': [None], 'dtype': 'int32', 'lod_level': 0},
  108. }
  109. # yapf: enable
  110. inputs_def.update(targets_def)
  111. return inputs_def
  112. def build_inputs(
  113. self,
  114. image_shape=[3, None, None],
  115. fields=['image', 'im_id', 'gt_segm'], # for train
  116. num_level=5,
  117. use_dataloader=True,
  118. iterable=False):
  119. self.num_level = num_level
  120. inputs_def = self._inputs_def(image_shape, fields)
  121. if 'gt_segm' in fields:
  122. fields.remove('gt_segm')
  123. fields.extend(['fg_num'])
  124. for i in range(num_level):
  125. fields.extend([
  126. 'ins_label%d' % i, 'cate_label%d' % i, 'grid_order%d' % i
  127. ])
  128. feed_vars = OrderedDict([(key, fluid.data(
  129. name=key,
  130. shape=inputs_def[key]['shape'],
  131. dtype=inputs_def[key]['dtype'],
  132. lod_level=inputs_def[key]['lod_level'])) for key in fields])
  133. loader = fluid.io.DataLoader.from_generator(
  134. feed_list=list(feed_vars.values()),
  135. capacity=16,
  136. use_double_buffer=True,
  137. iterable=iterable) if use_dataloader else None
  138. return feed_vars, loader
  139. def train(self, feed_vars):
  140. return self.build(feed_vars, mode='train')
  141. def eval(self, feed_vars):
  142. return self.build(feed_vars, mode='test')
  143. def test(self, feed_vars, exclude_nms=False):
  144. assert not exclude_nms, "exclude_nms for {} is not support currently".format(
  145. self.__class__.__name__)
  146. return self.build(feed_vars, mode='test')