track_sort.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. from loguru import logger
  2. import torch
  3. import torch.backends.cudnn as cudnn
  4. from torch.nn.parallel import DistributedDataParallel as DDP
  5. from yolox.core import launch
  6. from yolox.exp import get_exp
  7. from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
  8. from yolox.evaluators import MOTEvaluator
  9. import argparse
  10. import os
  11. import random
  12. import warnings
  13. import glob
  14. import motmetrics as mm
  15. from collections import OrderedDict
  16. from pathlib import Path
  17. def make_parser():
  18. parser = argparse.ArgumentParser("YOLOX Eval")
  19. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  20. parser.add_argument("-n", "--name", type=str, default=None, help="model name")
  21. # distributed
  22. parser.add_argument(
  23. "--dist-backend", default="nccl", type=str, help="distributed backend"
  24. )
  25. parser.add_argument(
  26. "--dist-url",
  27. default=None,
  28. type=str,
  29. help="url used to set up distributed training",
  30. )
  31. parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
  32. parser.add_argument(
  33. "-d", "--devices", default=None, type=int, help="device for training"
  34. )
  35. parser.add_argument(
  36. "--local_rank", default=0, type=int, help="local rank for dist training"
  37. )
  38. parser.add_argument(
  39. "--num_machines", default=1, type=int, help="num of node for training"
  40. )
  41. parser.add_argument(
  42. "--machine_rank", default=0, type=int, help="node rank for multi-node training"
  43. )
  44. parser.add_argument(
  45. "-f",
  46. "--exp_file",
  47. default=None,
  48. type=str,
  49. help="pls input your expriment description file",
  50. )
  51. parser.add_argument(
  52. "--fp16",
  53. dest="fp16",
  54. default=False,
  55. action="store_true",
  56. help="Adopting mix precision evaluating.",
  57. )
  58. parser.add_argument(
  59. "--fuse",
  60. dest="fuse",
  61. default=False,
  62. action="store_true",
  63. help="Fuse conv and bn for testing.",
  64. )
  65. parser.add_argument(
  66. "--trt",
  67. dest="trt",
  68. default=False,
  69. action="store_true",
  70. help="Using TensorRT model for testing.",
  71. )
  72. parser.add_argument(
  73. "--test",
  74. dest="test",
  75. default=False,
  76. action="store_true",
  77. help="Evaluating on test-dev set.",
  78. )
  79. parser.add_argument(
  80. "--speed",
  81. dest="speed",
  82. default=False,
  83. action="store_true",
  84. help="speed test only.",
  85. )
  86. parser.add_argument(
  87. "opts",
  88. help="Modify config_files options using the command-line",
  89. default=None,
  90. nargs=argparse.REMAINDER,
  91. )
  92. # det args
  93. parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
  94. parser.add_argument("--conf", default=0.1, type=float, help="test conf")
  95. parser.add_argument("--nms", default=0.7, type=float, help="test nms threshold")
  96. parser.add_argument("--tsize", default=None, type=int, help="test img size")
  97. parser.add_argument("--seed", default=None, type=int, help="eval seed")
  98. # tracking args
  99. parser.add_argument("--track_thresh", type=float, default=0.4, help="tracking confidence threshold")
  100. parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
  101. parser.add_argument("--match_thresh", type=float, default=0.9, help="matching threshold for tracking")
  102. parser.add_argument('--min-box-area', type=float, default=100, help='filter out tiny boxes')
  103. return parser
  104. def compare_dataframes(gts, ts):
  105. accs = []
  106. names = []
  107. for k, tsacc in ts.items():
  108. if k in gts:
  109. logger.info('Comparing {}...'.format(k))
  110. accs.append(mm.utils.compare_to_groundtruth(gts[k], tsacc, 'iou', distth=0.5))
  111. names.append(k)
  112. else:
  113. logger.warning('No ground truth for {}, skipping.'.format(k))
  114. return accs, names
  115. @logger.catch
  116. def main(exp, args, num_gpu):
  117. if args.seed is not None:
  118. random.seed(args.seed)
  119. torch.manual_seed(args.seed)
  120. cudnn.deterministic = True
  121. warnings.warn(
  122. "You have chosen to seed testing. This will turn on the CUDNN deterministic setting, "
  123. )
  124. is_distributed = num_gpu > 1
  125. # set environment variables for distributed training
  126. cudnn.benchmark = True
  127. rank = args.local_rank
  128. # rank = get_local_rank()
  129. file_name = os.path.join(exp.output_dir, args.experiment_name)
  130. if rank == 0:
  131. os.makedirs(file_name, exist_ok=True)
  132. results_folder = os.path.join(file_name, "track_results_sort")
  133. os.makedirs(results_folder, exist_ok=True)
  134. setup_logger(file_name, distributed_rank=rank, filename="val_log.txt", mode="a")
  135. logger.info("Args: {}".format(args))
  136. if args.conf is not None:
  137. exp.test_conf = args.conf
  138. if args.nms is not None:
  139. exp.nmsthre = args.nms
  140. if args.tsize is not None:
  141. exp.test_size = (args.tsize, args.tsize)
  142. model = exp.get_model()
  143. logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
  144. #logger.info("Model Structure:\n{}".format(str(model)))
  145. #evaluator = exp.get_evaluator(args.batch_size, is_distributed, args.test)
  146. val_loader = exp.get_eval_loader(args.batch_size, is_distributed, args.test)
  147. evaluator = MOTEvaluator(
  148. args=args,
  149. dataloader=val_loader,
  150. img_size=exp.test_size,
  151. confthre=exp.test_conf,
  152. nmsthre=exp.nmsthre,
  153. num_classes=exp.num_classes,
  154. )
  155. torch.cuda.set_device(rank)
  156. model.cuda(rank)
  157. model.eval()
  158. if not args.speed and not args.trt:
  159. if args.ckpt is None:
  160. ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
  161. else:
  162. ckpt_file = args.ckpt
  163. logger.info("loading checkpoint")
  164. loc = "cuda:{}".format(rank)
  165. ckpt = torch.load(ckpt_file, map_location=loc)
  166. # load the model state dict
  167. model.load_state_dict(ckpt["model"])
  168. logger.info("loaded checkpoint done.")
  169. if is_distributed:
  170. model = DDP(model, device_ids=[rank])
  171. if args.fuse:
  172. logger.info("\tFusing model...")
  173. model = fuse_model(model)
  174. if args.trt:
  175. assert (
  176. not args.fuse and not is_distributed and args.batch_size == 1
  177. ), "TensorRT model is not support model fusing and distributed inferencing!"
  178. trt_file = os.path.join(file_name, "model_trt.pth")
  179. assert os.path.exists(
  180. trt_file
  181. ), "TensorRT model is not found!\n Run tools/trt.py first!"
  182. model.head.decode_in_inference = False
  183. decoder = model.head.decode_outputs
  184. else:
  185. trt_file = None
  186. decoder = None
  187. # start evaluate
  188. *_, summary = evaluator.evaluate_sort(
  189. model, is_distributed, args.fp16, trt_file, decoder, exp.test_size, results_folder
  190. )
  191. logger.info("\n" + summary)
  192. # evaluate MOTA
  193. mm.lap.default_solver = 'lap'
  194. gt_type = '_val_half'
  195. #gt_type = ''
  196. print('gt_type', gt_type)
  197. gtfiles = glob.glob(
  198. os.path.join('datasets/mot/train', '*/gt/gt{}.txt'.format(gt_type)))
  199. print('gt_files', gtfiles)
  200. tsfiles = [f for f in glob.glob(os.path.join(results_folder, '*.txt')) if not os.path.basename(f).startswith('eval')]
  201. logger.info('Found {} groundtruths and {} test files.'.format(len(gtfiles), len(tsfiles)))
  202. logger.info('Available LAP solvers {}'.format(mm.lap.available_solvers))
  203. logger.info('Default LAP solver \'{}\''.format(mm.lap.default_solver))
  204. logger.info('Loading files.')
  205. gt = OrderedDict([(Path(f).parts[-3], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=1)) for f in gtfiles])
  206. ts = OrderedDict([(os.path.splitext(Path(f).parts[-1])[0], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=-1)) for f in tsfiles])
  207. mh = mm.metrics.create()
  208. accs, names = compare_dataframes(gt, ts)
  209. logger.info('Running metrics')
  210. metrics = ['recall', 'precision', 'num_unique_objects', 'mostly_tracked',
  211. 'partially_tracked', 'mostly_lost', 'num_false_positives', 'num_misses',
  212. 'num_switches', 'num_fragmentations', 'mota', 'motp', 'num_objects']
  213. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  214. # summary = mh.compute_many(accs, names=names, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
  215. # print(mm.io.render_summary(
  216. # summary, formatters=mh.formatters,
  217. # namemap=mm.io.motchallenge_metric_names))
  218. div_dict = {
  219. 'num_objects': ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations'],
  220. 'num_unique_objects': ['mostly_tracked', 'partially_tracked', 'mostly_lost']}
  221. for divisor in div_dict:
  222. for divided in div_dict[divisor]:
  223. summary[divided] = (summary[divided] / summary[divisor])
  224. fmt = mh.formatters
  225. change_fmt_list = ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations', 'mostly_tracked',
  226. 'partially_tracked', 'mostly_lost']
  227. for k in change_fmt_list:
  228. fmt[k] = fmt['mota']
  229. print(mm.io.render_summary(summary, formatters=fmt, namemap=mm.io.motchallenge_metric_names))
  230. metrics = mm.metrics.motchallenge_metrics + ['num_objects']
  231. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  232. print(mm.io.render_summary(summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names))
  233. logger.info('Completed')
  234. if __name__ == "__main__":
  235. args = make_parser().parse_args()
  236. exp = get_exp(args.exp_file, args.name)
  237. exp.merge(args.opts)
  238. if not args.experiment_name:
  239. args.experiment_name = exp.exp_name
  240. num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
  241. assert num_gpu <= torch.cuda.device_count()
  242. launch(
  243. main,
  244. num_gpu,
  245. args.num_machines,
  246. args.machine_rank,
  247. backend=args.dist_backend,
  248. dist_url=args.dist_url,
  249. args=(exp, args, num_gpu),
  250. )