infer_mot.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. # ignore warning log
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. import paddle
  26. from ppdet.core.workspace import load_config, merge_config
  27. from ppdet.engine import Tracker
  28. from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config
  29. from ppdet.utils.cli import ArgsParser
  30. def parse_args():
  31. parser = ArgsParser()
  32. parser.add_argument(
  33. '--video_file', type=str, default=None, help='Video name for tracking.')
  34. parser.add_argument(
  35. '--frame_rate',
  36. type=int,
  37. default=-1,
  38. help='Video frame rate for tracking.')
  39. parser.add_argument(
  40. "--image_dir",
  41. type=str,
  42. default=None,
  43. help="Directory for images to perform inference on.")
  44. parser.add_argument(
  45. "--det_results_dir",
  46. type=str,
  47. default='',
  48. help="Directory name for detection results.")
  49. parser.add_argument(
  50. '--output_dir',
  51. type=str,
  52. default='output',
  53. help='Directory name for output tracking results.')
  54. parser.add_argument(
  55. '--save_images',
  56. action='store_true',
  57. help='Save tracking results (image).')
  58. parser.add_argument(
  59. '--save_videos',
  60. action='store_true',
  61. help='Save tracking results (video).')
  62. parser.add_argument(
  63. '--show_image',
  64. action='store_true',
  65. help='Show tracking results (image).')
  66. parser.add_argument(
  67. '--scaled',
  68. type=bool,
  69. default=False,
  70. help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
  71. "True in general detector.")
  72. parser.add_argument(
  73. "--draw_threshold",
  74. type=float,
  75. default=0.5,
  76. help="Threshold to reserve the result for visualization.")
  77. args = parser.parse_args()
  78. return args
  79. def run(FLAGS, cfg):
  80. # build Tracker
  81. tracker = Tracker(cfg, mode='test')
  82. # load weights
  83. if cfg.architecture in ['DeepSORT', 'ByteTrack']:
  84. tracker.load_weights_sde(cfg.det_weights, cfg.reid_weights)
  85. else:
  86. tracker.load_weights_jde(cfg.weights)
  87. # inference
  88. tracker.mot_predict_seq(
  89. video_file=FLAGS.video_file,
  90. frame_rate=FLAGS.frame_rate,
  91. image_dir=FLAGS.image_dir,
  92. data_type=cfg.metric.lower(),
  93. model_type=cfg.architecture,
  94. output_dir=FLAGS.output_dir,
  95. save_images=FLAGS.save_images,
  96. save_videos=FLAGS.save_videos,
  97. show_image=FLAGS.show_image,
  98. scaled=FLAGS.scaled,
  99. det_results_dir=FLAGS.det_results_dir,
  100. draw_threshold=FLAGS.draw_threshold)
  101. def main():
  102. FLAGS = parse_args()
  103. cfg = load_config(FLAGS.config)
  104. merge_config(FLAGS.opt)
  105. # disable npu in config by default
  106. if 'use_npu' not in cfg:
  107. cfg.use_npu = False
  108. # disable xpu in config by default
  109. if 'use_xpu' not in cfg:
  110. cfg.use_xpu = False
  111. if cfg.use_gpu:
  112. place = paddle.set_device('gpu')
  113. elif cfg.use_npu:
  114. place = paddle.set_device('npu')
  115. elif cfg.use_xpu:
  116. place = paddle.set_device('xpu')
  117. else:
  118. place = paddle.set_device('cpu')
  119. if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
  120. cfg['norm_type'] = 'bn'
  121. check_config(cfg)
  122. check_gpu(cfg.use_gpu)
  123. check_npu(cfg.use_npu)
  124. check_xpu(cfg.use_xpu)
  125. check_version()
  126. run(FLAGS, cfg)
  127. if __name__ == '__main__':
  128. main()