mot_jde_infer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import yaml
  17. import cv2
  18. import numpy as np
  19. from collections import defaultdict
  20. import paddle
  21. from benchmark_utils import PaddleInferBenchmark
  22. from preprocess import decode_image
  23. from utils import argsparser, Timer, get_current_memory_mb
  24. from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
  25. # add python path
  26. import sys
  27. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  28. sys.path.insert(0, parent_path)
  29. from pptracking.python.mot import JDETracker
  30. from pptracking.python.mot.utils import MOTTimer, write_mot_results
  31. from pptracking.python.mot.visualize import plot_tracking_dict
  32. # Global dictionary
  33. MOT_JDE_SUPPORT_MODELS = {
  34. 'JDE',
  35. 'FairMOT',
  36. }
  37. class JDE_Detector(Detector):
  38. """
  39. Args:
  40. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  41. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  42. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  43. batch_size (int): size of pre batch in inference
  44. trt_min_shape (int): min shape for dynamic shape in trt
  45. trt_max_shape (int): max shape for dynamic shape in trt
  46. trt_opt_shape (int): opt shape for dynamic shape in trt
  47. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  48. calibration, trt_calib_mode need to set True
  49. cpu_threads (int): cpu threads
  50. enable_mkldnn (bool): whether to open MKLDNN
  51. output_dir (string): The path of output, default as 'output'
  52. threshold (float): Score threshold of the detected bbox, default as 0.5
  53. save_images (bool): Whether to save visualization image results, default as False
  54. save_mot_txts (bool): Whether to save tracking results (txt), default as False
  55. """
  56. def __init__(
  57. self,
  58. model_dir,
  59. tracker_config=None,
  60. device='CPU',
  61. run_mode='paddle',
  62. batch_size=1,
  63. trt_min_shape=1,
  64. trt_max_shape=1088,
  65. trt_opt_shape=608,
  66. trt_calib_mode=False,
  67. cpu_threads=1,
  68. enable_mkldnn=False,
  69. output_dir='output',
  70. threshold=0.5,
  71. save_images=False,
  72. save_mot_txts=False, ):
  73. super(JDE_Detector, self).__init__(
  74. model_dir=model_dir,
  75. device=device,
  76. run_mode=run_mode,
  77. batch_size=batch_size,
  78. trt_min_shape=trt_min_shape,
  79. trt_max_shape=trt_max_shape,
  80. trt_opt_shape=trt_opt_shape,
  81. trt_calib_mode=trt_calib_mode,
  82. cpu_threads=cpu_threads,
  83. enable_mkldnn=enable_mkldnn,
  84. output_dir=output_dir,
  85. threshold=threshold, )
  86. self.save_images = save_images
  87. self.save_mot_txts = save_mot_txts
  88. assert batch_size == 1, "MOT model only supports batch_size=1."
  89. self.det_times = Timer(with_tracker=True)
  90. self.num_classes = len(self.pred_config.labels)
  91. # tracker config
  92. assert self.pred_config.tracker, "The exported JDE Detector model should have tracker."
  93. cfg = self.pred_config.tracker
  94. min_box_area = cfg.get('min_box_area', 0.0)
  95. vertical_ratio = cfg.get('vertical_ratio', 0.0)
  96. conf_thres = cfg.get('conf_thres', 0.0)
  97. tracked_thresh = cfg.get('tracked_thresh', 0.7)
  98. metric_type = cfg.get('metric_type', 'euclidean')
  99. self.tracker = JDETracker(
  100. num_classes=self.num_classes,
  101. min_box_area=min_box_area,
  102. vertical_ratio=vertical_ratio,
  103. conf_thres=conf_thres,
  104. tracked_thresh=tracked_thresh,
  105. metric_type=metric_type)
  106. def postprocess(self, inputs, result):
  107. # postprocess output of predictor
  108. np_boxes = result['pred_dets']
  109. if np_boxes.shape[0] <= 0:
  110. print('[WARNNING] No object detected.')
  111. result = {'pred_dets': np.zeros([0, 6]), 'pred_embs': None}
  112. result = {k: v for k, v in result.items() if v is not None}
  113. return result
  114. def tracking(self, det_results):
  115. pred_dets = det_results['pred_dets'] # cls_id, score, x0, y0, x1, y1
  116. pred_embs = det_results['pred_embs']
  117. online_targets_dict = self.tracker.update(pred_dets, pred_embs)
  118. online_tlwhs = defaultdict(list)
  119. online_scores = defaultdict(list)
  120. online_ids = defaultdict(list)
  121. for cls_id in range(self.num_classes):
  122. online_targets = online_targets_dict[cls_id]
  123. for t in online_targets:
  124. tlwh = t.tlwh
  125. tid = t.track_id
  126. tscore = t.score
  127. if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
  128. if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  129. 3] > self.tracker.vertical_ratio:
  130. continue
  131. online_tlwhs[cls_id].append(tlwh)
  132. online_ids[cls_id].append(tid)
  133. online_scores[cls_id].append(tscore)
  134. return online_tlwhs, online_scores, online_ids
  135. def predict(self, repeats=1):
  136. '''
  137. Args:
  138. repeats (int): repeats number for prediction
  139. Returns:
  140. result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box,
  141. matix element:[class, score, x_min, y_min, x_max, y_max]
  142. FairMOT(JDE)'s result include 'pred_embs': np.ndarray:
  143. shape: [N, 128]
  144. '''
  145. # model prediction
  146. np_pred_dets, np_pred_embs = None, None
  147. for i in range(repeats):
  148. self.predictor.run()
  149. output_names = self.predictor.get_output_names()
  150. boxes_tensor = self.predictor.get_output_handle(output_names[0])
  151. np_pred_dets = boxes_tensor.copy_to_cpu()
  152. embs_tensor = self.predictor.get_output_handle(output_names[1])
  153. np_pred_embs = embs_tensor.copy_to_cpu()
  154. result = dict(pred_dets=np_pred_dets, pred_embs=np_pred_embs)
  155. return result
  156. def predict_image(self,
  157. image_list,
  158. run_benchmark=False,
  159. repeats=1,
  160. visual=True,
  161. seq_name=None):
  162. mot_results = []
  163. num_classes = self.num_classes
  164. image_list.sort()
  165. ids2names = self.pred_config.labels
  166. data_type = 'mcmot' if num_classes > 1 else 'mot'
  167. for frame_id, img_file in enumerate(image_list):
  168. batch_image_list = [img_file] # bs=1 in MOT model
  169. if run_benchmark:
  170. # preprocess
  171. inputs = self.preprocess(batch_image_list) # warmup
  172. self.det_times.preprocess_time_s.start()
  173. inputs = self.preprocess(batch_image_list)
  174. self.det_times.preprocess_time_s.end()
  175. # model prediction
  176. result_warmup = self.predict(repeats=repeats) # warmup
  177. self.det_times.inference_time_s.start()
  178. result = self.predict(repeats=repeats)
  179. self.det_times.inference_time_s.end(repeats=repeats)
  180. # postprocess
  181. result_warmup = self.postprocess(inputs, result) # warmup
  182. self.det_times.postprocess_time_s.start()
  183. det_result = self.postprocess(inputs, result)
  184. self.det_times.postprocess_time_s.end()
  185. # tracking
  186. result_warmup = self.tracking(det_result)
  187. self.det_times.tracking_time_s.start()
  188. online_tlwhs, online_scores, online_ids = self.tracking(
  189. det_result)
  190. self.det_times.tracking_time_s.end()
  191. self.det_times.img_num += 1
  192. cm, gm, gu = get_current_memory_mb()
  193. self.cpu_mem += cm
  194. self.gpu_mem += gm
  195. self.gpu_util += gu
  196. else:
  197. self.det_times.preprocess_time_s.start()
  198. inputs = self.preprocess(batch_image_list)
  199. self.det_times.preprocess_time_s.end()
  200. self.det_times.inference_time_s.start()
  201. result = self.predict()
  202. self.det_times.inference_time_s.end()
  203. self.det_times.postprocess_time_s.start()
  204. det_result = self.postprocess(inputs, result)
  205. self.det_times.postprocess_time_s.end()
  206. # tracking process
  207. self.det_times.tracking_time_s.start()
  208. online_tlwhs, online_scores, online_ids = self.tracking(
  209. det_result)
  210. self.det_times.tracking_time_s.end()
  211. self.det_times.img_num += 1
  212. if visual:
  213. if len(image_list) > 1 and frame_id % 10 == 0:
  214. print('Tracking frame {}'.format(frame_id))
  215. frame, _ = decode_image(img_file, {})
  216. im = plot_tracking_dict(
  217. frame,
  218. num_classes,
  219. online_tlwhs,
  220. online_ids,
  221. online_scores,
  222. frame_id=frame_id,
  223. ids2names=ids2names)
  224. if seq_name is None:
  225. seq_name = image_list[0].split('/')[-2]
  226. save_dir = os.path.join(self.output_dir, seq_name)
  227. if not os.path.exists(save_dir):
  228. os.makedirs(save_dir)
  229. cv2.imwrite(
  230. os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
  231. mot_results.append([online_tlwhs, online_scores, online_ids])
  232. return mot_results
  233. def predict_video(self, video_file, camera_id):
  234. video_out_name = 'mot_output.mp4'
  235. if camera_id != -1:
  236. capture = cv2.VideoCapture(camera_id)
  237. else:
  238. capture = cv2.VideoCapture(video_file)
  239. video_out_name = os.path.split(video_file)[-1]
  240. # Get Video info : resolution, fps, frame count
  241. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  242. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  243. fps = int(capture.get(cv2.CAP_PROP_FPS))
  244. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  245. print("fps: %d, frame_count: %d" % (fps, frame_count))
  246. if not os.path.exists(self.output_dir):
  247. os.makedirs(self.output_dir)
  248. out_path = os.path.join(self.output_dir, video_out_name)
  249. video_format = 'mp4v'
  250. fourcc = cv2.VideoWriter_fourcc(*video_format)
  251. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  252. frame_id = 1
  253. timer = MOTTimer()
  254. results = defaultdict(list) # support single class and multi classes
  255. num_classes = self.num_classes
  256. data_type = 'mcmot' if num_classes > 1 else 'mot'
  257. ids2names = self.pred_config.labels
  258. while (1):
  259. ret, frame = capture.read()
  260. if not ret:
  261. break
  262. if frame_id % 10 == 0:
  263. print('Tracking frame: %d' % (frame_id))
  264. frame_id += 1
  265. timer.tic()
  266. seq_name = video_out_name.split('.')[0]
  267. mot_results = self.predict_image(
  268. [frame[:, :, ::-1]], visual=False, seq_name=seq_name)
  269. timer.toc()
  270. online_tlwhs, online_scores, online_ids = mot_results[0]
  271. for cls_id in range(num_classes):
  272. results[cls_id].append(
  273. (frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
  274. online_ids[cls_id]))
  275. fps = 1. / timer.duration
  276. im = plot_tracking_dict(
  277. frame,
  278. num_classes,
  279. online_tlwhs,
  280. online_ids,
  281. online_scores,
  282. frame_id=frame_id,
  283. fps=fps,
  284. ids2names=ids2names)
  285. writer.write(im)
  286. if camera_id != -1:
  287. cv2.imshow('Mask Detection', im)
  288. if cv2.waitKey(1) & 0xFF == ord('q'):
  289. break
  290. if self.save_mot_txts:
  291. result_filename = os.path.join(
  292. self.output_dir, video_out_name.split('.')[-2] + '.txt')
  293. write_mot_results(result_filename, results, data_type, num_classes)
  294. writer.release()
  295. def main():
  296. detector = JDE_Detector(
  297. FLAGS.model_dir,
  298. tracker_config=None,
  299. device=FLAGS.device,
  300. run_mode=FLAGS.run_mode,
  301. batch_size=1,
  302. trt_min_shape=FLAGS.trt_min_shape,
  303. trt_max_shape=FLAGS.trt_max_shape,
  304. trt_opt_shape=FLAGS.trt_opt_shape,
  305. trt_calib_mode=FLAGS.trt_calib_mode,
  306. cpu_threads=FLAGS.cpu_threads,
  307. enable_mkldnn=FLAGS.enable_mkldnn,
  308. output_dir=FLAGS.output_dir,
  309. threshold=FLAGS.threshold,
  310. save_images=FLAGS.save_images,
  311. save_mot_txts=FLAGS.save_mot_txts)
  312. # predict from video file or camera video stream
  313. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  314. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  315. else:
  316. # predict from image
  317. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  318. detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
  319. if not FLAGS.run_benchmark:
  320. detector.det_times.info(average=True)
  321. else:
  322. mode = FLAGS.run_mode
  323. model_dir = FLAGS.model_dir
  324. model_info = {
  325. 'model_name': model_dir.strip('/').split('/')[-1],
  326. 'precision': mode.split('_')[-1]
  327. }
  328. bench_log(detector, img_list, model_info, name='MOT')
  329. if __name__ == '__main__':
  330. paddle.enable_static()
  331. parser = argsparser()
  332. FLAGS = parser.parse_args()
  333. print_arguments(FLAGS)
  334. FLAGS.device = FLAGS.device.upper()
  335. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  336. ], "device should be CPU, GPU or XPU"
  337. main()