retinanet.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. import paddle.fluid as fluid
  19. from ppdet.experimental import mixed_precision_global_state
  20. from ppdet.core.workspace import register
  21. __all__ = ['RetinaNet']
  22. @register
  23. class RetinaNet(object):
  24. """
  25. RetinaNet architecture, see https://arxiv.org/abs/1708.02002
  26. Args:
  27. backbone (object): backbone instance
  28. fpn (object): feature pyramid network instance
  29. retina_head (object): `RetinaHead` instance
  30. """
  31. __category__ = 'architecture'
  32. __inject__ = ['backbone', 'fpn', 'retina_head']
  33. def __init__(self, backbone, fpn, retina_head):
  34. super(RetinaNet, self).__init__()
  35. self.backbone = backbone
  36. self.fpn = fpn
  37. self.retina_head = retina_head
  38. def build(self, feed_vars, mode='train'):
  39. im = feed_vars['image']
  40. im_info = feed_vars['im_info']
  41. if mode == 'train':
  42. gt_bbox = feed_vars['gt_bbox']
  43. gt_class = feed_vars['gt_class']
  44. is_crowd = feed_vars['is_crowd']
  45. mixed_precision_enabled = mixed_precision_global_state() is not None
  46. # cast inputs to FP16
  47. if mixed_precision_enabled:
  48. im = fluid.layers.cast(im, 'float16')
  49. # backbone
  50. body_feats = self.backbone(im)
  51. # cast features back to FP32
  52. if mixed_precision_enabled:
  53. body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32'))
  54. for k, v in body_feats.items())
  55. # FPN
  56. body_feats, spatial_scale = self.fpn.get_output(body_feats)
  57. # retinanet head
  58. if mode == 'train':
  59. loss = self.retina_head.get_loss(body_feats, spatial_scale, im_info,
  60. gt_bbox, gt_class, is_crowd)
  61. total_loss = fluid.layers.sum(list(loss.values()))
  62. loss.update({'loss': total_loss})
  63. return loss
  64. else:
  65. pred = self.retina_head.get_prediction(body_feats, spatial_scale,
  66. im_info)
  67. return pred
  68. def _inputs_def(self, image_shape):
  69. im_shape = [None] + image_shape
  70. # yapf: disable
  71. inputs_def = {
  72. 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
  73. 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  74. 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
  75. 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
  76. 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1},
  77. 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  78. 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  79. 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
  80. }
  81. # yapf: enable
  82. return inputs_def
  83. def build_inputs(
  84. self,
  85. image_shape=[3, None, None],
  86. fields=[
  87. 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd'
  88. ], # for-train
  89. use_dataloader=True,
  90. iterable=False):
  91. inputs_def = self._inputs_def(image_shape)
  92. feed_vars = OrderedDict([(key, fluid.data(
  93. name=key,
  94. shape=inputs_def[key]['shape'],
  95. dtype=inputs_def[key]['dtype'],
  96. lod_level=inputs_def[key]['lod_level'])) for key in fields])
  97. loader = fluid.io.DataLoader.from_generator(
  98. feed_list=list(feed_vars.values()),
  99. capacity=16,
  100. use_double_buffer=True,
  101. iterable=iterable) if use_dataloader else None
  102. return feed_vars, loader
  103. def train(self, feed_vars):
  104. return self.build(feed_vars, 'train')
  105. def eval(self, feed_vars):
  106. return self.build(feed_vars, 'test')
  107. def test(self, feed_vars, exclude_nms=False):
  108. assert not exclude_nms, "exclude_nms for {} is not support currently".format(
  109. self.__class__.__name__)
  110. return self.build(feed_vars, 'test')