jde.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from ppdet.core.workspace import register, create
  18. from .meta_arch import BaseArch
  19. __all__ = ['JDE']
  20. @register
  21. class JDE(BaseArch):
  22. __category__ = 'architecture'
  23. __shared__ = ['metric']
  24. """
  25. JDE network, see https://arxiv.org/abs/1909.12605v1
  26. Args:
  27. detector (object): detector model instance
  28. reid (object): reid model instance
  29. tracker (object): tracker instance
  30. metric (str): 'MOTDet' for training and detection evaluation, 'ReID'
  31. for ReID embedding evaluation, or 'MOT' for multi object tracking
  32. evaluation.
  33. """
  34. def __init__(self,
  35. detector='YOLOv3',
  36. reid='JDEEmbeddingHead',
  37. tracker='JDETracker',
  38. metric='MOT'):
  39. super(JDE, self).__init__()
  40. self.detector = detector
  41. self.reid = reid
  42. self.tracker = tracker
  43. self.metric = metric
  44. @classmethod
  45. def from_config(cls, cfg, *args, **kwargs):
  46. detector = create(cfg['detector'])
  47. kwargs = {'input_shape': detector.neck.out_shape}
  48. reid = create(cfg['reid'], **kwargs)
  49. tracker = create(cfg['tracker'])
  50. return {
  51. "detector": detector,
  52. "reid": reid,
  53. "tracker": tracker,
  54. }
  55. def _forward(self):
  56. det_outs = self.detector(self.inputs)
  57. if self.training:
  58. emb_feats = det_outs['emb_feats']
  59. loss_confs = det_outs['det_losses']['loss_confs']
  60. loss_boxes = det_outs['det_losses']['loss_boxes']
  61. jde_losses = self.reid(
  62. emb_feats,
  63. self.inputs,
  64. loss_confs=loss_confs,
  65. loss_boxes=loss_boxes)
  66. return jde_losses
  67. else:
  68. if self.metric == 'MOTDet':
  69. det_results = {
  70. 'bbox': det_outs['bbox'],
  71. 'bbox_num': det_outs['bbox_num'],
  72. }
  73. return det_results
  74. elif self.metric == 'MOT':
  75. emb_feats = det_outs['emb_feats']
  76. bboxes = det_outs['bbox']
  77. boxes_idx = det_outs['boxes_idx']
  78. nms_keep_idx = det_outs['nms_keep_idx']
  79. pred_dets, pred_embs = self.reid(
  80. emb_feats,
  81. self.inputs,
  82. bboxes=bboxes,
  83. boxes_idx=boxes_idx,
  84. nms_keep_idx=nms_keep_idx)
  85. return pred_dets, pred_embs
  86. else:
  87. raise ValueError("Unknown metric {} for multi object tracking.".
  88. format(self.metric))
  89. def get_loss(self):
  90. return self._forward()
  91. def get_pred(self):
  92. return self._forward()