mot_sde_infer.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import yaml
  17. import cv2
  18. import numpy as np
  19. from collections import defaultdict
  20. import paddle
  21. from benchmark_utils import PaddleInferBenchmark
  22. from preprocess import decode_image
  23. from utils import argsparser, Timer, get_current_memory_mb
  24. from infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
  25. # add python path
  26. import sys
  27. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  28. sys.path.insert(0, parent_path)
  29. from pptracking.python.mot import JDETracker, DeepSORTTracker
  30. from pptracking.python.mot.utils import MOTTimer, write_mot_results, get_crops, clip_box
  31. from pptracking.python.mot.visualize import plot_tracking, plot_tracking_dict
  32. class SDE_Detector(Detector):
  33. """
  34. Args:
  35. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  36. tracker_config (str): tracker config path
  37. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  38. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  39. batch_size (int): size of pre batch in inference
  40. trt_min_shape (int): min shape for dynamic shape in trt
  41. trt_max_shape (int): max shape for dynamic shape in trt
  42. trt_opt_shape (int): opt shape for dynamic shape in trt
  43. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  44. calibration, trt_calib_mode need to set True
  45. cpu_threads (int): cpu threads
  46. enable_mkldnn (bool): whether to open MKLDNN
  47. output_dir (string): The path of output, default as 'output'
  48. threshold (float): Score threshold of the detected bbox, default as 0.5
  49. save_images (bool): Whether to save visualization image results, default as False
  50. save_mot_txts (bool): Whether to save tracking results (txt), default as False
  51. reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
  52. """
  53. def __init__(self,
  54. model_dir,
  55. tracker_config,
  56. device='CPU',
  57. run_mode='paddle',
  58. batch_size=1,
  59. trt_min_shape=1,
  60. trt_max_shape=1280,
  61. trt_opt_shape=640,
  62. trt_calib_mode=False,
  63. cpu_threads=1,
  64. enable_mkldnn=False,
  65. output_dir='output',
  66. threshold=0.5,
  67. save_images=False,
  68. save_mot_txts=False,
  69. reid_model_dir=None):
  70. super(SDE_Detector, self).__init__(
  71. model_dir=model_dir,
  72. device=device,
  73. run_mode=run_mode,
  74. batch_size=batch_size,
  75. trt_min_shape=trt_min_shape,
  76. trt_max_shape=trt_max_shape,
  77. trt_opt_shape=trt_opt_shape,
  78. trt_calib_mode=trt_calib_mode,
  79. cpu_threads=cpu_threads,
  80. enable_mkldnn=enable_mkldnn,
  81. output_dir=output_dir,
  82. threshold=threshold, )
  83. self.save_images = save_images
  84. self.save_mot_txts = save_mot_txts
  85. assert batch_size == 1, "MOT model only supports batch_size=1."
  86. self.det_times = Timer(with_tracker=True)
  87. self.num_classes = len(self.pred_config.labels)
  88. # reid config
  89. self.use_reid = False if reid_model_dir is None else True
  90. if self.use_reid:
  91. self.reid_pred_config = self.set_config(reid_model_dir)
  92. self.reid_predictor, self.config = load_predictor(
  93. reid_model_dir,
  94. run_mode=run_mode,
  95. batch_size=50, # reid_batch_size
  96. min_subgraph_size=self.reid_pred_config.min_subgraph_size,
  97. device=device,
  98. use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
  99. trt_min_shape=trt_min_shape,
  100. trt_max_shape=trt_max_shape,
  101. trt_opt_shape=trt_opt_shape,
  102. trt_calib_mode=trt_calib_mode,
  103. cpu_threads=cpu_threads,
  104. enable_mkldnn=enable_mkldnn)
  105. else:
  106. self.reid_pred_config = None
  107. self.reid_predictor = None
  108. assert tracker_config is not None, 'Note that tracker_config should be set.'
  109. self.tracker_config = tracker_config
  110. tracker_cfg = yaml.safe_load(open(self.tracker_config))
  111. cfg = tracker_cfg[tracker_cfg['type']]
  112. # tracker config
  113. self.use_deepsort_tracker = True if tracker_cfg[
  114. 'type'] == 'DeepSORTTracker' else False
  115. if self.use_deepsort_tracker:
  116. # use DeepSORTTracker
  117. if self.reid_pred_config is not None and hasattr(
  118. self.reid_pred_config, 'tracker'):
  119. cfg = self.reid_pred_config.tracker
  120. budget = cfg.get('budget', 100)
  121. max_age = cfg.get('max_age', 30)
  122. max_iou_distance = cfg.get('max_iou_distance', 0.7)
  123. matching_threshold = cfg.get('matching_threshold', 0.2)
  124. min_box_area = cfg.get('min_box_area', 0)
  125. vertical_ratio = cfg.get('vertical_ratio', 0)
  126. self.tracker = DeepSORTTracker(
  127. budget=budget,
  128. max_age=max_age,
  129. max_iou_distance=max_iou_distance,
  130. matching_threshold=matching_threshold,
  131. min_box_area=min_box_area,
  132. vertical_ratio=vertical_ratio, )
  133. else:
  134. # use ByteTracker
  135. use_byte = cfg.get('use_byte', False)
  136. det_thresh = cfg.get('det_thresh', 0.3)
  137. min_box_area = cfg.get('min_box_area', 0)
  138. vertical_ratio = cfg.get('vertical_ratio', 0)
  139. match_thres = cfg.get('match_thres', 0.9)
  140. conf_thres = cfg.get('conf_thres', 0.6)
  141. low_conf_thres = cfg.get('low_conf_thres', 0.1)
  142. self.tracker = JDETracker(
  143. use_byte=use_byte,
  144. det_thresh=det_thresh,
  145. num_classes=self.num_classes,
  146. min_box_area=min_box_area,
  147. vertical_ratio=vertical_ratio,
  148. match_thres=match_thres,
  149. conf_thres=conf_thres,
  150. low_conf_thres=low_conf_thres, )
  151. def postprocess(self, inputs, result):
  152. # postprocess output of predictor
  153. np_boxes_num = result['boxes_num']
  154. if np_boxes_num[0] <= 0:
  155. print('[WARNNING] No object detected.')
  156. result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
  157. result = {k: v for k, v in result.items() if v is not None}
  158. return result
  159. def reidprocess(self, det_results, repeats=1):
  160. pred_dets = det_results['boxes']
  161. pred_xyxys = pred_dets[:, 2:6]
  162. ori_image = det_results['ori_image']
  163. ori_image_shape = ori_image.shape[:2]
  164. pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape)
  165. if len(keep_idx[0]) == 0:
  166. det_results['boxes'] = np.zeros((1, 6), dtype=np.float32)
  167. det_results['embeddings'] = None
  168. return det_results
  169. pred_dets = pred_dets[keep_idx[0]]
  170. pred_xyxys = pred_dets[:, 2:6]
  171. w, h = self.tracker.input_size
  172. crops = get_crops(pred_xyxys, ori_image, w, h)
  173. # to keep fast speed, only use topk crops
  174. crops = crops[:50] # reid_batch_size
  175. det_results['crops'] = np.array(crops).astype('float32')
  176. det_results['boxes'] = pred_dets[:50]
  177. input_names = self.reid_predictor.get_input_names()
  178. for i in range(len(input_names)):
  179. input_tensor = self.reid_predictor.get_input_handle(input_names[i])
  180. input_tensor.copy_from_cpu(det_results[input_names[i]])
  181. # model prediction
  182. for i in range(repeats):
  183. self.reid_predictor.run()
  184. output_names = self.reid_predictor.get_output_names()
  185. feature_tensor = self.reid_predictor.get_output_handle(output_names[
  186. 0])
  187. pred_embs = feature_tensor.copy_to_cpu()
  188. det_results['embeddings'] = pred_embs
  189. return det_results
  190. def tracking(self, det_results):
  191. pred_dets = det_results['boxes'] # 'cls_id, score, x0, y0, x1, y1'
  192. pred_embs = det_results.get('embeddings', None)
  193. if self.use_deepsort_tracker:
  194. # use DeepSORTTracker, only support singe class
  195. self.tracker.predict()
  196. online_targets = self.tracker.update(pred_dets, pred_embs)
  197. online_tlwhs, online_scores, online_ids = [], [], []
  198. for t in online_targets:
  199. if not t.is_confirmed() or t.time_since_update > 1:
  200. continue
  201. tlwh = t.to_tlwh()
  202. tscore = t.score
  203. tid = t.track_id
  204. if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  205. 3] > self.tracker.vertical_ratio:
  206. continue
  207. online_tlwhs.append(tlwh)
  208. online_scores.append(tscore)
  209. online_ids.append(tid)
  210. tracking_outs = {
  211. 'online_tlwhs': online_tlwhs,
  212. 'online_scores': online_scores,
  213. 'online_ids': online_ids,
  214. }
  215. return tracking_outs
  216. else:
  217. # use ByteTracker, support multiple class
  218. online_tlwhs = defaultdict(list)
  219. online_scores = defaultdict(list)
  220. online_ids = defaultdict(list)
  221. online_targets_dict = self.tracker.update(pred_dets, pred_embs)
  222. for cls_id in range(self.num_classes):
  223. online_targets = online_targets_dict[cls_id]
  224. for t in online_targets:
  225. tlwh = t.tlwh
  226. tid = t.track_id
  227. tscore = t.score
  228. if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
  229. continue
  230. if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  231. 3] > self.tracker.vertical_ratio:
  232. continue
  233. online_tlwhs[cls_id].append(tlwh)
  234. online_ids[cls_id].append(tid)
  235. online_scores[cls_id].append(tscore)
  236. tracking_outs = {
  237. 'online_tlwhs': online_tlwhs,
  238. 'online_scores': online_scores,
  239. 'online_ids': online_ids,
  240. }
  241. return tracking_outs
  242. def predict_image(self,
  243. image_list,
  244. run_benchmark=False,
  245. repeats=1,
  246. visual=True,
  247. seq_name=None):
  248. num_classes = self.num_classes
  249. image_list.sort()
  250. ids2names = self.pred_config.labels
  251. mot_results = []
  252. for frame_id, img_file in enumerate(image_list):
  253. batch_image_list = [img_file] # bs=1 in MOT model
  254. frame, _ = decode_image(img_file, {})
  255. if run_benchmark:
  256. # preprocess
  257. inputs = self.preprocess(batch_image_list) # warmup
  258. self.det_times.preprocess_time_s.start()
  259. inputs = self.preprocess(batch_image_list)
  260. self.det_times.preprocess_time_s.end()
  261. # model prediction
  262. result_warmup = self.predict(repeats=repeats) # warmup
  263. self.det_times.inference_time_s.start()
  264. result = self.predict(repeats=repeats)
  265. self.det_times.inference_time_s.end(repeats=repeats)
  266. # postprocess
  267. result_warmup = self.postprocess(inputs, result) # warmup
  268. self.det_times.postprocess_time_s.start()
  269. det_result = self.postprocess(inputs, result)
  270. self.det_times.postprocess_time_s.end()
  271. # tracking
  272. if self.use_reid:
  273. det_result['frame_id'] = frame_id
  274. det_result['seq_name'] = seq_name
  275. det_result['ori_image'] = frame
  276. det_result = self.reidprocess(det_result)
  277. result_warmup = self.tracking(det_result)
  278. self.det_times.tracking_time_s.start()
  279. if self.use_reid:
  280. det_result = self.reidprocess(det_result)
  281. tracking_outs = self.tracking(det_result)
  282. self.det_times.tracking_time_s.end()
  283. self.det_times.img_num += 1
  284. cm, gm, gu = get_current_memory_mb()
  285. self.cpu_mem += cm
  286. self.gpu_mem += gm
  287. self.gpu_util += gu
  288. else:
  289. self.det_times.preprocess_time_s.start()
  290. inputs = self.preprocess(batch_image_list)
  291. self.det_times.preprocess_time_s.end()
  292. self.det_times.inference_time_s.start()
  293. result = self.predict()
  294. self.det_times.inference_time_s.end()
  295. self.det_times.postprocess_time_s.start()
  296. det_result = self.postprocess(inputs, result)
  297. self.det_times.postprocess_time_s.end()
  298. # tracking process
  299. self.det_times.tracking_time_s.start()
  300. if self.use_reid:
  301. det_result['frame_id'] = frame_id
  302. det_result['seq_name'] = seq_name
  303. det_result['ori_image'] = frame
  304. det_result = self.reidprocess(det_result)
  305. tracking_outs = self.tracking(det_result)
  306. self.det_times.tracking_time_s.end()
  307. self.det_times.img_num += 1
  308. online_tlwhs = tracking_outs['online_tlwhs']
  309. online_scores = tracking_outs['online_scores']
  310. online_ids = tracking_outs['online_ids']
  311. mot_results.append([online_tlwhs, online_scores, online_ids])
  312. if visual:
  313. if len(image_list) > 1 and frame_id % 10 == 0:
  314. print('Tracking frame {}'.format(frame_id))
  315. frame, _ = decode_image(img_file, {})
  316. if isinstance(online_tlwhs, defaultdict):
  317. im = plot_tracking_dict(
  318. frame,
  319. num_classes,
  320. online_tlwhs,
  321. online_ids,
  322. online_scores,
  323. frame_id=frame_id,
  324. ids2names=[])
  325. else:
  326. im = plot_tracking(
  327. frame,
  328. online_tlwhs,
  329. online_ids,
  330. online_scores,
  331. frame_id=frame_id)
  332. save_dir = os.path.join(self.output_dir, seq_name)
  333. if not os.path.exists(save_dir):
  334. os.makedirs(save_dir)
  335. cv2.imwrite(
  336. os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
  337. return mot_results
  338. def predict_video(self, video_file, camera_id):
  339. video_out_name = 'output.mp4'
  340. if camera_id != -1:
  341. capture = cv2.VideoCapture(camera_id)
  342. else:
  343. capture = cv2.VideoCapture(video_file)
  344. video_out_name = os.path.split(video_file)[-1]
  345. # Get Video info : resolution, fps, frame count
  346. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  347. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  348. fps = int(capture.get(cv2.CAP_PROP_FPS))
  349. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  350. print("fps: %d, frame_count: %d" % (fps, frame_count))
  351. if not os.path.exists(self.output_dir):
  352. os.makedirs(self.output_dir)
  353. out_path = os.path.join(self.output_dir, video_out_name)
  354. video_format = 'mp4v'
  355. fourcc = cv2.VideoWriter_fourcc(*video_format)
  356. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  357. frame_id = 1
  358. timer = MOTTimer()
  359. results = defaultdict(list)
  360. num_classes = self.num_classes
  361. data_type = 'mcmot' if num_classes > 1 else 'mot'
  362. ids2names = self.pred_config.labels
  363. while (1):
  364. ret, frame = capture.read()
  365. if not ret:
  366. break
  367. if frame_id % 10 == 0:
  368. print('Tracking frame: %d' % (frame_id))
  369. frame_id += 1
  370. timer.tic()
  371. seq_name = video_out_name.split('.')[0]
  372. mot_results = self.predict_image(
  373. [frame[:, :, ::-1]], visual=False, seq_name=seq_name)
  374. timer.toc()
  375. # bs=1 in MOT model
  376. online_tlwhs, online_scores, online_ids = mot_results[0]
  377. fps = 1. / timer.duration
  378. if self.use_deepsort_tracker:
  379. # use DeepSORTTracker, only support singe class
  380. results[0].append(
  381. (frame_id + 1, online_tlwhs, online_scores, online_ids))
  382. im = plot_tracking(
  383. frame,
  384. online_tlwhs,
  385. online_ids,
  386. online_scores,
  387. frame_id=frame_id,
  388. fps=fps)
  389. else:
  390. # use ByteTracker, support multiple class
  391. for cls_id in range(num_classes):
  392. results[cls_id].append(
  393. (frame_id + 1, online_tlwhs[cls_id],
  394. online_scores[cls_id], online_ids[cls_id]))
  395. im = plot_tracking_dict(
  396. frame,
  397. num_classes,
  398. online_tlwhs,
  399. online_ids,
  400. online_scores,
  401. frame_id=frame_id,
  402. fps=fps,
  403. ids2names=ids2names)
  404. writer.write(im)
  405. if camera_id != -1:
  406. cv2.imshow('Mask Detection', im)
  407. if cv2.waitKey(1) & 0xFF == ord('q'):
  408. break
  409. if self.save_mot_txts:
  410. result_filename = os.path.join(
  411. self.output_dir, video_out_name.split('.')[-2] + '.txt')
  412. write_mot_results(result_filename, results)
  413. writer.release()
  414. def main():
  415. deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
  416. with open(deploy_file) as f:
  417. yml_conf = yaml.safe_load(f)
  418. arch = yml_conf['arch']
  419. detector = SDE_Detector(
  420. FLAGS.model_dir,
  421. tracker_config=FLAGS.tracker_config,
  422. device=FLAGS.device,
  423. run_mode=FLAGS.run_mode,
  424. batch_size=1,
  425. trt_min_shape=FLAGS.trt_min_shape,
  426. trt_max_shape=FLAGS.trt_max_shape,
  427. trt_opt_shape=FLAGS.trt_opt_shape,
  428. trt_calib_mode=FLAGS.trt_calib_mode,
  429. cpu_threads=FLAGS.cpu_threads,
  430. enable_mkldnn=FLAGS.enable_mkldnn,
  431. output_dir=FLAGS.output_dir,
  432. threshold=FLAGS.threshold,
  433. save_images=FLAGS.save_images,
  434. save_mot_txts=FLAGS.save_mot_txts, )
  435. # predict from video file or camera video stream
  436. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  437. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  438. else:
  439. # predict from image
  440. if FLAGS.image_dir is None and FLAGS.image_file is not None:
  441. assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
  442. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  443. seq_name = FLAGS.image_dir.split('/')[-1]
  444. detector.predict_image(
  445. img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
  446. if not FLAGS.run_benchmark:
  447. detector.det_times.info(average=True)
  448. else:
  449. mode = FLAGS.run_mode
  450. model_dir = FLAGS.model_dir
  451. model_info = {
  452. 'model_name': model_dir.strip('/').split('/')[-1],
  453. 'precision': mode.split('_')[-1]
  454. }
  455. bench_log(detector, img_list, model_info, name='MOT')
  456. if __name__ == '__main__':
  457. paddle.enable_static()
  458. parser = argsparser()
  459. FLAGS = parser.parse_args()
  460. print_arguments(FLAGS)
  461. FLAGS.device = FLAGS.device.upper()
  462. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  463. ], "device should be CPU, GPU or XPU"
  464. main()