demo_track.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. from loguru import logger
  2. import cv2
  3. import torch
  4. from yolox.data.data_augment import preproc
  5. from yolox.exp import get_exp
  6. from yolox.utils import fuse_model, get_model_info, postprocess, vis
  7. from yolox.utils.visualize import plot_tracking
  8. from yolox.tracker.byte_tracker import BYTETracker
  9. from yolox.tracking_utils.timer import Timer
  10. import argparse
  11. import os
  12. import time
  13. IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
  14. def make_parser():
  15. parser = argparse.ArgumentParser("ByteTrack Demo!")
  16. parser.add_argument(
  17. "demo", default="image", help="demo type, eg. image, video and webcam"
  18. )
  19. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  20. parser.add_argument("-n", "--name", type=str, default=None, help="model name")
  21. parser.add_argument(
  22. #"--path", default="./datasets/mot/train/MOT17-05-FRCNN/img1", help="path to images or video"
  23. "--path", default="./videos/palace.mp4", help="path to images or video"
  24. )
  25. parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
  26. parser.add_argument(
  27. "--save_result",
  28. action="store_true",
  29. help="whether to save the inference result of image/video",
  30. )
  31. # exp file
  32. parser.add_argument(
  33. "-f",
  34. "--exp_file",
  35. default=None,
  36. type=str,
  37. help="pls input your expriment description file",
  38. )
  39. parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
  40. parser.add_argument(
  41. "--device",
  42. default="gpu",
  43. type=str,
  44. help="device to run our model, can either be cpu or gpu",
  45. )
  46. parser.add_argument("--conf", default=None, type=float, help="test conf")
  47. parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
  48. parser.add_argument("--tsize", default=None, type=int, help="test img size")
  49. parser.add_argument(
  50. "--fp16",
  51. dest="fp16",
  52. default=False,
  53. action="store_true",
  54. help="Adopting mix precision evaluating.",
  55. )
  56. parser.add_argument(
  57. "--fuse",
  58. dest="fuse",
  59. default=False,
  60. action="store_true",
  61. help="Fuse conv and bn for testing.",
  62. )
  63. parser.add_argument(
  64. "--trt",
  65. dest="trt",
  66. default=False,
  67. action="store_true",
  68. help="Using TensorRT model for testing.",
  69. )
  70. # tracking args
  71. parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
  72. parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
  73. parser.add_argument("--match_thresh", type=int, default=0.8, help="matching threshold for tracking")
  74. parser.add_argument('--min-box-area', type=float, default=10, help='filter out tiny boxes')
  75. parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
  76. return parser
  77. def get_image_list(path):
  78. image_names = []
  79. for maindir, subdir, file_name_list in os.walk(path):
  80. for filename in file_name_list:
  81. apath = os.path.join(maindir, filename)
  82. ext = os.path.splitext(apath)[1]
  83. if ext in IMAGE_EXT:
  84. image_names.append(apath)
  85. return image_names
  86. def write_results(filename, results):
  87. save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n'
  88. with open(filename, 'w') as f:
  89. for frame_id, tlwhs, track_ids, scores in results:
  90. for tlwh, track_id, score in zip(tlwhs, track_ids, scores):
  91. if track_id < 0:
  92. continue
  93. x1, y1, w, h = tlwh
  94. line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1), s=round(score, 2))
  95. f.write(line)
  96. logger.info('save results to {}'.format(filename))
  97. class Predictor(object):
  98. def __init__(
  99. self,
  100. model,
  101. exp,
  102. trt_file=None,
  103. decoder=None,
  104. device="cpu",
  105. fp16=False
  106. ):
  107. self.model = model
  108. self.decoder = decoder
  109. self.num_classes = exp.num_classes
  110. self.confthre = exp.test_conf
  111. self.nmsthre = exp.nmsthre
  112. self.test_size = exp.test_size
  113. self.device = device
  114. self.fp16 = fp16
  115. if trt_file is not None:
  116. from torch2trt import TRTModule
  117. model_trt = TRTModule()
  118. model_trt.load_state_dict(torch.load(trt_file))
  119. x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
  120. self.model(x)
  121. self.model = model_trt
  122. self.rgb_means = (0.485, 0.456, 0.406)
  123. self.std = (0.229, 0.224, 0.225)
  124. def inference(self, img, timer):
  125. img_info = {"id": 0}
  126. if isinstance(img, str):
  127. img_info["file_name"] = os.path.basename(img)
  128. img = cv2.imread(img)
  129. else:
  130. img_info["file_name"] = None
  131. height, width = img.shape[:2]
  132. img_info["height"] = height
  133. img_info["width"] = width
  134. img_info["raw_img"] = img
  135. img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
  136. img_info["ratio"] = ratio
  137. img = torch.from_numpy(img)
  138. img = img.unsqueeze(0)
  139. img = img.float()
  140. if self.device == "gpu":
  141. img = img.cuda()
  142. if self.fp16:
  143. img = img.half() # to FP16
  144. with torch.no_grad():
  145. timer.tic()
  146. outputs = self.model(img)
  147. if self.decoder is not None:
  148. outputs = self.decoder(outputs, dtype=outputs.type())
  149. outputs = postprocess(
  150. outputs, self.num_classes, self.confthre, self.nmsthre
  151. )
  152. #logger.info("Infer time: {:.4f}s".format(time.time() - t0))
  153. return outputs, img_info
  154. def image_demo(predictor, vis_folder, path, current_time, save_result):
  155. if os.path.isdir(path):
  156. files = get_image_list(path)
  157. else:
  158. files = [path]
  159. files.sort()
  160. tracker = BYTETracker(args, frame_rate=30)
  161. timer = Timer()
  162. frame_id = 0
  163. results = []
  164. for image_name in files:
  165. if frame_id % 20 == 0:
  166. logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
  167. outputs, img_info = predictor.inference(image_name, timer)
  168. online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
  169. online_tlwhs = []
  170. online_ids = []
  171. online_scores = []
  172. for t in online_targets:
  173. tlwh = t.tlwh
  174. tid = t.track_id
  175. vertical = tlwh[2] / tlwh[3] > 1.6
  176. if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
  177. online_tlwhs.append(tlwh)
  178. online_ids.append(tid)
  179. online_scores.append(t.score)
  180. timer.toc()
  181. # save results
  182. results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
  183. online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1,
  184. fps=1. / timer.average_time)
  185. #result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
  186. if save_result:
  187. save_folder = os.path.join(
  188. vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
  189. )
  190. os.makedirs(save_folder, exist_ok=True)
  191. save_file_name = os.path.join(save_folder, os.path.basename(image_name))
  192. cv2.imwrite(save_file_name, online_im)
  193. ch = cv2.waitKey(0)
  194. frame_id += 1
  195. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  196. break
  197. #write_results(result_filename, results)
  198. def imageflow_demo(predictor, vis_folder, current_time, args):
  199. cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
  200. width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
  201. height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
  202. # fps = cap.get(cv2.CAP_PROP_FPS)
  203. fps = 30
  204. save_folder = os.path.join(
  205. vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
  206. )
  207. os.makedirs(save_folder, exist_ok=True)
  208. if args.demo == "video":
  209. save_path = os.path.join(save_folder, args.path.split("/")[-1])
  210. else:
  211. save_path = os.path.join(save_folder, "camera.mp4")
  212. logger.info(f"video save_path is {save_path}")
  213. vid_writer = cv2.VideoWriter(
  214. save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
  215. )
  216. tracker = BYTETracker(args, frame_rate=30)
  217. timer = Timer()
  218. frame_id = 0
  219. results = []
  220. while True:
  221. if frame_id % 20 == 0:
  222. logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
  223. ret_val, frame = cap.read()
  224. if ret_val:
  225. outputs, img_info = predictor.inference(frame, timer)
  226. online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
  227. online_tlwhs = []
  228. online_ids = []
  229. online_scores = []
  230. for t in online_targets:
  231. tlwh = t.tlwh
  232. tid = t.track_id
  233. vertical = tlwh[2] / tlwh[3] > 1.6
  234. if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
  235. online_tlwhs.append(tlwh)
  236. online_ids.append(tid)
  237. online_scores.append(t.score)
  238. timer.toc()
  239. results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
  240. online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1,
  241. fps=1. / timer.average_time)
  242. cv2.imshow("demo", online_im)
  243. if args.save_result:
  244. vid_writer.write(online_im)
  245. ch = cv2.waitKey(1)
  246. if ch == 27 or ch == ord("q") or ch == ord("Q"):
  247. break
  248. else:
  249. break
  250. frame_id += 1
  251. def main(exp, args):
  252. torch.cuda.set_device(0)
  253. if not args.experiment_name:
  254. args.experiment_name = exp.exp_name
  255. file_name = os.path.join(exp.output_dir, args.experiment_name)
  256. os.makedirs(file_name, exist_ok=True)
  257. if args.save_result:
  258. vis_folder = os.path.join(file_name, "track_vis")
  259. os.makedirs(vis_folder, exist_ok=True)
  260. if args.trt:
  261. args.device = "gpu"
  262. logger.info("Args: {}".format(args))
  263. if args.conf is not None:
  264. exp.test_conf = args.conf
  265. if args.nms is not None:
  266. exp.nmsthre = args.nms
  267. if args.tsize is not None:
  268. exp.test_size = (args.tsize, args.tsize)
  269. model = exp.get_model()
  270. logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
  271. if args.device == "gpu":
  272. model.cuda()
  273. model.eval()
  274. if not args.trt:
  275. if args.ckpt is None:
  276. ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
  277. else:
  278. ckpt_file = args.ckpt
  279. logger.info("loading checkpoint")
  280. ckpt = torch.load(ckpt_file, map_location="cpu")
  281. # load the model state dict
  282. model.load_state_dict(ckpt["model"])
  283. logger.info("loaded checkpoint done.")
  284. if args.fuse:
  285. logger.info("\tFusing model...")
  286. model = fuse_model(model)
  287. if args.fp16:
  288. model = model.half() # to FP16
  289. if args.trt:
  290. assert not args.fuse, "TensorRT model is not support model fusing!"
  291. trt_file = os.path.join(file_name, "model_trt.pth")
  292. trt_file = os.path.join("/data/humaocheng/monitor/ByteTrack", trt_file)
  293. assert os.path.exists(
  294. trt_file
  295. ), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
  296. model.head.decode_in_inference = False
  297. decoder = model.head.decode_outputs
  298. logger.info("Using TensorRT to inference")
  299. else:
  300. trt_file = None
  301. decoder = None
  302. predictor = Predictor(model, exp, trt_file, decoder, args.device, args.fp16)
  303. current_time = time.localtime()
  304. if args.demo == "image":
  305. image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
  306. elif args.demo == "video" or args.demo == "webcam":
  307. imageflow_demo(predictor, vis_folder, current_time, args)
  308. if __name__ == "__main__":
  309. args = make_parser().parse_args()
  310. exp = get_exp(args.exp_file, args.name)
  311. main(exp, args)