det_keypoint_unite_infer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import cv2
  17. import math
  18. import numpy as np
  19. import paddle
  20. import yaml
  21. from det_keypoint_unite_utils import argsparser
  22. from preprocess import decode_image
  23. from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log
  24. from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint
  25. from visualize import visualize_pose
  26. from benchmark_utils import PaddleInferBenchmark
  27. from dependence.PaddleDetection.deploy.python.utils import get_current_memory_mb
  28. from keypoint_postprocess import translate_to_ori_images
  29. KEYPOINT_SUPPORT_MODELS = {
  30. 'HigherHRNet': 'keypoint_bottomup',
  31. 'HRNet': 'keypoint_topdown'
  32. }
  33. def predict_with_given_det(image, det_res, keypoint_detector,
  34. keypoint_batch_size, run_benchmark):
  35. rec_images, records, det_rects = keypoint_detector.get_person_from_rect(
  36. image, det_res)
  37. keypoint_vector = []
  38. score_vector = []
  39. rect_vector = det_rects
  40. keypoint_results = keypoint_detector.predict_image(
  41. rec_images, run_benchmark, repeats=10, visual=False)
  42. keypoint_vector, score_vector = translate_to_ori_images(keypoint_results,
  43. np.array(records))
  44. keypoint_res = {}
  45. keypoint_res['keypoint'] = [
  46. keypoint_vector.tolist(), score_vector.tolist()
  47. ] if len(keypoint_vector) > 0 else [[], []]
  48. keypoint_res['bbox'] = rect_vector
  49. return keypoint_res
  50. def topdown_unite_predict(detector,
  51. topdown_keypoint_detector,
  52. image_list,
  53. keypoint_batch_size=1,
  54. save_res=False):
  55. det_timer = detector.get_timer()
  56. store_res = []
  57. for i, img_file in enumerate(image_list):
  58. # Decode image in advance in det + pose prediction
  59. det_timer.preprocess_time_s.start()
  60. image, _ = decode_image(img_file, {})
  61. det_timer.preprocess_time_s.end()
  62. if FLAGS.run_benchmark:
  63. results = detector.predict_image(
  64. [image], run_benchmark=True, repeats=10)
  65. cm, gm, gu = get_current_memory_mb()
  66. detector.cpu_mem += cm
  67. detector.gpu_mem += gm
  68. detector.gpu_util += gu
  69. else:
  70. results = detector.predict_image([image], visual=False)
  71. results = detector.filter_box(results, FLAGS.det_threshold)
  72. if results['boxes_num'] > 0:
  73. keypoint_res = predict_with_given_det(
  74. image, results, topdown_keypoint_detector, keypoint_batch_size,
  75. FLAGS.run_benchmark)
  76. if save_res:
  77. save_name = img_file if isinstance(img_file, str) else i
  78. store_res.append([
  79. save_name, keypoint_res['bbox'],
  80. [keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
  81. ])
  82. else:
  83. results["keypoint"] = [[], []]
  84. keypoint_res = results
  85. if FLAGS.run_benchmark:
  86. cm, gm, gu = get_current_memory_mb()
  87. topdown_keypoint_detector.cpu_mem += cm
  88. topdown_keypoint_detector.gpu_mem += gm
  89. topdown_keypoint_detector.gpu_util += gu
  90. else:
  91. if not os.path.exists(FLAGS.output_dir):
  92. os.makedirs(FLAGS.output_dir)
  93. visualize_pose(
  94. img_file,
  95. keypoint_res,
  96. visual_thresh=FLAGS.keypoint_threshold,
  97. save_dir=FLAGS.output_dir)
  98. if save_res:
  99. """
  100. 1) store_res: a list of image_data
  101. 2) image_data: [imageid, rects, [keypoints, scores]]
  102. 3) rects: list of rect [xmin, ymin, xmax, ymax]
  103. 4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
  104. 5) scores: mean of all joint conf
  105. """
  106. with open("det_keypoint_unite_image_results.json", 'w') as wf:
  107. json.dump(store_res, wf, indent=4)
  108. def topdown_unite_predict_video(detector,
  109. topdown_keypoint_detector,
  110. camera_id,
  111. keypoint_batch_size=1,
  112. save_res=False):
  113. video_name = 'output.mp4'
  114. if camera_id != -1:
  115. capture = cv2.VideoCapture(camera_id)
  116. else:
  117. capture = cv2.VideoCapture(FLAGS.video_file)
  118. video_name = os.path.split(FLAGS.video_file)[-1]
  119. # Get Video info : resolution, fps, frame count
  120. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  121. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  122. fps = int(capture.get(cv2.CAP_PROP_FPS))
  123. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  124. print("fps: %d, frame_count: %d" % (fps, frame_count))
  125. if not os.path.exists(FLAGS.output_dir):
  126. os.makedirs(FLAGS.output_dir)
  127. out_path = os.path.join(FLAGS.output_dir, video_name)
  128. fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
  129. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  130. index = 0
  131. store_res = []
  132. while (1):
  133. ret, frame = capture.read()
  134. if not ret:
  135. break
  136. index += 1
  137. print('detect frame: %d' % (index))
  138. frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  139. results = detector.predict_image([frame2], visual=False)
  140. results = detector.filter_box(results, FLAGS.det_threshold)
  141. if results['boxes_num'] == 0:
  142. writer.write(frame)
  143. continue
  144. keypoint_res = predict_with_given_det(
  145. frame2, results, topdown_keypoint_detector, keypoint_batch_size,
  146. FLAGS.run_benchmark)
  147. im = visualize_pose(
  148. frame,
  149. keypoint_res,
  150. visual_thresh=FLAGS.keypoint_threshold,
  151. returnimg=True)
  152. if save_res:
  153. store_res.append([
  154. index, keypoint_res['bbox'],
  155. [keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
  156. ])
  157. writer.write(im)
  158. if camera_id != -1:
  159. cv2.imshow('Mask Detection', im)
  160. if cv2.waitKey(1) & 0xFF == ord('q'):
  161. break
  162. writer.release()
  163. print('output_video saved to: {}'.format(out_path))
  164. if save_res:
  165. """
  166. 1) store_res: a list of frame_data
  167. 2) frame_data: [frameid, rects, [keypoints, scores]]
  168. 3) rects: list of rect [xmin, ymin, xmax, ymax]
  169. 4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
  170. 5) scores: mean of all joint conf
  171. """
  172. with open("det_keypoint_unite_video_results.json", 'w') as wf:
  173. json.dump(store_res, wf, indent=4)
  174. def main():
  175. deploy_file = os.path.join(FLAGS.det_model_dir, 'infer_cfg.yml')
  176. with open(deploy_file) as f:
  177. yml_conf = yaml.safe_load(f)
  178. arch = yml_conf['arch']
  179. detector_func = 'Detector'
  180. if arch == 'PicoDet':
  181. detector_func = 'DetectorPicoDet'
  182. detector = eval(detector_func)(FLAGS.det_model_dir,
  183. device=FLAGS.device,
  184. run_mode=FLAGS.run_mode,
  185. trt_min_shape=FLAGS.trt_min_shape,
  186. trt_max_shape=FLAGS.trt_max_shape,
  187. trt_opt_shape=FLAGS.trt_opt_shape,
  188. trt_calib_mode=FLAGS.trt_calib_mode,
  189. cpu_threads=FLAGS.cpu_threads,
  190. enable_mkldnn=FLAGS.enable_mkldnn,
  191. threshold=FLAGS.det_threshold)
  192. topdown_keypoint_detector = KeyPointDetector(
  193. FLAGS.keypoint_model_dir,
  194. device=FLAGS.device,
  195. run_mode=FLAGS.run_mode,
  196. batch_size=FLAGS.keypoint_batch_size,
  197. trt_min_shape=FLAGS.trt_min_shape,
  198. trt_max_shape=FLAGS.trt_max_shape,
  199. trt_opt_shape=FLAGS.trt_opt_shape,
  200. trt_calib_mode=FLAGS.trt_calib_mode,
  201. cpu_threads=FLAGS.cpu_threads,
  202. enable_mkldnn=FLAGS.enable_mkldnn,
  203. use_dark=FLAGS.use_dark)
  204. keypoint_arch = topdown_keypoint_detector.pred_config.arch
  205. assert KEYPOINT_SUPPORT_MODELS[
  206. keypoint_arch] == 'keypoint_topdown', 'Detection-Keypoint unite inference only supports topdown models.'
  207. # predict from video file or camera video stream
  208. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  209. topdown_unite_predict_video(detector, topdown_keypoint_detector,
  210. FLAGS.camera_id, FLAGS.keypoint_batch_size,
  211. FLAGS.save_res)
  212. else:
  213. # predict from image
  214. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  215. topdown_unite_predict(detector, topdown_keypoint_detector, img_list,
  216. FLAGS.keypoint_batch_size, FLAGS.save_res)
  217. if not FLAGS.run_benchmark:
  218. detector.det_times.info(average=True)
  219. topdown_keypoint_detector.det_times.info(average=True)
  220. else:
  221. mode = FLAGS.run_mode
  222. det_model_dir = FLAGS.det_model_dir
  223. det_model_info = {
  224. 'model_name': det_model_dir.strip('/').split('/')[-1],
  225. 'precision': mode.split('_')[-1]
  226. }
  227. bench_log(detector, img_list, det_model_info, name='Det')
  228. keypoint_model_dir = FLAGS.keypoint_model_dir
  229. keypoint_model_info = {
  230. 'model_name': keypoint_model_dir.strip('/').split('/')[-1],
  231. 'precision': mode.split('_')[-1]
  232. }
  233. bench_log(topdown_keypoint_detector, img_list, keypoint_model_info,
  234. FLAGS.keypoint_batch_size, 'KeyPoint')
  235. if __name__ == '__main__':
  236. paddle.enable_static()
  237. parser = argsparser()
  238. FLAGS = parser.parse_args()
  239. print_arguments(FLAGS)
  240. FLAGS.device = FLAGS.device.upper()
  241. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  242. ], "device should be CPU, GPU or XPU"
  243. main()