eval_mot.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. # ignore warning log
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. import paddle
  26. from ppdet.core.workspace import load_config, merge_config
  27. from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config
  28. from ppdet.utils.cli import ArgsParser
  29. from ppdet.engine import Tracker
  30. def parse_args():
  31. parser = ArgsParser()
  32. parser.add_argument(
  33. "--det_results_dir",
  34. type=str,
  35. default='',
  36. help="Directory name for detection results.")
  37. parser.add_argument(
  38. '--output_dir',
  39. type=str,
  40. default='output',
  41. help='Directory name for output tracking results.')
  42. parser.add_argument(
  43. '--save_images',
  44. action='store_true',
  45. help='Save tracking results (image).')
  46. parser.add_argument(
  47. '--save_videos',
  48. action='store_true',
  49. help='Save tracking results (video).')
  50. parser.add_argument(
  51. '--show_image',
  52. action='store_true',
  53. help='Show tracking results (image).')
  54. parser.add_argument(
  55. '--scaled',
  56. type=bool,
  57. default=False,
  58. help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
  59. "True in general detector.")
  60. args = parser.parse_args()
  61. return args
  62. def run(FLAGS, cfg):
  63. dataset_dir = cfg['EvalMOTDataset'].dataset_dir
  64. data_root = cfg['EvalMOTDataset'].data_root
  65. data_root = '{}/{}'.format(dataset_dir, data_root)
  66. seqs = os.listdir(data_root)
  67. seqs.sort()
  68. # build Tracker
  69. tracker = Tracker(cfg, mode='eval')
  70. # load weights
  71. if cfg.architecture in ['DeepSORT', 'ByteTrack']:
  72. tracker.load_weights_sde(cfg.det_weights, cfg.reid_weights)
  73. else:
  74. tracker.load_weights_jde(cfg.weights)
  75. # inference
  76. tracker.mot_evaluate(
  77. data_root=data_root,
  78. seqs=seqs,
  79. data_type=cfg.metric.lower(),
  80. model_type=cfg.architecture,
  81. output_dir=FLAGS.output_dir,
  82. save_images=FLAGS.save_images,
  83. save_videos=FLAGS.save_videos,
  84. show_image=FLAGS.show_image,
  85. scaled=FLAGS.scaled,
  86. det_results_dir=FLAGS.det_results_dir)
  87. def main():
  88. FLAGS = parse_args()
  89. cfg = load_config(FLAGS.config)
  90. merge_config(FLAGS.opt)
  91. # disable npu in config by default
  92. if 'use_npu' not in cfg:
  93. cfg.use_npu = False
  94. # disable xpu in config by default
  95. if 'use_xpu' not in cfg:
  96. cfg.use_xpu = False
  97. if cfg.use_gpu:
  98. place = paddle.set_device('gpu')
  99. elif cfg.use_npu:
  100. place = paddle.set_device('npu')
  101. elif cfg.use_xpu:
  102. place = paddle.set_device('xpu')
  103. else:
  104. place = paddle.set_device('cpu')
  105. if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
  106. cfg['norm_type'] = 'bn'
  107. check_config(cfg)
  108. check_gpu(cfg.use_gpu)
  109. check_npu(cfg.use_npu)
  110. check_xpu(cfg.use_xpu)
  111. check_version()
  112. run(FLAGS, cfg)
  113. if __name__ == '__main__':
  114. main()