track_half.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import os
  2. import os.path as osp
  3. import cv2
  4. import logging
  5. import argparse
  6. import motmetrics as mm
  7. import torch
  8. #from tracker.multitracker import JDETracker
  9. from tracker.byte_tracker import BYTETracker
  10. from utils import visualization as vis
  11. from dev.utils.log import logger
  12. from dev.utils.timer import Timer
  13. from utils.evaluation import Evaluator
  14. from utils.parse_config import parse_model_cfg
  15. import utils.datasets as datasets
  16. from utils.utils import *
  17. def write_results(filename, results, data_type):
  18. if data_type == 'mot':
  19. save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
  20. elif data_type == 'kitti':
  21. save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
  22. else:
  23. raise ValueError(data_type)
  24. with open(filename, 'w') as f:
  25. for frame_id, tlwhs, track_ids in results:
  26. if data_type == 'kitti':
  27. frame_id -= 1
  28. for tlwh, track_id in zip(tlwhs, track_ids):
  29. if track_id < 0:
  30. continue
  31. x1, y1, w, h = tlwh
  32. x2, y2 = x1 + w, y1 + h
  33. line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
  34. f.write(line)
  35. logger.info('save results to {}'.format(filename))
  36. def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30):
  37. '''
  38. Processes the video sequence given and provides the output of tracking result (write the results in video file)
  39. It uses JDE model for getting information about the online targets present.
  40. Parameters
  41. ----------
  42. opt : Namespace
  43. Contains information passed as commandline arguments.
  44. dataloader : LoadVideo
  45. Instance of LoadVideo class used for fetching the image sequence and associated data.
  46. data_type : String
  47. Type of dataset corresponding(similar) to the given video.
  48. result_filename : String
  49. The name(path) of the file for storing results.
  50. save_dir : String
  51. Path to the folder for storing the frames containing bounding box information (Result frames).
  52. show_image : bool
  53. Option for shhowing individial frames during run-time.
  54. frame_rate : int
  55. Frame-rate of the given video.
  56. Returns
  57. -------
  58. (Returns are not significant here)
  59. frame_id : int
  60. Sequence number of the last sequence
  61. '''
  62. if save_dir:
  63. mkdir_if_missing(save_dir)
  64. tracker = BYTETracker(opt, frame_rate=frame_rate)
  65. timer = Timer()
  66. results = []
  67. len_all = len(dataloader)
  68. start_frame = int(len_all / 2)
  69. frame_id = int(len_all / 2)
  70. for i, (path, img, img0) in enumerate(dataloader):
  71. if i < start_frame:
  72. continue
  73. if frame_id % 20 == 0:
  74. logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1./max(1e-5, timer.average_time)))
  75. # run tracking
  76. timer.tic()
  77. blob = torch.from_numpy(img).cuda().unsqueeze(0)
  78. online_targets = tracker.update(blob, img0)
  79. online_tlwhs = []
  80. online_ids = []
  81. for t in online_targets:
  82. tlwh = t.tlwh
  83. tid = t.track_id
  84. vertical = tlwh[2] / tlwh[3] > 1.6
  85. if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical:
  86. online_tlwhs.append(tlwh)
  87. online_ids.append(tid)
  88. timer.toc()
  89. # save results
  90. results.append((frame_id + 1, online_tlwhs, online_ids))
  91. if show_image or save_dir is not None:
  92. online_im = vis.plot_tracking(img0, online_tlwhs, online_ids, frame_id=frame_id,
  93. fps=1. / timer.average_time)
  94. if show_image:
  95. cv2.imshow('online_im', online_im)
  96. if save_dir is not None:
  97. cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)
  98. frame_id += 1
  99. # save results
  100. write_results(result_filename, results, data_type)
  101. return frame_id, timer.average_time, timer.calls
  102. def main(opt, data_root='/data/MOT16/train', det_root=None, seqs=('MOT16-05',), exp_name='demo',
  103. save_images=False, save_videos=False, show_image=True):
  104. logger.setLevel(logging.INFO)
  105. result_root = os.path.join(data_root, '..', 'results', exp_name)
  106. mkdir_if_missing(result_root)
  107. data_type = 'mot'
  108. # Read config_files
  109. cfg_dict = parse_model_cfg(opt.cfg)
  110. opt.img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]
  111. # run tracking
  112. accs = []
  113. n_frame = 0
  114. timer_avgs, timer_calls = [], []
  115. for seq in seqs:
  116. output_dir = os.path.join(data_root, '..','outputs', exp_name, seq) if save_images or save_videos else None
  117. logger.info('start seq: {}'.format(seq))
  118. dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'), opt.img_size)
  119. result_filename = os.path.join(result_root, '{}.txt'.format(seq))
  120. meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
  121. frame_rate = int(meta_info[meta_info.find('frameRate')+10:meta_info.find('\nseqLength')])
  122. nf, ta, tc = eval_seq(opt, dataloader, data_type, result_filename,
  123. save_dir=output_dir, show_image=show_image, frame_rate=frame_rate)
  124. n_frame += nf
  125. timer_avgs.append(ta)
  126. timer_calls.append(tc)
  127. # eval
  128. logger.info('Evaluate seq: {}'.format(seq))
  129. evaluator = Evaluator(data_root, seq, data_type)
  130. accs.append(evaluator.eval_file(result_filename))
  131. if save_videos:
  132. output_video_path = osp.join(output_dir, '{}.mp4'.format(seq))
  133. cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -c:v copy {}'.format(output_dir, output_video_path)
  134. os.system(cmd_str)
  135. timer_avgs = np.asarray(timer_avgs)
  136. timer_calls = np.asarray(timer_calls)
  137. all_time = np.dot(timer_avgs, timer_calls)
  138. avg_time = all_time / np.sum(timer_calls)
  139. logger.info('Time elapsed: {:.2f} seconds, FPS: {:.2f}'.format(all_time, 1.0 / avg_time))
  140. # get summary
  141. metrics = mm.metrics.motchallenge_metrics
  142. mh = mm.metrics.create()
  143. summary = Evaluator.get_summary(accs, seqs, metrics)
  144. strsummary = mm.io.render_summary(
  145. summary,
  146. formatters=mh.formatters,
  147. namemap=mm.io.motchallenge_metric_names
  148. )
  149. print(strsummary)
  150. Evaluator.save_summary(summary, os.path.join(result_root, 'summary_{}.xlsx'.format(exp_name)))
  151. if __name__ == '__main__':
  152. parser = argparse.ArgumentParser(prog='track.py')
  153. parser.add_argument('--cfg', type=str, default='cfg/yolov3_1088x608.cfg', help='cfg file path')
  154. parser.add_argument('--weights', type=str, default='weights/latest.pt', help='path to weights file')
  155. parser.add_argument('--iou-thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
  156. parser.add_argument('--conf-thres', type=float, default=0.7, help='object confidence threshold')
  157. parser.add_argument('--nms-thres', type=float, default=0.4, help='iou threshold for non-maximum suppression')
  158. parser.add_argument('--min-box-area', type=float, default=200, help='filter out tiny boxes')
  159. parser.add_argument('--track-buffer', type=int, default=30, help='tracking buffer')
  160. parser.add_argument('--test-mot16', action='store_true', help='tracking buffer')
  161. parser.add_argument('--val-mot17', default=True, help='validation on MOT17')
  162. parser.add_argument('--save-images', action='store_true', help='save tracking results (image)')
  163. parser.add_argument('--save-videos', action='store_true', help='save tracking results (video)')
  164. opt = parser.parse_args()
  165. print(opt, end='\n\n')
  166. if not opt.test_mot16:
  167. seqs_str = '''MOT17-02-SDP
  168. MOT17-04-SDP
  169. MOT17-05-SDP
  170. MOT17-09-SDP
  171. MOT17-10-SDP
  172. MOT17-11-SDP
  173. MOT17-13-SDP
  174. '''
  175. #seqs_str = '''MOT17-02-SDP'''
  176. data_root = '/opt/tiger/demo/datasets/MOT17/images/train'
  177. else:
  178. seqs_str = '''MOT16-01
  179. MOT16-03
  180. MOT16-06
  181. MOT16-07
  182. MOT16-08
  183. MOT16-12
  184. MOT16-14'''
  185. data_root = '/home/wangzd/datasets/MOT/MOT16/images/test'
  186. seqs = [seq.strip() for seq in seqs_str.split()]
  187. main(opt,
  188. data_root=data_root,
  189. seqs=seqs,
  190. exp_name=opt.weights.split('/')[-2],
  191. show_image=False,
  192. save_images=opt.save_images,
  193. save_videos=opt.save_videos)