efficientdet.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from collections import OrderedDict
  17. import paddle.fluid as fluid
  18. from ppdet.experimental import mixed_precision_global_state
  19. from ppdet.core.workspace import register
  20. __all__ = ['EfficientDet']
  21. @register
  22. class EfficientDet(object):
  23. """
  24. EfficientDet architecture, see https://arxiv.org/abs/1911.09070
  25. Args:
  26. backbone (object): backbone instance
  27. fpn (object): feature pyramid network instance
  28. retina_head (object): `RetinaHead` instance
  29. """
  30. __category__ = 'architecture'
  31. __inject__ = ['backbone', 'fpn', 'efficient_head', 'anchor_grid']
  32. def __init__(self,
  33. backbone,
  34. fpn,
  35. efficient_head,
  36. anchor_grid,
  37. box_loss_weight=50.):
  38. super(EfficientDet, self).__init__()
  39. self.backbone = backbone
  40. self.fpn = fpn
  41. self.efficient_head = efficient_head
  42. self.anchor_grid = anchor_grid
  43. self.box_loss_weight = box_loss_weight
  44. def build(self, feed_vars, mode='train'):
  45. im = feed_vars['image']
  46. if mode == 'train':
  47. gt_labels = feed_vars['gt_label']
  48. gt_targets = feed_vars['gt_target']
  49. fg_num = feed_vars['fg_num']
  50. else:
  51. im_info = feed_vars['im_info']
  52. mixed_precision_enabled = mixed_precision_global_state() is not None
  53. if mixed_precision_enabled:
  54. im = fluid.layers.cast(im, 'float16')
  55. body_feats = self.backbone(im)
  56. if mixed_precision_enabled:
  57. body_feats = [fluid.layers.cast(f, 'float32') for f in body_feats]
  58. body_feats = self.fpn(body_feats)
  59. # XXX not used for training, but the parameters are needed when
  60. # exporting inference model
  61. anchors = self.anchor_grid()
  62. if mode == 'train':
  63. loss = self.efficient_head.get_loss(body_feats, gt_labels,
  64. gt_targets, fg_num)
  65. loss_cls = loss['loss_cls']
  66. loss_bbox = loss['loss_bbox']
  67. total_loss = loss_cls + self.box_loss_weight * loss_bbox
  68. loss.update({'loss': total_loss})
  69. return loss
  70. else:
  71. pred = self.efficient_head.get_prediction(body_feats, anchors,
  72. im_info)
  73. return pred
  74. def _inputs_def(self, image_shape):
  75. im_shape = [None] + image_shape
  76. inputs_def = {
  77. 'image': {
  78. 'shape': im_shape,
  79. 'dtype': 'float32'
  80. },
  81. 'im_info': {
  82. 'shape': [None, 3],
  83. 'dtype': 'float32'
  84. },
  85. 'im_id': {
  86. 'shape': [None, 1],
  87. 'dtype': 'int64'
  88. },
  89. 'im_shape': {
  90. 'shape': [None, 3],
  91. 'dtype': 'float32'
  92. },
  93. 'fg_num': {
  94. 'shape': [None, 1],
  95. 'dtype': 'int32'
  96. },
  97. 'gt_label': {
  98. 'shape': [None, None, 1],
  99. 'dtype': 'int32'
  100. },
  101. 'gt_target': {
  102. 'shape': [None, None, 4],
  103. 'dtype': 'float32'
  104. },
  105. }
  106. return inputs_def
  107. def build_inputs(self,
  108. image_shape=[3, None, None],
  109. fields=[
  110. 'image', 'im_info', 'im_id', 'fg_num', 'gt_label',
  111. 'gt_target'
  112. ],
  113. use_dataloader=True,
  114. iterable=False):
  115. inputs_def = self._inputs_def(image_shape)
  116. feed_vars = OrderedDict([(key, fluid.data(
  117. name=key,
  118. shape=inputs_def[key]['shape'],
  119. dtype=inputs_def[key]['dtype'])) for key in fields])
  120. loader = fluid.io.DataLoader.from_generator(
  121. feed_list=list(feed_vars.values()),
  122. capacity=16,
  123. use_double_buffer=True,
  124. iterable=iterable) if use_dataloader else None
  125. return feed_vars, loader
  126. def train(self, feed_vars):
  127. return self.build(feed_vars, 'train')
  128. def eval(self, feed_vars):
  129. return self.build(feed_vars, 'test')
  130. def test(self, feed_vars, exclude_nms=False):
  131. assert not exclude_nms, "exclude_nms for {} is not support currently".format(
  132. self.__class__.__name__)
  133. return self.build(feed_vars, 'test')