mot_jde_infer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import yaml
  17. import cv2
  18. import numpy as np
  19. from collections import defaultdict
  20. import paddle
  21. from benchmark_utils import PaddleInferBenchmark
  22. from preprocess import decode_image
  23. from mot_utils import argsparser, Timer, get_current_memory_mb
  24. from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig
  25. # add python path
  26. import sys
  27. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  28. sys.path.insert(0, parent_path)
  29. from mot import JDETracker
  30. from mot.utils import MOTTimer, write_mot_results, flow_statistic
  31. from mot.visualize import plot_tracking, plot_tracking_dict
  32. # Global dictionary
  33. MOT_JDE_SUPPORT_MODELS = {
  34. 'JDE',
  35. 'FairMOT',
  36. }
  37. class JDE_Detector(Detector):
  38. """
  39. Args:
  40. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  41. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  42. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  43. batch_size (int): size of pre batch in inference
  44. trt_min_shape (int): min shape for dynamic shape in trt
  45. trt_max_shape (int): max shape for dynamic shape in trt
  46. trt_opt_shape (int): opt shape for dynamic shape in trt
  47. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  48. calibration, trt_calib_mode need to set True
  49. cpu_threads (int): cpu threads
  50. enable_mkldnn (bool): whether to open MKLDNN
  51. output_dir (string): The path of output, default as 'output'
  52. threshold (float): Score threshold of the detected bbox, default as 0.5
  53. save_images (bool): Whether to save visualization image results, default as False
  54. save_mot_txts (bool): Whether to save tracking results (txt), default as False
  55. draw_center_traj (bool): Whether drawing the trajectory of center, default as False
  56. secs_interval (int): The seconds interval to count after tracking, default as 10
  57. do_entrance_counting(bool): Whether counting the numbers of identifiers entering
  58. or getting out from the entrance, default as False,only support single class
  59. counting in MOT.
  60. """
  61. def __init__(
  62. self,
  63. model_dir,
  64. tracker_config=None,
  65. device='CPU',
  66. run_mode='paddle',
  67. batch_size=1,
  68. trt_min_shape=1,
  69. trt_max_shape=1088,
  70. trt_opt_shape=608,
  71. trt_calib_mode=False,
  72. cpu_threads=1,
  73. enable_mkldnn=False,
  74. output_dir='output',
  75. threshold=0.5,
  76. save_images=False,
  77. save_mot_txts=False,
  78. draw_center_traj=False,
  79. secs_interval=10,
  80. do_entrance_counting=False, ):
  81. super(JDE_Detector, self).__init__(
  82. model_dir=model_dir,
  83. device=device,
  84. run_mode=run_mode,
  85. batch_size=batch_size,
  86. trt_min_shape=trt_min_shape,
  87. trt_max_shape=trt_max_shape,
  88. trt_opt_shape=trt_opt_shape,
  89. trt_calib_mode=trt_calib_mode,
  90. cpu_threads=cpu_threads,
  91. enable_mkldnn=enable_mkldnn,
  92. output_dir=output_dir,
  93. threshold=threshold, )
  94. self.save_images = save_images
  95. self.save_mot_txts = save_mot_txts
  96. self.draw_center_traj = draw_center_traj
  97. self.secs_interval = secs_interval
  98. self.do_entrance_counting = do_entrance_counting
  99. assert batch_size == 1, "MOT model only supports batch_size=1."
  100. self.det_times = Timer(with_tracker=True)
  101. self.num_classes = len(self.pred_config.labels)
  102. # tracker config
  103. assert self.pred_config.tracker, "The exported JDE Detector model should have tracker."
  104. cfg = self.pred_config.tracker
  105. min_box_area = cfg.get('min_box_area', 0.0)
  106. vertical_ratio = cfg.get('vertical_ratio', 0.0)
  107. conf_thres = cfg.get('conf_thres', 0.0)
  108. tracked_thresh = cfg.get('tracked_thresh', 0.7)
  109. metric_type = cfg.get('metric_type', 'euclidean')
  110. self.tracker = JDETracker(
  111. num_classes=self.num_classes,
  112. min_box_area=min_box_area,
  113. vertical_ratio=vertical_ratio,
  114. conf_thres=conf_thres,
  115. tracked_thresh=tracked_thresh,
  116. metric_type=metric_type)
  117. def postprocess(self, inputs, result):
  118. # postprocess output of predictor
  119. np_boxes = result['pred_dets']
  120. if np_boxes.shape[0] <= 0:
  121. print('[WARNNING] No object detected.')
  122. result = {'pred_dets': np.zeros([0, 6]), 'pred_embs': None}
  123. result = {k: v for k, v in result.items() if v is not None}
  124. return result
  125. def tracking(self, det_results):
  126. pred_dets = det_results['pred_dets'] # cls_id, score, x0, y0, x1, y1
  127. pred_embs = det_results['pred_embs']
  128. online_targets_dict = self.tracker.update(pred_dets, pred_embs)
  129. online_tlwhs = defaultdict(list)
  130. online_scores = defaultdict(list)
  131. online_ids = defaultdict(list)
  132. for cls_id in range(self.num_classes):
  133. online_targets = online_targets_dict[cls_id]
  134. for t in online_targets:
  135. tlwh = t.tlwh
  136. tid = t.track_id
  137. tscore = t.score
  138. if tlwh[2] * tlwh[3] <= self.tracker.min_box_area: continue
  139. if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  140. 3] > self.tracker.vertical_ratio:
  141. continue
  142. online_tlwhs[cls_id].append(tlwh)
  143. online_ids[cls_id].append(tid)
  144. online_scores[cls_id].append(tscore)
  145. return online_tlwhs, online_scores, online_ids
  146. def predict(self, repeats=1):
  147. '''
  148. Args:
  149. repeats (int): repeats number for prediction
  150. Returns:
  151. result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box,
  152. matix element:[class, score, x_min, y_min, x_max, y_max]
  153. FairMOT(JDE)'s result include 'pred_embs': np.ndarray:
  154. shape: [N, 128]
  155. '''
  156. # model prediction
  157. np_pred_dets, np_pred_embs = None, None
  158. for i in range(repeats):
  159. self.predictor.run()
  160. output_names = self.predictor.get_output_names()
  161. boxes_tensor = self.predictor.get_output_handle(output_names[0])
  162. np_pred_dets = boxes_tensor.copy_to_cpu()
  163. embs_tensor = self.predictor.get_output_handle(output_names[1])
  164. np_pred_embs = embs_tensor.copy_to_cpu()
  165. result = dict(pred_dets=np_pred_dets, pred_embs=np_pred_embs)
  166. return result
  167. def predict_image(self,
  168. image_list,
  169. run_benchmark=False,
  170. repeats=1,
  171. visual=True,
  172. seq_name=None):
  173. mot_results = []
  174. num_classes = self.num_classes
  175. image_list.sort()
  176. ids2names = self.pred_config.labels
  177. data_type = 'mcmot' if num_classes > 1 else 'mot'
  178. for frame_id, img_file in enumerate(image_list):
  179. batch_image_list = [img_file] # bs=1 in MOT model
  180. if run_benchmark:
  181. # preprocess
  182. inputs = self.preprocess(batch_image_list) # warmup
  183. self.det_times.preprocess_time_s.start()
  184. inputs = self.preprocess(batch_image_list)
  185. self.det_times.preprocess_time_s.end()
  186. # model prediction
  187. result_warmup = self.predict(repeats=repeats) # warmup
  188. self.det_times.inference_time_s.start()
  189. result = self.predict(repeats=repeats)
  190. self.det_times.inference_time_s.end(repeats=repeats)
  191. # postprocess
  192. result_warmup = self.postprocess(inputs, result) # warmup
  193. self.det_times.postprocess_time_s.start()
  194. det_result = self.postprocess(inputs, result)
  195. self.det_times.postprocess_time_s.end()
  196. # tracking
  197. result_warmup = self.tracking(det_result)
  198. self.det_times.tracking_time_s.start()
  199. online_tlwhs, online_scores, online_ids = self.tracking(
  200. det_result)
  201. self.det_times.tracking_time_s.end()
  202. self.det_times.img_num += 1
  203. cm, gm, gu = get_current_memory_mb()
  204. self.cpu_mem += cm
  205. self.gpu_mem += gm
  206. self.gpu_util += gu
  207. else:
  208. self.det_times.preprocess_time_s.start()
  209. inputs = self.preprocess(batch_image_list)
  210. self.det_times.preprocess_time_s.end()
  211. self.det_times.inference_time_s.start()
  212. result = self.predict()
  213. self.det_times.inference_time_s.end()
  214. self.det_times.postprocess_time_s.start()
  215. det_result = self.postprocess(inputs, result)
  216. self.det_times.postprocess_time_s.end()
  217. # tracking process
  218. self.det_times.tracking_time_s.start()
  219. online_tlwhs, online_scores, online_ids = self.tracking(
  220. det_result)
  221. self.det_times.tracking_time_s.end()
  222. self.det_times.img_num += 1
  223. if visual:
  224. if len(image_list) > 1 and frame_id % 10 == 0:
  225. print('Tracking frame {}'.format(frame_id))
  226. frame, _ = decode_image(img_file, {})
  227. im = plot_tracking_dict(
  228. frame,
  229. num_classes,
  230. online_tlwhs,
  231. online_ids,
  232. online_scores,
  233. frame_id=frame_id,
  234. ids2names=ids2names)
  235. if seq_name is None:
  236. seq_name = image_list[0].split('/')[-2]
  237. save_dir = os.path.join(self.output_dir, seq_name)
  238. if not os.path.exists(save_dir):
  239. os.makedirs(save_dir)
  240. cv2.imwrite(
  241. os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
  242. mot_results.append([online_tlwhs, online_scores, online_ids])
  243. return mot_results
  244. def predict_video(self, video_file, camera_id):
  245. video_out_name = 'mot_output.mp4'
  246. if camera_id != -1:
  247. capture = cv2.VideoCapture(camera_id)
  248. else:
  249. capture = cv2.VideoCapture(video_file)
  250. video_out_name = os.path.split(video_file)[-1]
  251. # Get Video info : resolution, fps, frame count
  252. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  253. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  254. fps = int(capture.get(cv2.CAP_PROP_FPS))
  255. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  256. print("fps: %d, frame_count: %d" % (fps, frame_count))
  257. if not os.path.exists(self.output_dir):
  258. os.makedirs(self.output_dir)
  259. out_path = os.path.join(self.output_dir, video_out_name)
  260. video_format = 'mp4v'
  261. fourcc = cv2.VideoWriter_fourcc(*video_format)
  262. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  263. frame_id = 1
  264. timer = MOTTimer()
  265. results = defaultdict(list) # support single class and multi classes
  266. num_classes = self.num_classes
  267. data_type = 'mcmot' if num_classes > 1 else 'mot'
  268. ids2names = self.pred_config.labels
  269. center_traj = None
  270. entrance = None
  271. records = None
  272. if self.draw_center_traj:
  273. center_traj = [{} for i in range(num_classes)]
  274. if num_classes == 1:
  275. id_set = set()
  276. interval_id_set = set()
  277. in_id_list = list()
  278. out_id_list = list()
  279. prev_center = dict()
  280. records = list()
  281. entrance = [0, height / 2., width, height / 2.]
  282. video_fps = fps
  283. while (1):
  284. ret, frame = capture.read()
  285. if not ret:
  286. break
  287. if frame_id % 10 == 0:
  288. print('Tracking frame: %d' % (frame_id))
  289. frame_id += 1
  290. timer.tic()
  291. seq_name = video_out_name.split('.')[0]
  292. mot_results = self.predict_image(
  293. [frame], visual=False, seq_name=seq_name)
  294. timer.toc()
  295. online_tlwhs, online_scores, online_ids = mot_results[0]
  296. for cls_id in range(num_classes):
  297. results[cls_id].append(
  298. (frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
  299. online_ids[cls_id]))
  300. # NOTE: just implement flow statistic for single class
  301. if num_classes == 1:
  302. result = (frame_id + 1, online_tlwhs[0], online_scores[0],
  303. online_ids[0])
  304. statistic = flow_statistic(
  305. result, self.secs_interval, self.do_entrance_counting,
  306. video_fps, entrance, id_set, interval_id_set, in_id_list,
  307. out_id_list, prev_center, records, data_type, num_classes)
  308. records = statistic['records']
  309. fps = 1. / timer.duration
  310. im = plot_tracking_dict(
  311. frame,
  312. num_classes,
  313. online_tlwhs,
  314. online_ids,
  315. online_scores,
  316. frame_id=frame_id,
  317. fps=fps,
  318. ids2names=ids2names,
  319. do_entrance_counting=self.do_entrance_counting,
  320. entrance=entrance,
  321. records=records,
  322. center_traj=center_traj)
  323. writer.write(im)
  324. if camera_id != -1:
  325. cv2.imshow('Mask Detection', im)
  326. if cv2.waitKey(1) & 0xFF == ord('q'):
  327. break
  328. if self.save_mot_txts:
  329. result_filename = os.path.join(
  330. self.output_dir, video_out_name.split('.')[-2] + '.txt')
  331. write_mot_results(result_filename, results, data_type, num_classes)
  332. if num_classes == 1:
  333. result_filename = os.path.join(
  334. self.output_dir,
  335. video_out_name.split('.')[-2] + '_flow_statistic.txt')
  336. f = open(result_filename, 'w')
  337. for line in records:
  338. f.write(line)
  339. print('Flow statistic save in {}'.format(result_filename))
  340. f.close()
  341. writer.release()
  342. def main():
  343. detector = JDE_Detector(
  344. FLAGS.model_dir,
  345. tracker_config=None,
  346. device=FLAGS.device,
  347. run_mode=FLAGS.run_mode,
  348. batch_size=1,
  349. trt_min_shape=FLAGS.trt_min_shape,
  350. trt_max_shape=FLAGS.trt_max_shape,
  351. trt_opt_shape=FLAGS.trt_opt_shape,
  352. trt_calib_mode=FLAGS.trt_calib_mode,
  353. cpu_threads=FLAGS.cpu_threads,
  354. enable_mkldnn=FLAGS.enable_mkldnn,
  355. output_dir=FLAGS.output_dir,
  356. threshold=FLAGS.threshold,
  357. save_images=FLAGS.save_images,
  358. save_mot_txts=FLAGS.save_mot_txts,
  359. draw_center_traj=FLAGS.draw_center_traj,
  360. secs_interval=FLAGS.secs_interval,
  361. do_entrance_counting=FLAGS.do_entrance_counting, )
  362. # predict from video file or camera video stream
  363. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  364. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  365. else:
  366. # predict from image
  367. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  368. detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
  369. if not FLAGS.run_benchmark:
  370. detector.det_times.info(average=True)
  371. else:
  372. mode = FLAGS.run_mode
  373. model_dir = FLAGS.model_dir
  374. model_info = {
  375. 'model_name': model_dir.strip('/').split('/')[-1],
  376. 'precision': mode.split('_')[-1]
  377. }
  378. bench_log(detector, img_list, model_info, name='MOT')
  379. if __name__ == '__main__':
  380. paddle.enable_static()
  381. parser = argsparser()
  382. FLAGS = parser.parse_args()
  383. print_arguments(FLAGS)
  384. FLAGS.device = FLAGS.device.upper()
  385. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  386. ], "device should be CPU, GPU or XPU"
  387. main()