yolo.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from ppdet.core.workspace import register, create
  18. from .meta_arch import BaseArch
  19. from ..post_process import JDEBBoxPostProcess
  20. __all__ = ['YOLOv3']
  21. @register
  22. class YOLOv3(BaseArch):
  23. __category__ = 'architecture'
  24. __shared__ = ['data_format']
  25. __inject__ = ['post_process']
  26. def __init__(self,
  27. backbone='DarkNet',
  28. neck='YOLOv3FPN',
  29. yolo_head='YOLOv3Head',
  30. post_process='BBoxPostProcess',
  31. data_format='NCHW',
  32. for_mot=False):
  33. """
  34. YOLOv3 network, see https://arxiv.org/abs/1804.02767
  35. Args:
  36. backbone (nn.Layer): backbone instance
  37. neck (nn.Layer): neck instance
  38. yolo_head (nn.Layer): anchor_head instance
  39. bbox_post_process (object): `BBoxPostProcess` instance
  40. data_format (str): data format, NCHW or NHWC
  41. for_mot (bool): whether return other features for multi-object tracking
  42. models, default False in pure object detection models.
  43. """
  44. super(YOLOv3, self).__init__(data_format=data_format)
  45. self.backbone = backbone
  46. self.neck = neck
  47. self.yolo_head = yolo_head
  48. self.post_process = post_process
  49. self.for_mot = for_mot
  50. self.return_idx = isinstance(post_process, JDEBBoxPostProcess)
  51. @classmethod
  52. def from_config(cls, cfg, *args, **kwargs):
  53. # backbone
  54. backbone = create(cfg['backbone'])
  55. # fpn
  56. kwargs = {'input_shape': backbone.out_shape}
  57. neck = create(cfg['neck'], **kwargs)
  58. # head
  59. kwargs = {'input_shape': neck.out_shape}
  60. yolo_head = create(cfg['yolo_head'], **kwargs)
  61. return {
  62. 'backbone': backbone,
  63. 'neck': neck,
  64. "yolo_head": yolo_head,
  65. }
  66. def _forward(self):
  67. body_feats = self.backbone(self.inputs)
  68. neck_feats = self.neck(body_feats, self.for_mot)
  69. if isinstance(neck_feats, dict):
  70. assert self.for_mot == True
  71. emb_feats = neck_feats['emb_feats']
  72. neck_feats = neck_feats['yolo_feats']
  73. if self.training:
  74. yolo_losses = self.yolo_head(neck_feats, self.inputs)
  75. if self.for_mot:
  76. return {'det_losses': yolo_losses, 'emb_feats': emb_feats}
  77. else:
  78. return yolo_losses
  79. else:
  80. yolo_head_outs = self.yolo_head(neck_feats)
  81. if self.for_mot:
  82. boxes_idx, bbox, bbox_num, nms_keep_idx = self.post_process(
  83. yolo_head_outs, self.yolo_head.mask_anchors)
  84. output = {
  85. 'bbox': bbox,
  86. 'bbox_num': bbox_num,
  87. 'boxes_idx': boxes_idx,
  88. 'nms_keep_idx': nms_keep_idx,
  89. 'emb_feats': emb_feats,
  90. }
  91. else:
  92. if self.return_idx:
  93. _, bbox, bbox_num, _ = self.post_process(
  94. yolo_head_outs, self.yolo_head.mask_anchors)
  95. elif self.post_process is not None:
  96. bbox, bbox_num = self.post_process(
  97. yolo_head_outs, self.yolo_head.mask_anchors,
  98. self.inputs['im_shape'], self.inputs['scale_factor'])
  99. else:
  100. bbox, bbox_num = self.yolo_head.post_process(
  101. yolo_head_outs, self.inputs['im_shape'],
  102. self.inputs['scale_factor'])
  103. output = {'bbox': bbox, 'bbox_num': bbox_num}
  104. return output
  105. def get_loss(self):
  106. return self._forward()
  107. def get_pred(self):
  108. return self._forward()