123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- from loguru import logger
- import cv2
- import torch
- from yolox.data.data_augment import preproc
- from yolox.exp import get_exp
- from yolox.utils import fuse_model, get_model_info, postprocess, vis
- from yolox.utils.visualize import plot_tracking
- from yolox.tracker.byte_tracker import BYTETracker
- from yolox.tracking_utils.timer import Timer
- import argparse
- import os
- import time
- IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
- def make_parser():
- parser = argparse.ArgumentParser("ByteTrack Demo!")
- parser.add_argument(
- "demo", default="image", help="demo type, eg. image, video and webcam"
- )
- parser.add_argument("-expn", "--experiment-name", type=str, default=None)
- parser.add_argument("-n", "--name", type=str, default=None, help="model name")
- parser.add_argument(
- #"--path", default="./datasets/mot/train/MOT17-05-FRCNN/img1", help="path to images or video"
- "--path", default="./videos/palace.mp4", help="path to images or video"
- )
- parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
- parser.add_argument(
- "--save_result",
- action="store_true",
- help="whether to save the inference result of image/video",
- )
- # exp file
- parser.add_argument(
- "-f",
- "--exp_file",
- default=None,
- type=str,
- help="pls input your expriment description file",
- )
- parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
- parser.add_argument(
- "--device",
- default="gpu",
- type=str,
- help="device to run our model, can either be cpu or gpu",
- )
- parser.add_argument("--conf", default=None, type=float, help="test conf")
- parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
- parser.add_argument("--tsize", default=None, type=int, help="test img size")
- parser.add_argument(
- "--fp16",
- dest="fp16",
- default=False,
- action="store_true",
- help="Adopting mix precision evaluating.",
- )
- parser.add_argument(
- "--fuse",
- dest="fuse",
- default=False,
- action="store_true",
- help="Fuse conv and bn for testing.",
- )
- parser.add_argument(
- "--trt",
- dest="trt",
- default=False,
- action="store_true",
- help="Using TensorRT model for testing.",
- )
- # tracking args
- parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
- parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
- parser.add_argument("--match_thresh", type=int, default=0.8, help="matching threshold for tracking")
- parser.add_argument('--min-box-area', type=float, default=10, help='filter out tiny boxes')
- parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
- return parser
- def get_image_list(path):
- image_names = []
- for maindir, subdir, file_name_list in os.walk(path):
- for filename in file_name_list:
- apath = os.path.join(maindir, filename)
- ext = os.path.splitext(apath)[1]
- if ext in IMAGE_EXT:
- image_names.append(apath)
- return image_names
- def write_results(filename, results):
- save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n'
- with open(filename, 'w') as f:
- for frame_id, tlwhs, track_ids, scores in results:
- for tlwh, track_id, score in zip(tlwhs, track_ids, scores):
- if track_id < 0:
- continue
- x1, y1, w, h = tlwh
- line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1), s=round(score, 2))
- f.write(line)
- logger.info('save results to {}'.format(filename))
- class Predictor(object):
- def __init__(
- self,
- model,
- exp,
- trt_file=None,
- decoder=None,
- device="cpu",
- fp16=False
- ):
- self.model = model
- self.decoder = decoder
- self.num_classes = exp.num_classes
- self.confthre = exp.test_conf
- self.nmsthre = exp.nmsthre
- self.test_size = exp.test_size
- self.device = device
- self.fp16 = fp16
- if trt_file is not None:
- from torch2trt import TRTModule
- model_trt = TRTModule()
- model_trt.load_state_dict(torch.load(trt_file))
- x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
- self.model(x)
- self.model = model_trt
- self.rgb_means = (0.485, 0.456, 0.406)
- self.std = (0.229, 0.224, 0.225)
- def inference(self, img, timer):
- img_info = {"id": 0}
- if isinstance(img, str):
- img_info["file_name"] = os.path.basename(img)
- img = cv2.imread(img)
- else:
- img_info["file_name"] = None
- height, width = img.shape[:2]
- img_info["height"] = height
- img_info["width"] = width
- img_info["raw_img"] = img
- img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
- img_info["ratio"] = ratio
- img = torch.from_numpy(img)
- img = img.unsqueeze(0)
- img = img.float()
- if self.device == "gpu":
- img = img.cuda()
- if self.fp16:
- img = img.half() # to FP16
- with torch.no_grad():
- timer.tic()
- outputs = self.model(img)
- if self.decoder is not None:
- outputs = self.decoder(outputs, dtype=outputs.type())
- outputs = postprocess(
- outputs, self.num_classes, self.confthre, self.nmsthre
- )
- #logger.info("Infer time: {:.4f}s".format(time.time() - t0))
- return outputs, img_info
- def image_demo(predictor, vis_folder, path, current_time, save_result):
- if os.path.isdir(path):
- files = get_image_list(path)
- else:
- files = [path]
- files.sort()
- tracker = BYTETracker(args, frame_rate=30)
- timer = Timer()
- frame_id = 0
- results = []
- for image_name in files:
- if frame_id % 20 == 0:
- logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
- outputs, img_info = predictor.inference(image_name, timer)
- online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
- online_tlwhs = []
- online_ids = []
- online_scores = []
- for t in online_targets:
- tlwh = t.tlwh
- tid = t.track_id
- vertical = tlwh[2] / tlwh[3] > 1.6
- if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
- online_tlwhs.append(tlwh)
- online_ids.append(tid)
- online_scores.append(t.score)
- timer.toc()
- # save results
- results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
- online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1,
- fps=1. / timer.average_time)
- #result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
- if save_result:
- save_folder = os.path.join(
- vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
- )
- os.makedirs(save_folder, exist_ok=True)
- save_file_name = os.path.join(save_folder, os.path.basename(image_name))
- cv2.imwrite(save_file_name, online_im)
- ch = cv2.waitKey(0)
- frame_id += 1
- if ch == 27 or ch == ord("q") or ch == ord("Q"):
- break
- #write_results(result_filename, results)
- def imageflow_demo(predictor, vis_folder, current_time, args):
- cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
- width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
- height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
- # fps = cap.get(cv2.CAP_PROP_FPS)
- fps = 30
- save_folder = os.path.join(
- vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
- )
- os.makedirs(save_folder, exist_ok=True)
- if args.demo == "video":
- save_path = os.path.join(save_folder, args.path.split("/")[-1])
- else:
- save_path = os.path.join(save_folder, "camera.mp4")
- logger.info(f"video save_path is {save_path}")
- vid_writer = cv2.VideoWriter(
- save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
- )
- tracker = BYTETracker(args, frame_rate=30)
- timer = Timer()
- frame_id = 0
- results = []
- while True:
- if frame_id % 20 == 0:
- logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
- ret_val, frame = cap.read()
- if ret_val:
- outputs, img_info = predictor.inference(frame, timer)
- online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
- online_tlwhs = []
- online_ids = []
- online_scores = []
- for t in online_targets:
- tlwh = t.tlwh
- tid = t.track_id
- vertical = tlwh[2] / tlwh[3] > 1.6
- if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
- online_tlwhs.append(tlwh)
- online_ids.append(tid)
- online_scores.append(t.score)
- timer.toc()
- results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
- online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1,
- fps=1. / timer.average_time)
- cv2.imshow("demo", online_im)
- if args.save_result:
- vid_writer.write(online_im)
- ch = cv2.waitKey(1)
- if ch == 27 or ch == ord("q") or ch == ord("Q"):
- break
- else:
- break
- frame_id += 1
- def main(exp, args):
- torch.cuda.set_device(0)
- if not args.experiment_name:
- args.experiment_name = exp.exp_name
- file_name = os.path.join(exp.output_dir, args.experiment_name)
- os.makedirs(file_name, exist_ok=True)
- if args.save_result:
- vis_folder = os.path.join(file_name, "track_vis")
- os.makedirs(vis_folder, exist_ok=True)
- if args.trt:
- args.device = "gpu"
- logger.info("Args: {}".format(args))
- if args.conf is not None:
- exp.test_conf = args.conf
- if args.nms is not None:
- exp.nmsthre = args.nms
- if args.tsize is not None:
- exp.test_size = (args.tsize, args.tsize)
- model = exp.get_model()
- logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
- if args.device == "gpu":
- model.cuda()
- model.eval()
- if not args.trt:
- if args.ckpt is None:
- ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
- else:
- ckpt_file = args.ckpt
- logger.info("loading checkpoint")
- ckpt = torch.load(ckpt_file, map_location="cpu")
- # load the model state dict
- model.load_state_dict(ckpt["model"])
- logger.info("loaded checkpoint done.")
- if args.fuse:
- logger.info("\tFusing model...")
- model = fuse_model(model)
-
- if args.fp16:
- model = model.half() # to FP16
- if args.trt:
- assert not args.fuse, "TensorRT model is not support model fusing!"
- trt_file = os.path.join(file_name, "model_trt.pth")
- trt_file = os.path.join("/data/humaocheng/monitor/ByteTrack", trt_file)
- assert os.path.exists(
- trt_file
- ), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
- model.head.decode_in_inference = False
- decoder = model.head.decode_outputs
- logger.info("Using TensorRT to inference")
- else:
- trt_file = None
- decoder = None
- predictor = Predictor(model, exp, trt_file, decoder, args.device, args.fp16)
- current_time = time.localtime()
- if args.demo == "image":
- image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
- elif args.demo == "video" or args.demo == "webcam":
- imageflow_demo(predictor, vis_folder, current_time, args)
- if __name__ == "__main__":
- args = make_parser().parse_args()
- exp = get_exp(args.exp_file, args.name)
- main(exp, args)
|