mot_sde_infer.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import yaml
  17. import cv2
  18. import re
  19. import glob
  20. import numpy as np
  21. from collections import defaultdict
  22. import paddle
  23. from benchmark_utils import PaddleInferBenchmark
  24. from preprocess import decode_image
  25. # add python path
  26. import sys
  27. parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
  28. sys.path.insert(0, parent_path)
  29. from det_infer import Detector, get_test_images, print_arguments, bench_log, PredictConfig, load_predictor
  30. from mot_utils import argsparser, Timer, get_current_memory_mb, video2frames, _is_valid_video
  31. from mot.tracker import JDETracker, DeepSORTTracker
  32. from mot.utils import MOTTimer, write_mot_results, get_crops, clip_box, flow_statistic
  33. from mot.visualize import plot_tracking, plot_tracking_dict
  34. from mot.mtmct.utils import parse_bias
  35. from mot.mtmct.postprocess import trajectory_fusion, sub_cluster, gen_res, print_mtmct_result
  36. from mot.mtmct.postprocess import get_mtmct_matching_results, save_mtmct_crops, save_mtmct_vis_results
  37. class SDE_Detector(Detector):
  38. """
  39. Args:
  40. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  41. tracker_config (str): tracker config path
  42. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  43. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  44. batch_size (int): size of pre batch in inference
  45. trt_min_shape (int): min shape for dynamic shape in trt
  46. trt_max_shape (int): max shape for dynamic shape in trt
  47. trt_opt_shape (int): opt shape for dynamic shape in trt
  48. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  49. calibration, trt_calib_mode need to set True
  50. cpu_threads (int): cpu threads
  51. enable_mkldnn (bool): whether to open MKLDNN
  52. output_dir (string): The path of output, default as 'output'
  53. threshold (float): Score threshold of the detected bbox, default as 0.5
  54. save_images (bool): Whether to save visualization image results, default as False
  55. save_mot_txts (bool): Whether to save tracking results (txt), default as False
  56. draw_center_traj (bool): Whether drawing the trajectory of center, default as False
  57. secs_interval (int): The seconds interval to count after tracking, default as 10
  58. do_entrance_counting(bool): Whether counting the numbers of identifiers entering
  59. or getting out from the entrance, default as False,only support single class
  60. counting in MOT.
  61. reid_model_dir (str): reid model dir, default None for ByteTrack, but set for DeepSORT
  62. mtmct_dir (str): MTMCT dir, default None, set for doing MTMCT
  63. """
  64. def __init__(self,
  65. model_dir,
  66. tracker_config,
  67. device='CPU',
  68. run_mode='paddle',
  69. batch_size=1,
  70. trt_min_shape=1,
  71. trt_max_shape=1280,
  72. trt_opt_shape=640,
  73. trt_calib_mode=False,
  74. cpu_threads=1,
  75. enable_mkldnn=False,
  76. output_dir='output',
  77. threshold=0.5,
  78. save_images=False,
  79. save_mot_txts=False,
  80. draw_center_traj=False,
  81. secs_interval=10,
  82. do_entrance_counting=False,
  83. reid_model_dir=None,
  84. mtmct_dir=None):
  85. super(SDE_Detector, self).__init__(
  86. model_dir=model_dir,
  87. device=device,
  88. run_mode=run_mode,
  89. batch_size=batch_size,
  90. trt_min_shape=trt_min_shape,
  91. trt_max_shape=trt_max_shape,
  92. trt_opt_shape=trt_opt_shape,
  93. trt_calib_mode=trt_calib_mode,
  94. cpu_threads=cpu_threads,
  95. enable_mkldnn=enable_mkldnn,
  96. output_dir=output_dir,
  97. threshold=threshold, )
  98. self.save_images = save_images
  99. self.save_mot_txts = save_mot_txts
  100. self.draw_center_traj = draw_center_traj
  101. self.secs_interval = secs_interval
  102. self.do_entrance_counting = do_entrance_counting
  103. assert batch_size == 1, "MOT model only supports batch_size=1."
  104. self.det_times = Timer(with_tracker=True)
  105. self.num_classes = len(self.pred_config.labels)
  106. # reid config
  107. self.use_reid = False if reid_model_dir is None else True
  108. if self.use_reid:
  109. self.reid_pred_config = self.set_config(reid_model_dir)
  110. self.reid_predictor, self.config = load_predictor(
  111. reid_model_dir,
  112. run_mode=run_mode,
  113. batch_size=50, # reid_batch_size
  114. min_subgraph_size=self.reid_pred_config.min_subgraph_size,
  115. device=device,
  116. use_dynamic_shape=self.reid_pred_config.use_dynamic_shape,
  117. trt_min_shape=trt_min_shape,
  118. trt_max_shape=trt_max_shape,
  119. trt_opt_shape=trt_opt_shape,
  120. trt_calib_mode=trt_calib_mode,
  121. cpu_threads=cpu_threads,
  122. enable_mkldnn=enable_mkldnn)
  123. else:
  124. self.reid_pred_config = None
  125. self.reid_predictor = None
  126. assert tracker_config is not None, 'Note that tracker_config should be set.'
  127. self.tracker_config = tracker_config
  128. tracker_cfg = yaml.safe_load(open(self.tracker_config))
  129. cfg = tracker_cfg[tracker_cfg['type']]
  130. # tracker config
  131. self.use_deepsort_tracker = True if tracker_cfg[
  132. 'type'] == 'DeepSORTTracker' else False
  133. if self.use_deepsort_tracker:
  134. # use DeepSORTTracker
  135. if self.reid_pred_config is not None and hasattr(
  136. self.reid_pred_config, 'tracker'):
  137. cfg = self.reid_pred_config.tracker
  138. budget = cfg.get('budget', 100)
  139. max_age = cfg.get('max_age', 30)
  140. max_iou_distance = cfg.get('max_iou_distance', 0.7)
  141. matching_threshold = cfg.get('matching_threshold', 0.2)
  142. min_box_area = cfg.get('min_box_area', 0)
  143. vertical_ratio = cfg.get('vertical_ratio', 0)
  144. self.tracker = DeepSORTTracker(
  145. budget=budget,
  146. max_age=max_age,
  147. max_iou_distance=max_iou_distance,
  148. matching_threshold=matching_threshold,
  149. min_box_area=min_box_area,
  150. vertical_ratio=vertical_ratio, )
  151. else:
  152. # use ByteTracker
  153. use_byte = cfg.get('use_byte', False)
  154. det_thresh = cfg.get('det_thresh', 0.3)
  155. min_box_area = cfg.get('min_box_area', 0)
  156. vertical_ratio = cfg.get('vertical_ratio', 0)
  157. match_thres = cfg.get('match_thres', 0.9)
  158. conf_thres = cfg.get('conf_thres', 0.6)
  159. low_conf_thres = cfg.get('low_conf_thres', 0.1)
  160. self.tracker = JDETracker(
  161. use_byte=use_byte,
  162. det_thresh=det_thresh,
  163. num_classes=self.num_classes,
  164. min_box_area=min_box_area,
  165. vertical_ratio=vertical_ratio,
  166. match_thres=match_thres,
  167. conf_thres=conf_thres,
  168. low_conf_thres=low_conf_thres, )
  169. self.do_mtmct = False if mtmct_dir is None else True
  170. self.mtmct_dir = mtmct_dir
  171. def postprocess(self, inputs, result):
  172. # postprocess output of predictor
  173. np_boxes_num = result['boxes_num']
  174. if np_boxes_num[0] <= 0:
  175. print('[WARNNING] No object detected.')
  176. result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
  177. result = {k: v for k, v in result.items() if v is not None}
  178. return result
  179. def reidprocess(self, det_results, repeats=1):
  180. pred_dets = det_results['boxes'] # cls_id, score, x0, y0, x1, y1
  181. pred_xyxys = pred_dets[:, 2:6]
  182. ori_image = det_results['ori_image']
  183. ori_image_shape = ori_image.shape[:2]
  184. pred_xyxys, keep_idx = clip_box(pred_xyxys, ori_image_shape)
  185. if len(keep_idx[0]) == 0:
  186. det_results['boxes'] = np.zeros((1, 6), dtype=np.float32)
  187. det_results['embeddings'] = None
  188. return det_results
  189. pred_dets = pred_dets[keep_idx[0]]
  190. pred_xyxys = pred_dets[:, 2:6]
  191. w, h = self.tracker.input_size
  192. crops = get_crops(pred_xyxys, ori_image, w, h)
  193. # to keep fast speed, only use topk crops
  194. crops = crops[:50] # reid_batch_size
  195. det_results['crops'] = np.array(crops).astype('float32')
  196. det_results['boxes'] = pred_dets[:50]
  197. input_names = self.reid_predictor.get_input_names()
  198. for i in range(len(input_names)):
  199. input_tensor = self.reid_predictor.get_input_handle(input_names[i])
  200. input_tensor.copy_from_cpu(det_results[input_names[i]])
  201. # model prediction
  202. for i in range(repeats):
  203. self.reid_predictor.run()
  204. output_names = self.reid_predictor.get_output_names()
  205. feature_tensor = self.reid_predictor.get_output_handle(output_names[
  206. 0])
  207. pred_embs = feature_tensor.copy_to_cpu()
  208. det_results['embeddings'] = pred_embs
  209. return det_results
  210. def tracking(self, det_results):
  211. pred_dets = det_results['boxes'] # cls_id, score, x0, y0, x1, y1
  212. pred_embs = det_results.get('embeddings', None)
  213. if self.use_deepsort_tracker:
  214. # use DeepSORTTracker, only support singe class
  215. self.tracker.predict()
  216. online_targets = self.tracker.update(pred_dets, pred_embs)
  217. online_tlwhs, online_scores, online_ids = [], [], []
  218. if self.do_mtmct:
  219. online_tlbrs, online_feats = [], []
  220. for t in online_targets:
  221. if not t.is_confirmed() or t.time_since_update > 1:
  222. continue
  223. tlwh = t.to_tlwh()
  224. tscore = t.score
  225. tid = t.track_id
  226. if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  227. 3] > self.tracker.vertical_ratio:
  228. continue
  229. online_tlwhs.append(tlwh)
  230. online_scores.append(tscore)
  231. online_ids.append(tid)
  232. if self.do_mtmct:
  233. online_tlbrs.append(t.to_tlbr())
  234. online_feats.append(t.feat)
  235. tracking_outs = {
  236. 'online_tlwhs': online_tlwhs,
  237. 'online_scores': online_scores,
  238. 'online_ids': online_ids,
  239. }
  240. if self.do_mtmct:
  241. seq_name = det_results['seq_name']
  242. frame_id = det_results['frame_id']
  243. tracking_outs['feat_data'] = {}
  244. for _tlbr, _id, _feat in zip(online_tlbrs, online_ids,
  245. online_feats):
  246. feat_data = {}
  247. feat_data['bbox'] = _tlbr
  248. feat_data['frame'] = f"{frame_id:06d}"
  249. feat_data['id'] = _id
  250. _imgname = f'{seq_name}_{_id}_{frame_id}.jpg'
  251. feat_data['imgname'] = _imgname
  252. feat_data['feat'] = _feat
  253. tracking_outs['feat_data'].update({_imgname: feat_data})
  254. return tracking_outs
  255. else:
  256. # use ByteTracker, support multiple class
  257. online_tlwhs = defaultdict(list)
  258. online_scores = defaultdict(list)
  259. online_ids = defaultdict(list)
  260. if self.do_mtmct:
  261. online_tlbrs, online_feats = defaultdict(list), defaultdict(
  262. list)
  263. online_targets_dict = self.tracker.update(pred_dets, pred_embs)
  264. for cls_id in range(self.num_classes):
  265. online_targets = online_targets_dict[cls_id]
  266. for t in online_targets:
  267. tlwh = t.tlwh
  268. tid = t.track_id
  269. tscore = t.score
  270. if tlwh[2] * tlwh[3] <= self.tracker.min_box_area:
  271. continue
  272. if self.tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  273. 3] > self.tracker.vertical_ratio:
  274. continue
  275. online_tlwhs[cls_id].append(tlwh)
  276. online_ids[cls_id].append(tid)
  277. online_scores[cls_id].append(tscore)
  278. if self.do_mtmct:
  279. online_tlbrs[cls_id].append(t.tlbr)
  280. online_feats[cls_id].append(t.curr_feat)
  281. if self.do_mtmct:
  282. assert self.num_classes == 1, 'MTMCT only support single class.'
  283. tracking_outs = {
  284. 'online_tlwhs': online_tlwhs[0],
  285. 'online_scores': online_scores[0],
  286. 'online_ids': online_ids[0],
  287. }
  288. seq_name = det_results['seq_name']
  289. frame_id = det_results['frame_id']
  290. tracking_outs['feat_data'] = {}
  291. for _tlbr, _id, _feat in zip(online_tlbrs[0], online_ids[0],
  292. online_feats[0]):
  293. feat_data = {}
  294. feat_data['bbox'] = _tlbr
  295. feat_data['frame'] = f"{frame_id:06d}"
  296. feat_data['id'] = _id
  297. _imgname = f'{seq_name}_{_id}_{frame_id}.jpg'
  298. feat_data['imgname'] = _imgname
  299. feat_data['feat'] = _feat
  300. tracking_outs['feat_data'].update({_imgname: feat_data})
  301. return tracking_outs
  302. else:
  303. tracking_outs = {
  304. 'online_tlwhs': online_tlwhs,
  305. 'online_scores': online_scores,
  306. 'online_ids': online_ids,
  307. }
  308. return tracking_outs
  309. def predict_image(self,
  310. image_list,
  311. run_benchmark=False,
  312. repeats=1,
  313. visual=True,
  314. seq_name=None):
  315. num_classes = self.num_classes
  316. image_list.sort()
  317. ids2names = self.pred_config.labels
  318. if self.do_mtmct:
  319. mot_features_dict = {} # cid_tid_fid feats
  320. else:
  321. mot_results = []
  322. for frame_id, img_file in enumerate(image_list):
  323. if self.do_mtmct:
  324. if frame_id % 10 == 0:
  325. print('Tracking frame: %d' % (frame_id))
  326. batch_image_list = [img_file] # bs=1 in MOT model
  327. frame, _ = decode_image(img_file, {})
  328. if run_benchmark:
  329. # preprocess
  330. inputs = self.preprocess(batch_image_list) # warmup
  331. self.det_times.preprocess_time_s.start()
  332. inputs = self.preprocess(batch_image_list)
  333. self.det_times.preprocess_time_s.end()
  334. # model prediction
  335. result_warmup = self.predict(repeats=repeats) # warmup
  336. self.det_times.inference_time_s.start()
  337. result = self.predict(repeats=repeats)
  338. self.det_times.inference_time_s.end(repeats=repeats)
  339. # postprocess
  340. result_warmup = self.postprocess(inputs, result) # warmup
  341. self.det_times.postprocess_time_s.start()
  342. det_result = self.postprocess(inputs, result)
  343. self.det_times.postprocess_time_s.end()
  344. # tracking
  345. if self.use_reid:
  346. det_result['frame_id'] = frame_id
  347. det_result['seq_name'] = seq_name
  348. det_result['ori_image'] = frame
  349. det_result = self.reidprocess(det_result)
  350. result_warmup = self.tracking(det_result)
  351. self.det_times.tracking_time_s.start()
  352. if self.use_reid:
  353. det_result = self.reidprocess(det_result)
  354. tracking_outs = self.tracking(det_result)
  355. self.det_times.tracking_time_s.end()
  356. self.det_times.img_num += 1
  357. cm, gm, gu = get_current_memory_mb()
  358. self.cpu_mem += cm
  359. self.gpu_mem += gm
  360. self.gpu_util += gu
  361. else:
  362. self.det_times.preprocess_time_s.start()
  363. inputs = self.preprocess(batch_image_list)
  364. self.det_times.preprocess_time_s.end()
  365. self.det_times.inference_time_s.start()
  366. result = self.predict()
  367. self.det_times.inference_time_s.end()
  368. self.det_times.postprocess_time_s.start()
  369. det_result = self.postprocess(inputs, result)
  370. self.det_times.postprocess_time_s.end()
  371. # tracking process
  372. self.det_times.tracking_time_s.start()
  373. if self.use_reid:
  374. det_result['frame_id'] = frame_id
  375. det_result['seq_name'] = seq_name
  376. det_result['ori_image'] = frame
  377. det_result = self.reidprocess(det_result)
  378. tracking_outs = self.tracking(det_result)
  379. self.det_times.tracking_time_s.end()
  380. self.det_times.img_num += 1
  381. online_tlwhs = tracking_outs['online_tlwhs']
  382. online_scores = tracking_outs['online_scores']
  383. online_ids = tracking_outs['online_ids']
  384. if self.do_mtmct:
  385. feat_data_dict = tracking_outs['feat_data']
  386. mot_features_dict = dict(mot_features_dict, **feat_data_dict)
  387. else:
  388. mot_results.append([online_tlwhs, online_scores, online_ids])
  389. if visual:
  390. if len(image_list) > 1 and frame_id % 10 == 0:
  391. print('Tracking frame {}'.format(frame_id))
  392. frame, _ = decode_image(img_file, {})
  393. if isinstance(online_tlwhs, defaultdict):
  394. im = plot_tracking_dict(
  395. frame,
  396. num_classes,
  397. online_tlwhs,
  398. online_ids,
  399. online_scores,
  400. frame_id=frame_id,
  401. ids2names=[])
  402. else:
  403. im = plot_tracking(
  404. frame,
  405. online_tlwhs,
  406. online_ids,
  407. online_scores,
  408. frame_id=frame_id)
  409. save_dir = os.path.join(self.output_dir, seq_name)
  410. if not os.path.exists(save_dir):
  411. os.makedirs(save_dir)
  412. cv2.imwrite(
  413. os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)
  414. if self.do_mtmct:
  415. return mot_features_dict
  416. else:
  417. return mot_results
  418. def predict_video(self, video_file, camera_id):
  419. video_out_name = 'output.mp4'
  420. if camera_id != -1:
  421. capture = cv2.VideoCapture(camera_id)
  422. else:
  423. capture = cv2.VideoCapture(video_file)
  424. video_out_name = os.path.split(video_file)[-1]
  425. # Get Video info : resolution, fps, frame count
  426. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  427. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  428. fps = int(capture.get(cv2.CAP_PROP_FPS))
  429. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  430. print("fps: %d, frame_count: %d" % (fps, frame_count))
  431. if not os.path.exists(self.output_dir):
  432. os.makedirs(self.output_dir)
  433. out_path = os.path.join(self.output_dir, video_out_name)
  434. video_format = 'mp4v'
  435. fourcc = cv2.VideoWriter_fourcc(*video_format)
  436. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  437. frame_id = 1
  438. timer = MOTTimer()
  439. results = defaultdict(list)
  440. num_classes = self.num_classes
  441. data_type = 'mcmot' if num_classes > 1 else 'mot'
  442. ids2names = self.pred_config.labels
  443. center_traj = None
  444. entrance = None
  445. records = None
  446. if self.draw_center_traj:
  447. center_traj = [{} for i in range(num_classes)]
  448. if num_classes == 1:
  449. id_set = set()
  450. interval_id_set = set()
  451. in_id_list = list()
  452. out_id_list = list()
  453. prev_center = dict()
  454. records = list()
  455. entrance = [0, height / 2., width, height / 2.]
  456. video_fps = fps
  457. while (1):
  458. ret, frame = capture.read()
  459. if not ret:
  460. break
  461. if frame_id % 10 == 0:
  462. print('Tracking frame: %d' % (frame_id))
  463. frame_id += 1
  464. timer.tic()
  465. seq_name = video_out_name.split('.')[0]
  466. mot_results = self.predict_image(
  467. [frame], visual=False, seq_name=seq_name)
  468. timer.toc()
  469. # bs=1 in MOT model
  470. online_tlwhs, online_scores, online_ids = mot_results[0]
  471. # NOTE: just implement flow statistic for one class
  472. if num_classes == 1:
  473. result = (frame_id + 1, online_tlwhs[0], online_scores[0],
  474. online_ids[0])
  475. statistic = flow_statistic(
  476. result, self.secs_interval, self.do_entrance_counting,
  477. video_fps, entrance, id_set, interval_id_set, in_id_list,
  478. out_id_list, prev_center, records, data_type, num_classes)
  479. records = statistic['records']
  480. fps = 1. / timer.duration
  481. if self.use_deepsort_tracker:
  482. # use DeepSORTTracker, only support singe class
  483. results[0].append(
  484. (frame_id + 1, online_tlwhs, online_scores, online_ids))
  485. im = plot_tracking(
  486. frame,
  487. online_tlwhs,
  488. online_ids,
  489. online_scores,
  490. frame_id=frame_id,
  491. fps=fps,
  492. do_entrance_counting=self.do_entrance_counting,
  493. entrance=entrance)
  494. else:
  495. # use ByteTracker, support multiple class
  496. for cls_id in range(num_classes):
  497. results[cls_id].append(
  498. (frame_id + 1, online_tlwhs[cls_id],
  499. online_scores[cls_id], online_ids[cls_id]))
  500. im = plot_tracking_dict(
  501. frame,
  502. num_classes,
  503. online_tlwhs,
  504. online_ids,
  505. online_scores,
  506. frame_id=frame_id,
  507. fps=fps,
  508. ids2names=ids2names,
  509. do_entrance_counting=self.do_entrance_counting,
  510. entrance=entrance,
  511. records=records,
  512. center_traj=center_traj)
  513. writer.write(im)
  514. if camera_id != -1:
  515. cv2.imshow('Mask Detection', im)
  516. if cv2.waitKey(1) & 0xFF == ord('q'):
  517. break
  518. if self.save_mot_txts:
  519. result_filename = os.path.join(
  520. self.output_dir, video_out_name.split('.')[-2] + '.txt')
  521. write_mot_results(result_filename, results)
  522. result_filename = os.path.join(
  523. self.output_dir,
  524. video_out_name.split('.')[-2] + '_flow_statistic.txt')
  525. f = open(result_filename, 'w')
  526. for line in records:
  527. f.write(line)
  528. print('Flow statistic save in {}'.format(result_filename))
  529. f.close()
  530. writer.release()
  531. def predict_mtmct(self, mtmct_dir, mtmct_cfg):
  532. cameras_bias = mtmct_cfg['cameras_bias']
  533. cid_bias = parse_bias(cameras_bias)
  534. scene_cluster = list(cid_bias.keys())
  535. # 1.zone releated parameters
  536. use_zone = mtmct_cfg.get('use_zone', False)
  537. zone_path = mtmct_cfg.get('zone_path', None)
  538. # 2.tricks parameters, can be used for other mtmct dataset
  539. use_ff = mtmct_cfg.get('use_ff', False)
  540. use_rerank = mtmct_cfg.get('use_rerank', False)
  541. # 3.camera releated parameters
  542. use_camera = mtmct_cfg.get('use_camera', False)
  543. use_st_filter = mtmct_cfg.get('use_st_filter', False)
  544. # 4.zone releated parameters
  545. use_roi = mtmct_cfg.get('use_roi', False)
  546. roi_dir = mtmct_cfg.get('roi_dir', False)
  547. mot_list_breaks = []
  548. cid_tid_dict = dict()
  549. output_dir = self.output_dir
  550. if not os.path.exists(output_dir):
  551. os.makedirs(output_dir)
  552. seqs = os.listdir(mtmct_dir)
  553. for seq in sorted(seqs):
  554. fpath = os.path.join(mtmct_dir, seq)
  555. if os.path.isfile(fpath) and _is_valid_video(fpath):
  556. seq = seq.split('.')[-2]
  557. print('ffmpeg processing of video {}'.format(fpath))
  558. frames_path = video2frames(
  559. video_path=fpath, outpath=mtmct_dir, frame_rate=25)
  560. fpath = os.path.join(mtmct_dir, seq)
  561. if os.path.isdir(fpath) == False:
  562. print('{} is not a image folder.'.format(fpath))
  563. continue
  564. if os.path.exists(os.path.join(fpath, 'img1')):
  565. fpath = os.path.join(fpath, 'img1')
  566. assert os.path.isdir(fpath), '{} should be a directory'.format(
  567. fpath)
  568. image_list = glob.glob(os.path.join(fpath, '*.jpg'))
  569. image_list.sort()
  570. assert len(image_list) > 0, '{} has no images.'.format(fpath)
  571. print('start tracking seq: {}'.format(seq))
  572. mot_features_dict = self.predict_image(
  573. image_list, visual=False, seq_name=seq)
  574. cid = int(re.sub('[a-z,A-Z]', "", seq))
  575. tid_data, mot_list_break = trajectory_fusion(
  576. mot_features_dict,
  577. cid,
  578. cid_bias,
  579. use_zone=use_zone,
  580. zone_path=zone_path)
  581. mot_list_breaks.append(mot_list_break)
  582. # single seq process
  583. for line in tid_data:
  584. tracklet = tid_data[line]
  585. tid = tracklet['tid']
  586. if (cid, tid) not in cid_tid_dict:
  587. cid_tid_dict[(cid, tid)] = tracklet
  588. map_tid = sub_cluster(
  589. cid_tid_dict,
  590. scene_cluster,
  591. use_ff=use_ff,
  592. use_rerank=use_rerank,
  593. use_camera=use_camera,
  594. use_st_filter=use_st_filter)
  595. pred_mtmct_file = os.path.join(output_dir, 'mtmct_result.txt')
  596. if use_camera:
  597. gen_res(pred_mtmct_file, scene_cluster, map_tid, mot_list_breaks)
  598. else:
  599. gen_res(
  600. pred_mtmct_file,
  601. scene_cluster,
  602. map_tid,
  603. mot_list_breaks,
  604. use_roi=use_roi,
  605. roi_dir=roi_dir)
  606. camera_results, cid_tid_fid_res = get_mtmct_matching_results(
  607. pred_mtmct_file)
  608. crops_dir = os.path.join(output_dir, 'mtmct_crops')
  609. save_mtmct_crops(
  610. cid_tid_fid_res, images_dir=mtmct_dir, crops_dir=crops_dir)
  611. save_dir = os.path.join(output_dir, 'mtmct_vis')
  612. save_mtmct_vis_results(
  613. camera_results,
  614. images_dir=mtmct_dir,
  615. save_dir=save_dir,
  616. save_videos=FLAGS.save_images)
  617. def main():
  618. deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
  619. with open(deploy_file) as f:
  620. yml_conf = yaml.safe_load(f)
  621. arch = yml_conf['arch']
  622. detector = SDE_Detector(
  623. FLAGS.model_dir,
  624. tracker_config=FLAGS.tracker_config,
  625. device=FLAGS.device,
  626. run_mode=FLAGS.run_mode,
  627. batch_size=1,
  628. trt_min_shape=FLAGS.trt_min_shape,
  629. trt_max_shape=FLAGS.trt_max_shape,
  630. trt_opt_shape=FLAGS.trt_opt_shape,
  631. trt_calib_mode=FLAGS.trt_calib_mode,
  632. cpu_threads=FLAGS.cpu_threads,
  633. enable_mkldnn=FLAGS.enable_mkldnn,
  634. output_dir=FLAGS.output_dir,
  635. threshold=FLAGS.threshold,
  636. save_images=FLAGS.save_images,
  637. save_mot_txts=FLAGS.save_mot_txts,
  638. draw_center_traj=FLAGS.draw_center_traj,
  639. secs_interval=FLAGS.secs_interval,
  640. do_entrance_counting=FLAGS.do_entrance_counting,
  641. reid_model_dir=FLAGS.reid_model_dir,
  642. mtmct_dir=FLAGS.mtmct_dir, )
  643. # predict from video file or camera video stream
  644. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  645. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  646. elif FLAGS.mtmct_dir is not None:
  647. with open(FLAGS.mtmct_cfg) as f:
  648. mtmct_cfg = yaml.safe_load(f)
  649. detector.predict_mtmct(FLAGS.mtmct_dir, mtmct_cfg)
  650. else:
  651. # predict from image
  652. if FLAGS.image_dir is None and FLAGS.image_file is not None:
  653. assert FLAGS.batch_size == 1, "--batch_size should be 1 in MOT models."
  654. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  655. seq_name = FLAGS.image_dir.split('/')[-1]
  656. detector.predict_image(
  657. img_list, FLAGS.run_benchmark, repeats=10, seq_name=seq_name)
  658. if not FLAGS.run_benchmark:
  659. detector.det_times.info(average=True)
  660. else:
  661. mode = FLAGS.run_mode
  662. model_dir = FLAGS.model_dir
  663. model_info = {
  664. 'model_name': model_dir.strip('/').split('/')[-1],
  665. 'precision': mode.split('_')[-1]
  666. }
  667. bench_log(detector, img_list, model_info, name='MOT')
  668. if __name__ == '__main__':
  669. paddle.enable_static()
  670. parser = argsparser()
  671. FLAGS = parser.parse_args()
  672. print_arguments(FLAGS)
  673. FLAGS.device = FLAGS.device.upper()
  674. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  675. ], "device should be CPU, GPU or XPU"
  676. main()