mot_keypoint_unite_infer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import cv2
  17. import math
  18. import numpy as np
  19. import paddle
  20. import yaml
  21. import copy
  22. from collections import defaultdict
  23. from mot_keypoint_unite_utils import argsparser
  24. from preprocess import decode_image
  25. from infer import print_arguments, get_test_images, bench_log
  26. from mot_sde_infer import SDE_Detector
  27. from mot_jde_infer import JDE_Detector, MOT_JDE_SUPPORT_MODELS
  28. from keypoint_infer import KeyPointDetector, KEYPOINT_SUPPORT_MODELS
  29. from det_keypoint_unite_infer import predict_with_given_det
  30. from visualize import visualize_pose
  31. from benchmark_utils import PaddleInferBenchmark
  32. from utils import get_current_memory_mb
  33. from keypoint_postprocess import translate_to_ori_images
  34. # add python path
  35. import sys
  36. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  37. sys.path.insert(0, parent_path)
  38. from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
  39. from pptracking.python.mot.utils import MOTTimer as FPSTimer
  40. def convert_mot_to_det(tlwhs, scores):
  41. results = {}
  42. num_mot = len(tlwhs)
  43. xyxys = copy.deepcopy(tlwhs)
  44. for xyxy in xyxys.copy():
  45. xyxy[2:] = xyxy[2:] + xyxy[:2]
  46. # support single class now
  47. results['boxes'] = np.vstack(
  48. [np.hstack([0, scores[i], xyxys[i]]) for i in range(num_mot)])
  49. results['boxes_num'] = np.array([num_mot])
  50. return results
  51. def mot_topdown_unite_predict(mot_detector,
  52. topdown_keypoint_detector,
  53. image_list,
  54. keypoint_batch_size=1,
  55. save_res=False):
  56. det_timer = mot_detector.get_timer()
  57. store_res = []
  58. image_list.sort()
  59. num_classes = mot_detector.num_classes
  60. for i, img_file in enumerate(image_list):
  61. # Decode image in advance in mot + pose prediction
  62. det_timer.preprocess_time_s.start()
  63. image, _ = decode_image(img_file, {})
  64. det_timer.preprocess_time_s.end()
  65. if FLAGS.run_benchmark:
  66. mot_results = mot_detector.predict_image(
  67. [image], run_benchmark=True, repeats=10)
  68. cm, gm, gu = get_current_memory_mb()
  69. mot_detector.cpu_mem += cm
  70. mot_detector.gpu_mem += gm
  71. mot_detector.gpu_util += gu
  72. else:
  73. mot_results = mot_detector.predict_image([image], visual=False)
  74. online_tlwhs, online_scores, online_ids = mot_results[
  75. 0] # only support bs=1 in MOT model
  76. results = convert_mot_to_det(
  77. online_tlwhs[0],
  78. online_scores[0]) # only support single class for mot + pose
  79. if results['boxes_num'] == 0:
  80. continue
  81. keypoint_res = predict_with_given_det(
  82. image, results, topdown_keypoint_detector, keypoint_batch_size,
  83. FLAGS.run_benchmark)
  84. if save_res:
  85. save_name = img_file if isinstance(img_file, str) else i
  86. store_res.append([
  87. save_name, keypoint_res['bbox'],
  88. [keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
  89. ])
  90. if FLAGS.run_benchmark:
  91. cm, gm, gu = get_current_memory_mb()
  92. topdown_keypoint_detector.cpu_mem += cm
  93. topdown_keypoint_detector.gpu_mem += gm
  94. topdown_keypoint_detector.gpu_util += gu
  95. else:
  96. if not os.path.exists(FLAGS.output_dir):
  97. os.makedirs(FLAGS.output_dir)
  98. visualize_pose(
  99. img_file,
  100. keypoint_res,
  101. visual_thresh=FLAGS.keypoint_threshold,
  102. save_dir=FLAGS.output_dir)
  103. if save_res:
  104. """
  105. 1) store_res: a list of image_data
  106. 2) image_data: [imageid, rects, [keypoints, scores]]
  107. 3) rects: list of rect [xmin, ymin, xmax, ymax]
  108. 4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
  109. 5) scores: mean of all joint conf
  110. """
  111. with open("det_keypoint_unite_image_results.json", 'w') as wf:
  112. json.dump(store_res, wf, indent=4)
  113. def mot_topdown_unite_predict_video(mot_detector,
  114. topdown_keypoint_detector,
  115. camera_id,
  116. keypoint_batch_size=1,
  117. save_res=False):
  118. video_name = 'output.mp4'
  119. if camera_id != -1:
  120. capture = cv2.VideoCapture(camera_id)
  121. else:
  122. capture = cv2.VideoCapture(FLAGS.video_file)
  123. video_name = os.path.split(FLAGS.video_file)[-1]
  124. # Get Video info : resolution, fps, frame count
  125. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  126. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  127. fps = int(capture.get(cv2.CAP_PROP_FPS))
  128. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  129. print("fps: %d, frame_count: %d" % (fps, frame_count))
  130. if not os.path.exists(FLAGS.output_dir):
  131. os.makedirs(FLAGS.output_dir)
  132. out_path = os.path.join(FLAGS.output_dir, video_name)
  133. fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
  134. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  135. frame_id = 0
  136. timer_mot, timer_kp, timer_mot_kp = FPSTimer(), FPSTimer(), FPSTimer()
  137. num_classes = mot_detector.num_classes
  138. assert num_classes == 1, 'Only one category mot model supported for uniting keypoint deploy.'
  139. data_type = 'mot'
  140. while (1):
  141. ret, frame = capture.read()
  142. if not ret:
  143. break
  144. if frame_id % 10 == 0:
  145. print('Tracking frame: %d' % (frame_id))
  146. frame_id += 1
  147. timer_mot_kp.tic()
  148. # mot model
  149. timer_mot.tic()
  150. frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  151. mot_results = mot_detector.predict_image([frame2], visual=False)
  152. timer_mot.toc()
  153. online_tlwhs, online_scores, online_ids = mot_results[0]
  154. results = convert_mot_to_det(
  155. online_tlwhs[0],
  156. online_scores[0]) # only support single class for mot + pose
  157. if results['boxes_num'] == 0:
  158. continue
  159. # keypoint model
  160. timer_kp.tic()
  161. keypoint_res = predict_with_given_det(
  162. frame2, results, topdown_keypoint_detector, keypoint_batch_size,
  163. FLAGS.run_benchmark)
  164. timer_kp.toc()
  165. timer_mot_kp.toc()
  166. kp_fps = 1. / timer_kp.duration
  167. mot_kp_fps = 1. / timer_mot_kp.duration
  168. im = visualize_pose(
  169. frame,
  170. keypoint_res,
  171. visual_thresh=FLAGS.keypoint_threshold,
  172. returnimg=True,
  173. ids=online_ids[0])
  174. im = plot_tracking_dict(
  175. im,
  176. num_classes,
  177. online_tlwhs,
  178. online_ids,
  179. online_scores,
  180. frame_id=frame_id,
  181. fps=mot_kp_fps)
  182. writer.write(im)
  183. if camera_id != -1:
  184. cv2.imshow('Tracking and keypoint results', im)
  185. if cv2.waitKey(1) & 0xFF == ord('q'):
  186. break
  187. writer.release()
  188. print('output_video saved to: {}'.format(out_path))
  189. def main():
  190. deploy_file = os.path.join(FLAGS.mot_model_dir, 'infer_cfg.yml')
  191. with open(deploy_file) as f:
  192. yml_conf = yaml.safe_load(f)
  193. arch = yml_conf['arch']
  194. mot_detector_func = 'SDE_Detector'
  195. if arch in MOT_JDE_SUPPORT_MODELS:
  196. mot_detector_func = 'JDE_Detector'
  197. mot_detector = eval(mot_detector_func)(FLAGS.mot_model_dir,
  198. FLAGS.tracker_config,
  199. device=FLAGS.device,
  200. run_mode=FLAGS.run_mode,
  201. batch_size=1,
  202. trt_min_shape=FLAGS.trt_min_shape,
  203. trt_max_shape=FLAGS.trt_max_shape,
  204. trt_opt_shape=FLAGS.trt_opt_shape,
  205. trt_calib_mode=FLAGS.trt_calib_mode,
  206. cpu_threads=FLAGS.cpu_threads,
  207. enable_mkldnn=FLAGS.enable_mkldnn,
  208. threshold=FLAGS.mot_threshold,
  209. output_dir=FLAGS.output_dir)
  210. topdown_keypoint_detector = KeyPointDetector(
  211. FLAGS.keypoint_model_dir,
  212. device=FLAGS.device,
  213. run_mode=FLAGS.run_mode,
  214. batch_size=FLAGS.keypoint_batch_size,
  215. trt_min_shape=FLAGS.trt_min_shape,
  216. trt_max_shape=FLAGS.trt_max_shape,
  217. trt_opt_shape=FLAGS.trt_opt_shape,
  218. trt_calib_mode=FLAGS.trt_calib_mode,
  219. cpu_threads=FLAGS.cpu_threads,
  220. enable_mkldnn=FLAGS.enable_mkldnn,
  221. threshold=FLAGS.keypoint_threshold,
  222. output_dir=FLAGS.output_dir,
  223. use_dark=FLAGS.use_dark)
  224. keypoint_arch = topdown_keypoint_detector.pred_config.arch
  225. assert KEYPOINT_SUPPORT_MODELS[
  226. keypoint_arch] == 'keypoint_topdown', 'MOT-Keypoint unite inference only supports topdown models.'
  227. # predict from video file or camera video stream
  228. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  229. mot_topdown_unite_predict_video(
  230. mot_detector, topdown_keypoint_detector, FLAGS.camera_id,
  231. FLAGS.keypoint_batch_size, FLAGS.save_res)
  232. else:
  233. # predict from image
  234. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  235. mot_topdown_unite_predict(mot_detector, topdown_keypoint_detector,
  236. img_list, FLAGS.keypoint_batch_size,
  237. FLAGS.save_res)
  238. if not FLAGS.run_benchmark:
  239. mot_detector.det_times.info(average=True)
  240. topdown_keypoint_detector.det_times.info(average=True)
  241. else:
  242. mode = FLAGS.run_mode
  243. mot_model_dir = FLAGS.mot_model_dir
  244. mot_model_info = {
  245. 'model_name': mot_model_dir.strip('/').split('/')[-1],
  246. 'precision': mode.split('_')[-1]
  247. }
  248. bench_log(mot_detector, img_list, mot_model_info, name='MOT')
  249. keypoint_model_dir = FLAGS.keypoint_model_dir
  250. keypoint_model_info = {
  251. 'model_name': keypoint_model_dir.strip('/').split('/')[-1],
  252. 'precision': mode.split('_')[-1]
  253. }
  254. bench_log(topdown_keypoint_detector, img_list, keypoint_model_info,
  255. FLAGS.keypoint_batch_size, 'KeyPoint')
  256. if __name__ == '__main__':
  257. paddle.enable_static()
  258. parser = argsparser()
  259. FLAGS = parser.parse_args()
  260. print_arguments(FLAGS)
  261. FLAGS.device = FLAGS.device.upper()
  262. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  263. ], "device should be CPU, GPU or XPU"
  264. main()