3
0

tracker.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import glob
  19. import re
  20. import paddle
  21. import numpy as np
  22. from tqdm import tqdm
  23. from collections import defaultdict
  24. from ppdet.core.workspace import create
  25. from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
  26. from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
  27. from ppdet.modeling.mot.utils import MOTTimer, load_det_results, write_mot_results, save_vis_results
  28. from ppdet.modeling.mot.tracker import JDETracker, DeepSORTTracker
  29. from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric, MCMOTMetric
  30. import ppdet.utils.stats as stats
  31. from .callbacks import Callback, ComposeCallback
  32. from ppdet.utils.logger import setup_logger
  33. logger = setup_logger(__name__)
  34. MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack']
  35. MOT_ARCH_JDE = ['JDE', 'FairMOT']
  36. MOT_ARCH_SDE = ['DeepSORT', 'ByteTrack']
  37. MOT_DATA_TYPE = ['mot', 'mcmot', 'kitti']
  38. __all__ = ['Tracker']
  39. class Tracker(object):
  40. def __init__(self, cfg, mode='eval'):
  41. self.cfg = cfg
  42. assert mode.lower() in ['test', 'eval'], \
  43. "mode should be 'test' or 'eval'"
  44. self.mode = mode.lower()
  45. self.optimizer = None
  46. # build MOT data loader
  47. self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
  48. # build model
  49. self.model = create(cfg.architecture)
  50. self.status = {}
  51. self.start_epoch = 0
  52. # initial default callbacks
  53. self._init_callbacks()
  54. # initial default metrics
  55. self._init_metrics()
  56. self._reset_metrics()
  57. def _init_callbacks(self):
  58. self._callbacks = []
  59. self._compose_callback = None
  60. def _init_metrics(self):
  61. if self.mode in ['test']:
  62. self._metrics = []
  63. return
  64. if self.cfg.metric == 'MOT':
  65. self._metrics = [MOTMetric(), ]
  66. elif self.cfg.metric == 'MCMOT':
  67. self._metrics = [MCMOTMetric(self.cfg.num_classes), ]
  68. elif self.cfg.metric == 'KITTI':
  69. self._metrics = [KITTIMOTMetric(), ]
  70. else:
  71. logger.warning("Metric not support for metric type {}".format(
  72. self.cfg.metric))
  73. self._metrics = []
  74. def _reset_metrics(self):
  75. for metric in self._metrics:
  76. metric.reset()
  77. def register_callbacks(self, callbacks):
  78. callbacks = [h for h in list(callbacks) if h is not None]
  79. for c in callbacks:
  80. assert isinstance(c, Callback), \
  81. "metrics shoule be instances of subclass of Metric"
  82. self._callbacks.extend(callbacks)
  83. self._compose_callback = ComposeCallback(self._callbacks)
  84. def register_metrics(self, metrics):
  85. metrics = [m for m in list(metrics) if m is not None]
  86. for m in metrics:
  87. assert isinstance(m, Metric), \
  88. "metrics shoule be instances of subclass of Metric"
  89. self._metrics.extend(metrics)
  90. def load_weights_jde(self, weights):
  91. load_weight(self.model, weights, self.optimizer)
  92. def load_weights_sde(self, det_weights, reid_weights):
  93. with_detector = self.model.detector is not None
  94. with_reid = self.model.reid is not None
  95. if with_detector:
  96. load_weight(self.model.detector, det_weights)
  97. if with_reid:
  98. load_weight(self.model.reid, reid_weights)
  99. else:
  100. load_weight(self.model.reid, reid_weights)
  101. def _eval_seq_jde(self,
  102. dataloader,
  103. save_dir=None,
  104. show_image=False,
  105. frame_rate=30,
  106. draw_threshold=0):
  107. if save_dir:
  108. if not os.path.exists(save_dir): os.makedirs(save_dir)
  109. tracker = self.model.tracker
  110. tracker.max_time_lost = int(frame_rate / 30.0 * tracker.track_buffer)
  111. timer = MOTTimer()
  112. frame_id = 0
  113. self.status['mode'] = 'track'
  114. self.model.eval()
  115. results = defaultdict(list) # support single class and multi classes
  116. for step_id, data in enumerate(tqdm(dataloader)):
  117. self.status['step_id'] = step_id
  118. # forward
  119. timer.tic()
  120. pred_dets, pred_embs = self.model(data)
  121. pred_dets, pred_embs = pred_dets.numpy(), pred_embs.numpy()
  122. online_targets_dict = self.model.tracker.update(pred_dets,
  123. pred_embs)
  124. online_tlwhs = defaultdict(list)
  125. online_scores = defaultdict(list)
  126. online_ids = defaultdict(list)
  127. for cls_id in range(self.cfg.num_classes):
  128. online_targets = online_targets_dict[cls_id]
  129. for t in online_targets:
  130. tlwh = t.tlwh
  131. tid = t.track_id
  132. tscore = t.score
  133. if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue
  134. if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  135. 3] > tracker.vertical_ratio:
  136. continue
  137. online_tlwhs[cls_id].append(tlwh)
  138. online_ids[cls_id].append(tid)
  139. online_scores[cls_id].append(tscore)
  140. # save results
  141. results[cls_id].append(
  142. (frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
  143. online_ids[cls_id]))
  144. timer.toc()
  145. save_vis_results(data, frame_id, online_ids, online_tlwhs,
  146. online_scores, timer.average_time, show_image,
  147. save_dir, self.cfg.num_classes)
  148. frame_id += 1
  149. return results, frame_id, timer.average_time, timer.calls
  150. def _eval_seq_sde(self,
  151. dataloader,
  152. save_dir=None,
  153. show_image=False,
  154. frame_rate=30,
  155. seq_name='',
  156. scaled=False,
  157. det_file='',
  158. draw_threshold=0):
  159. if save_dir:
  160. if not os.path.exists(save_dir): os.makedirs(save_dir)
  161. use_detector = False if not self.model.detector else True
  162. use_reid = False if not self.model.reid else True
  163. timer = MOTTimer()
  164. results = defaultdict(list)
  165. frame_id = 0
  166. self.status['mode'] = 'track'
  167. self.model.eval()
  168. if use_reid:
  169. self.model.reid.eval()
  170. if not use_detector:
  171. dets_list = load_det_results(det_file, len(dataloader))
  172. logger.info('Finish loading detection results file {}.'.format(
  173. det_file))
  174. tracker = self.model.tracker
  175. for step_id, data in enumerate(tqdm(dataloader)):
  176. self.status['step_id'] = step_id
  177. ori_image = data['ori_image'] # [bs, H, W, 3]
  178. ori_image_shape = data['ori_image'].shape[1:3]
  179. # ori_image_shape: [H, W]
  180. input_shape = data['image'].shape[2:]
  181. # input_shape: [h, w], before data transforms, set in model config
  182. im_shape = data['im_shape'][0].numpy()
  183. # im_shape: [new_h, new_w], after data transforms
  184. scale_factor = data['scale_factor'][0].numpy()
  185. empty_detections = False
  186. # when it has no detected bboxes, will not inference reid model
  187. # and if visualize, use original image instead
  188. # forward
  189. timer.tic()
  190. if not use_detector:
  191. dets = dets_list[frame_id]
  192. bbox_tlwh = np.array(dets['bbox'], dtype='float32')
  193. if bbox_tlwh.shape[0] > 0:
  194. # detector outputs: pred_cls_ids, pred_scores, pred_bboxes
  195. pred_cls_ids = np.array(dets['cls_id'], dtype='float32')
  196. pred_scores = np.array(dets['score'], dtype='float32')
  197. pred_bboxes = np.concatenate(
  198. (bbox_tlwh[:, 0:2],
  199. bbox_tlwh[:, 2:4] + bbox_tlwh[:, 0:2]),
  200. axis=1)
  201. else:
  202. logger.warning(
  203. 'Frame {} has not object, try to modify score threshold.'.
  204. format(frame_id))
  205. empty_detections = True
  206. else:
  207. outs = self.model.detector(data)
  208. outs['bbox'] = outs['bbox'].numpy()
  209. outs['bbox_num'] = outs['bbox_num'].numpy()
  210. if len(outs['bbox']) > 0 and empty_detections == False:
  211. # detector outputs: pred_cls_ids, pred_scores, pred_bboxes
  212. pred_cls_ids = outs['bbox'][:, 0:1]
  213. pred_scores = outs['bbox'][:, 1:2]
  214. if not scaled:
  215. # Note: scaled=False only in JDE YOLOv3 or other detectors
  216. # with LetterBoxResize and JDEBBoxPostProcess.
  217. #
  218. # 'scaled' means whether the coords after detector outputs
  219. # have been scaled back to the original image, set True
  220. # in general detector, set False in JDE YOLOv3.
  221. pred_bboxes = scale_coords(outs['bbox'][:, 2:],
  222. input_shape, im_shape,
  223. scale_factor)
  224. else:
  225. pred_bboxes = outs['bbox'][:, 2:]
  226. pred_dets_old = np.concatenate(
  227. (pred_cls_ids, pred_scores, pred_bboxes), axis=1)
  228. else:
  229. logger.warning(
  230. 'Frame {} has not detected object, try to modify score threshold.'.
  231. format(frame_id))
  232. empty_detections = True
  233. if not empty_detections:
  234. pred_xyxys, keep_idx = clip_box(pred_bboxes, ori_image_shape)
  235. if len(keep_idx[0]) == 0:
  236. logger.warning(
  237. 'Frame {} has not detected object left after clip_box.'.
  238. format(frame_id))
  239. empty_detections = True
  240. if empty_detections:
  241. timer.toc()
  242. # if visualize, use original image instead
  243. online_ids, online_tlwhs, online_scores = None, None, None
  244. save_vis_results(data, frame_id, online_ids, online_tlwhs,
  245. online_scores, timer.average_time, show_image,
  246. save_dir, self.cfg.num_classes)
  247. frame_id += 1
  248. # thus will not inference reid model
  249. continue
  250. pred_cls_ids = pred_cls_ids[keep_idx[0]]
  251. pred_scores = pred_scores[keep_idx[0]]
  252. pred_dets = np.concatenate(
  253. (pred_cls_ids, pred_scores, pred_xyxys), axis=1)
  254. if use_reid:
  255. crops = get_crops(
  256. pred_xyxys,
  257. ori_image,
  258. w=tracker.input_size[0],
  259. h=tracker.input_size[1])
  260. crops = paddle.to_tensor(crops)
  261. data.update({'crops': crops})
  262. pred_embs = self.model(data)['embeddings'].numpy()
  263. else:
  264. pred_embs = None
  265. if isinstance(tracker, DeepSORTTracker):
  266. online_tlwhs, online_scores, online_ids = [], [], []
  267. tracker.predict()
  268. online_targets = tracker.update(pred_dets, pred_embs)
  269. for t in online_targets:
  270. if not t.is_confirmed() or t.time_since_update > 1:
  271. continue
  272. tlwh = t.to_tlwh()
  273. tscore = t.score
  274. tid = t.track_id
  275. if tscore < draw_threshold: continue
  276. if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue
  277. if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  278. 3] > tracker.vertical_ratio:
  279. continue
  280. online_tlwhs.append(tlwh)
  281. online_scores.append(tscore)
  282. online_ids.append(tid)
  283. timer.toc()
  284. # save results
  285. results[0].append(
  286. (frame_id + 1, online_tlwhs, online_scores, online_ids))
  287. save_vis_results(data, frame_id, online_ids, online_tlwhs,
  288. online_scores, timer.average_time, show_image,
  289. save_dir, self.cfg.num_classes)
  290. elif isinstance(tracker, JDETracker):
  291. # trick hyperparams only used for MOTChallenge (MOT17, MOT20) Test-set
  292. tracker.track_buffer, tracker.conf_thres = get_trick_hyperparams(
  293. seq_name, tracker.track_buffer, tracker.conf_thres)
  294. online_targets_dict = tracker.update(pred_dets_old, pred_embs)
  295. online_tlwhs = defaultdict(list)
  296. online_scores = defaultdict(list)
  297. online_ids = defaultdict(list)
  298. for cls_id in range(self.cfg.num_classes):
  299. online_targets = online_targets_dict[cls_id]
  300. for t in online_targets:
  301. tlwh = t.tlwh
  302. tid = t.track_id
  303. tscore = t.score
  304. if tlwh[2] * tlwh[3] <= tracker.min_box_area: continue
  305. if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[
  306. 3] > tracker.vertical_ratio:
  307. continue
  308. online_tlwhs[cls_id].append(tlwh)
  309. online_ids[cls_id].append(tid)
  310. online_scores[cls_id].append(tscore)
  311. # save results
  312. results[cls_id].append(
  313. (frame_id + 1, online_tlwhs[cls_id],
  314. online_scores[cls_id], online_ids[cls_id]))
  315. timer.toc()
  316. save_vis_results(data, frame_id, online_ids, online_tlwhs,
  317. online_scores, timer.average_time, show_image,
  318. save_dir, self.cfg.num_classes)
  319. frame_id += 1
  320. return results, frame_id, timer.average_time, timer.calls
  321. def mot_evaluate(self,
  322. data_root,
  323. seqs,
  324. output_dir,
  325. data_type='mot',
  326. model_type='JDE',
  327. save_images=False,
  328. save_videos=False,
  329. show_image=False,
  330. scaled=False,
  331. det_results_dir=''):
  332. if not os.path.exists(output_dir): os.makedirs(output_dir)
  333. result_root = os.path.join(output_dir, 'mot_results')
  334. if not os.path.exists(result_root): os.makedirs(result_root)
  335. assert data_type in MOT_DATA_TYPE, \
  336. "data_type should be 'mot', 'mcmot' or 'kitti'"
  337. assert model_type in MOT_ARCH, \
  338. "model_type should be 'JDE', 'DeepSORT', 'FairMOT' or 'ByteTrack'"
  339. # run tracking
  340. n_frame = 0
  341. timer_avgs, timer_calls = [], []
  342. for seq in seqs:
  343. infer_dir = os.path.join(data_root, seq)
  344. if not os.path.exists(infer_dir) or not os.path.isdir(infer_dir):
  345. logger.warning("Seq {} error, {} has no images.".format(
  346. seq, infer_dir))
  347. continue
  348. if os.path.exists(os.path.join(infer_dir, 'img1')):
  349. infer_dir = os.path.join(infer_dir, 'img1')
  350. frame_rate = 30
  351. seqinfo = os.path.join(data_root, seq, 'seqinfo.ini')
  352. if os.path.exists(seqinfo):
  353. meta_info = open(seqinfo).read()
  354. frame_rate = int(meta_info[meta_info.find('frameRate') + 10:
  355. meta_info.find('\nseqLength')])
  356. save_dir = os.path.join(output_dir, 'mot_outputs',
  357. seq) if save_images or save_videos else None
  358. logger.info('Evaluate seq: {}'.format(seq))
  359. self.dataset.set_images(self.get_infer_images(infer_dir))
  360. dataloader = create('EvalMOTReader')(self.dataset, 0)
  361. result_filename = os.path.join(result_root, '{}.txt'.format(seq))
  362. with paddle.no_grad():
  363. if model_type in MOT_ARCH_JDE:
  364. results, nf, ta, tc = self._eval_seq_jde(
  365. dataloader,
  366. save_dir=save_dir,
  367. show_image=show_image,
  368. frame_rate=frame_rate)
  369. elif model_type in MOT_ARCH_SDE:
  370. results, nf, ta, tc = self._eval_seq_sde(
  371. dataloader,
  372. save_dir=save_dir,
  373. show_image=show_image,
  374. frame_rate=frame_rate,
  375. seq_name=seq,
  376. scaled=scaled,
  377. det_file=os.path.join(det_results_dir,
  378. '{}.txt'.format(seq)))
  379. else:
  380. raise ValueError(model_type)
  381. write_mot_results(result_filename, results, data_type,
  382. self.cfg.num_classes)
  383. n_frame += nf
  384. timer_avgs.append(ta)
  385. timer_calls.append(tc)
  386. if save_videos:
  387. output_video_path = os.path.join(save_dir, '..',
  388. '{}_vis.mp4'.format(seq))
  389. cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
  390. save_dir, output_video_path)
  391. os.system(cmd_str)
  392. logger.info('Save video in {}.'.format(output_video_path))
  393. # update metrics
  394. for metric in self._metrics:
  395. metric.update(data_root, seq, data_type, result_root,
  396. result_filename)
  397. timer_avgs = np.asarray(timer_avgs)
  398. timer_calls = np.asarray(timer_calls)
  399. all_time = np.dot(timer_avgs, timer_calls)
  400. avg_time = all_time / np.sum(timer_calls)
  401. logger.info('Time elapsed: {:.2f} seconds, FPS: {:.2f}'.format(
  402. all_time, 1.0 / avg_time))
  403. # accumulate metric to log out
  404. for metric in self._metrics:
  405. metric.accumulate()
  406. metric.log()
  407. # reset metric states for metric may performed multiple times
  408. self._reset_metrics()
  409. def get_infer_images(self, infer_dir):
  410. assert infer_dir is None or os.path.isdir(infer_dir), \
  411. "{} is not a directory".format(infer_dir)
  412. images = set()
  413. assert os.path.isdir(infer_dir), \
  414. "infer_dir {} is not a directory".format(infer_dir)
  415. exts = ['jpg', 'jpeg', 'png', 'bmp']
  416. exts += [ext.upper() for ext in exts]
  417. for ext in exts:
  418. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  419. images = list(images)
  420. images.sort()
  421. assert len(images) > 0, "no image found in {}".format(infer_dir)
  422. logger.info("Found {} inference images in total.".format(len(images)))
  423. return images
  424. def mot_predict_seq(self,
  425. video_file,
  426. frame_rate,
  427. image_dir,
  428. output_dir,
  429. data_type='mot',
  430. model_type='JDE',
  431. save_images=False,
  432. save_videos=True,
  433. show_image=False,
  434. scaled=False,
  435. det_results_dir='',
  436. draw_threshold=0.5):
  437. assert video_file is not None or image_dir is not None, \
  438. "--video_file or --image_dir should be set."
  439. assert video_file is None or os.path.isfile(video_file), \
  440. "{} is not a file".format(video_file)
  441. assert image_dir is None or os.path.isdir(image_dir), \
  442. "{} is not a directory".format(image_dir)
  443. if not os.path.exists(output_dir): os.makedirs(output_dir)
  444. result_root = os.path.join(output_dir, 'mot_results')
  445. if not os.path.exists(result_root): os.makedirs(result_root)
  446. assert data_type in MOT_DATA_TYPE, \
  447. "data_type should be 'mot', 'mcmot' or 'kitti'"
  448. assert model_type in MOT_ARCH, \
  449. "model_type should be 'JDE', 'DeepSORT', 'FairMOT' or 'ByteTrack'"
  450. # run tracking
  451. if video_file:
  452. seq = video_file.split('/')[-1].split('.')[0]
  453. self.dataset.set_video(video_file, frame_rate)
  454. logger.info('Starting tracking video {}'.format(video_file))
  455. elif image_dir:
  456. seq = image_dir.split('/')[-1].split('.')[0]
  457. if os.path.exists(os.path.join(image_dir, 'img1')):
  458. image_dir = os.path.join(image_dir, 'img1')
  459. images = [
  460. '{}/{}'.format(image_dir, x) for x in os.listdir(image_dir)
  461. ]
  462. images.sort()
  463. self.dataset.set_images(images)
  464. logger.info('Starting tracking folder {}, found {} images'.format(
  465. image_dir, len(images)))
  466. else:
  467. raise ValueError('--video_file or --image_dir should be set.')
  468. save_dir = os.path.join(output_dir, 'mot_outputs',
  469. seq) if save_images or save_videos else None
  470. dataloader = create('TestMOTReader')(self.dataset, 0)
  471. result_filename = os.path.join(result_root, '{}.txt'.format(seq))
  472. if frame_rate == -1:
  473. frame_rate = self.dataset.frame_rate
  474. with paddle.no_grad():
  475. if model_type in MOT_ARCH_JDE:
  476. results, nf, ta, tc = self._eval_seq_jde(
  477. dataloader,
  478. save_dir=save_dir,
  479. show_image=show_image,
  480. frame_rate=frame_rate,
  481. draw_threshold=draw_threshold)
  482. elif model_type in MOT_ARCH_SDE:
  483. results, nf, ta, tc = self._eval_seq_sde(
  484. dataloader,
  485. save_dir=save_dir,
  486. show_image=show_image,
  487. frame_rate=frame_rate,
  488. seq_name=seq,
  489. scaled=scaled,
  490. det_file=os.path.join(det_results_dir,
  491. '{}.txt'.format(seq)),
  492. draw_threshold=draw_threshold)
  493. else:
  494. raise ValueError(model_type)
  495. if save_videos:
  496. output_video_path = os.path.join(save_dir, '..',
  497. '{}_vis.mp4'.format(seq))
  498. cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
  499. save_dir, output_video_path)
  500. os.system(cmd_str)
  501. logger.info('Save video in {}'.format(output_video_path))
  502. write_mot_results(result_filename, results, data_type,
  503. self.cfg.num_classes)
  504. def get_trick_hyperparams(video_name, ori_buffer, ori_thresh):
  505. if video_name[:3] != 'MOT':
  506. # only used for MOTChallenge (MOT17, MOT20) Test-set
  507. return ori_buffer, ori_thresh
  508. video_name = video_name[:8]
  509. if 'MOT17-05' in video_name:
  510. track_buffer = 14
  511. elif 'MOT17-13' in video_name:
  512. track_buffer = 25
  513. else:
  514. track_buffer = ori_buffer
  515. if 'MOT17-01' in video_name:
  516. track_thresh = 0.65
  517. elif 'MOT17-06' in video_name:
  518. track_thresh = 0.65
  519. elif 'MOT17-12' in video_name:
  520. track_thresh = 0.7
  521. elif 'MOT17-14' in video_name:
  522. track_thresh = 0.67
  523. else:
  524. track_thresh = ori_thresh
  525. if 'MOT20-06' in video_name or 'MOT20-08' in video_name:
  526. track_thresh = 0.3
  527. else:
  528. track_thresh = ori_thresh
  529. return track_buffer, ori_thresh