fcos.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. import paddle.fluid as fluid
  19. from ppdet.experimental import mixed_precision_global_state
  20. from ppdet.core.workspace import register
  21. __all__ = ['FCOS']
  22. @register
  23. class FCOS(object):
  24. """
  25. FCOS architecture, see https://arxiv.org/abs/1904.01355
  26. Args:
  27. backbone (object): backbone instance
  28. fpn (object): feature pyramid network instance
  29. fcos_head (object): `FCOSHead` instance
  30. """
  31. __category__ = 'architecture'
  32. __inject__ = ['backbone', 'fpn', 'fcos_head']
  33. def __init__(self, backbone, fpn, fcos_head):
  34. super(FCOS, self).__init__()
  35. self.backbone = backbone
  36. self.fpn = fpn
  37. self.fcos_head = fcos_head
  38. def build(self, feed_vars, mode='train'):
  39. im = feed_vars['image']
  40. im_info = feed_vars['im_info']
  41. mixed_precision_enabled = mixed_precision_global_state() is not None
  42. # cast inputs to FP16
  43. if mixed_precision_enabled:
  44. im = fluid.layers.cast(im, 'float16')
  45. # backbone
  46. body_feats = self.backbone(im)
  47. # cast features back to FP32
  48. if mixed_precision_enabled:
  49. body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32'))
  50. for k, v in body_feats.items())
  51. # FPN
  52. body_feats, spatial_scale = self.fpn.get_output(body_feats)
  53. # fcosnet head
  54. if mode == 'train':
  55. tag_labels = []
  56. tag_bboxes = []
  57. tag_centerness = []
  58. for i in range(len(self.fcos_head.fpn_stride)):
  59. # reg_target, labels, scores, centerness
  60. k_lbl = 'labels{}'.format(i)
  61. if k_lbl in feed_vars:
  62. tag_labels.append(feed_vars[k_lbl])
  63. k_box = 'reg_target{}'.format(i)
  64. if k_box in feed_vars:
  65. tag_bboxes.append(feed_vars[k_box])
  66. k_ctn = 'centerness{}'.format(i)
  67. if k_ctn in feed_vars:
  68. tag_centerness.append(feed_vars[k_ctn])
  69. # tag_labels, tag_bboxes, tag_centerness
  70. loss = self.fcos_head.get_loss(body_feats, tag_labels, tag_bboxes,
  71. tag_centerness)
  72. total_loss = fluid.layers.sum(list(loss.values()))
  73. loss.update({'loss': total_loss})
  74. return loss
  75. else:
  76. pred = self.fcos_head.get_prediction(body_feats, im_info)
  77. return pred
  78. def _inputs_def(self, image_shape, fields):
  79. im_shape = [None] + image_shape
  80. # yapf: disable
  81. inputs_def = {
  82. 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
  83. 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  84. 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  85. 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
  86. 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1},
  87. 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  88. 'gt_score': {'shape': [None, 1], 'dtype': 'float32', 'lod_level': 1},
  89. 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  90. 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}
  91. }
  92. # yapf: disable
  93. if 'fcos_target' in fields:
  94. targets_def = {
  95. 'labels0': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
  96. 'reg_target0': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
  97. 'centerness0': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0},
  98. 'labels1': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
  99. 'reg_target1': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
  100. 'centerness1': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0},
  101. 'labels2': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
  102. 'reg_target2': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
  103. 'centerness2': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0},
  104. 'labels3': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
  105. 'reg_target3': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
  106. 'centerness3': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0},
  107. 'labels4': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
  108. 'reg_target4': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
  109. 'centerness4': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0},
  110. }
  111. # yapf: enable
  112. # downsample = 128
  113. for k, stride in enumerate(self.fcos_head.fpn_stride):
  114. k_lbl = 'labels{}'.format(k)
  115. k_box = 'reg_target{}'.format(k)
  116. k_ctn = 'centerness{}'.format(k)
  117. grid_y = image_shape[-2] // stride if image_shape[-2] else None
  118. grid_x = image_shape[-1] // stride if image_shape[-1] else None
  119. if grid_x is not None:
  120. num_pts = grid_x * grid_y
  121. num_dim2 = 1
  122. else:
  123. num_pts = None
  124. num_dim2 = None
  125. targets_def[k_lbl]['shape'][1] = num_pts
  126. targets_def[k_box]['shape'][1] = num_pts
  127. targets_def[k_ctn]['shape'][1] = num_pts
  128. targets_def[k_lbl]['shape'][2] = num_dim2
  129. targets_def[k_box]['shape'][2] = num_dim2
  130. targets_def[k_ctn]['shape'][2] = num_dim2
  131. inputs_def.update(targets_def)
  132. return inputs_def
  133. def build_inputs(
  134. self,
  135. image_shape=[3, None, None],
  136. fields=['image', 'im_info', 'fcos_target'], # for-train
  137. use_dataloader=True,
  138. iterable=False):
  139. inputs_def = self._inputs_def(image_shape, fields)
  140. if "fcos_target" in fields:
  141. for i in range(len(self.fcos_head.fpn_stride)):
  142. fields.extend(
  143. ['labels%d' % i, 'reg_target%d' % i, 'centerness%d' % i])
  144. fields.remove('fcos_target')
  145. feed_vars = OrderedDict([(key, fluid.data(
  146. name=key,
  147. shape=inputs_def[key]['shape'],
  148. dtype=inputs_def[key]['dtype'],
  149. lod_level=inputs_def[key]['lod_level'])) for key in fields])
  150. loader = fluid.io.DataLoader.from_generator(
  151. feed_list=list(feed_vars.values()),
  152. capacity=16,
  153. use_double_buffer=True,
  154. iterable=iterable) if use_dataloader else None
  155. return feed_vars, loader
  156. def train(self, feed_vars):
  157. return self.build(feed_vars, 'train')
  158. def eval(self, feed_vars):
  159. return self.build(feed_vars, 'test')
  160. def test(self, feed_vars, exclude_nms=False):
  161. assert not exclude_nms, "exclude_nms for {} is not support currently".format(
  162. self.__class__.__name__)
  163. return self.build(feed_vars, 'test')