centernet.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from ppdet.core.workspace import register, create
  18. from .meta_arch import BaseArch
  19. __all__ = ['CenterNet']
  20. @register
  21. class CenterNet(BaseArch):
  22. """
  23. CenterNet network, see http://arxiv.org/abs/1904.07850
  24. Args:
  25. backbone (object): backbone instance
  26. neck (object): FPN instance, default use 'CenterNetDLAFPN'
  27. head (object): 'CenterNetHead' instance
  28. post_process (object): 'CenterNetPostProcess' instance
  29. for_mot (bool): whether return other features used in tracking model
  30. """
  31. __category__ = 'architecture'
  32. __inject__ = ['post_process']
  33. __shared__ = ['for_mot']
  34. def __init__(self,
  35. backbone,
  36. neck='CenterNetDLAFPN',
  37. head='CenterNetHead',
  38. post_process='CenterNetPostProcess',
  39. for_mot=False):
  40. super(CenterNet, self).__init__()
  41. self.backbone = backbone
  42. self.neck = neck
  43. self.head = head
  44. self.post_process = post_process
  45. self.for_mot = for_mot
  46. @classmethod
  47. def from_config(cls, cfg, *args, **kwargs):
  48. backbone = create(cfg['backbone'])
  49. kwargs = {'input_shape': backbone.out_shape}
  50. neck = cfg['neck'] and create(cfg['neck'], **kwargs)
  51. out_shape = neck and neck.out_shape or backbone.out_shape
  52. kwargs = {'input_shape': out_shape}
  53. head = create(cfg['head'], **kwargs)
  54. return {'backbone': backbone, 'neck': neck, "head": head}
  55. def _forward(self):
  56. neck_feat = self.backbone(self.inputs)
  57. if self.neck is not None:
  58. neck_feat = self.neck(neck_feat)
  59. head_out = self.head(neck_feat, self.inputs)
  60. if self.for_mot:
  61. head_out.update({'neck_feat': neck_feat})
  62. elif self.training:
  63. head_out['loss'] = head_out.pop('det_loss')
  64. return head_out
  65. def get_pred(self):
  66. head_out = self._forward()
  67. if self.for_mot:
  68. bbox, bbox_inds, topk_clses = self.post_process(
  69. head_out['heatmap'],
  70. head_out['size'],
  71. head_out['offset'],
  72. im_shape=self.inputs['im_shape'],
  73. scale_factor=self.inputs['scale_factor'])
  74. output = {
  75. "bbox": bbox,
  76. "bbox_inds": bbox_inds,
  77. "topk_clses": topk_clses,
  78. "neck_feat": head_out['neck_feat']
  79. }
  80. else:
  81. bbox, bbox_num, _ = self.post_process(
  82. head_out['heatmap'],
  83. head_out['size'],
  84. head_out['offset'],
  85. im_shape=self.inputs['im_shape'],
  86. scale_factor=self.inputs['scale_factor'])
  87. output = {
  88. "bbox": bbox,
  89. "bbox_num": bbox_num,
  90. }
  91. return output
  92. def get_loss(self):
  93. return self._forward()