track_motdt.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. from loguru import logger
  2. import torch
  3. import torch.backends.cudnn as cudnn
  4. from torch.nn.parallel import DistributedDataParallel as DDP
  5. from yolox.core import launch
  6. from yolox.exp import get_exp
  7. from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
  8. from yolox.evaluators import MOTEvaluator
  9. import argparse
  10. import os
  11. import random
  12. import warnings
  13. import glob
  14. import motmetrics as mm
  15. from collections import OrderedDict
  16. from pathlib import Path
  17. def make_parser():
  18. parser = argparse.ArgumentParser("YOLOX Eval")
  19. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  20. parser.add_argument("-n", "--name", type=str, default=None, help="model name")
  21. # distributed
  22. parser.add_argument(
  23. "--dist-backend", default="nccl", type=str, help="distributed backend"
  24. )
  25. parser.add_argument(
  26. "--dist-url",
  27. default=None,
  28. type=str,
  29. help="url used to set up distributed training",
  30. )
  31. parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
  32. parser.add_argument(
  33. "-d", "--devices", default=None, type=int, help="device for training"
  34. )
  35. parser.add_argument(
  36. "--local_rank", default=0, type=int, help="local rank for dist training"
  37. )
  38. parser.add_argument(
  39. "--num_machines", default=1, type=int, help="num of node for training"
  40. )
  41. parser.add_argument(
  42. "--machine_rank", default=0, type=int, help="node rank for multi-node training"
  43. )
  44. parser.add_argument(
  45. "-f",
  46. "--exp_file",
  47. default=None,
  48. type=str,
  49. help="pls input your expriment description file",
  50. )
  51. parser.add_argument(
  52. "--fp16",
  53. dest="fp16",
  54. default=False,
  55. action="store_true",
  56. help="Adopting mix precision evaluating.",
  57. )
  58. parser.add_argument(
  59. "--fuse",
  60. dest="fuse",
  61. default=False,
  62. action="store_true",
  63. help="Fuse conv and bn for testing.",
  64. )
  65. parser.add_argument(
  66. "--trt",
  67. dest="trt",
  68. default=False,
  69. action="store_true",
  70. help="Using TensorRT model for testing.",
  71. )
  72. parser.add_argument(
  73. "--test",
  74. dest="test",
  75. default=False,
  76. action="store_true",
  77. help="Evaluating on test-dev set.",
  78. )
  79. parser.add_argument(
  80. "--speed",
  81. dest="speed",
  82. default=False,
  83. action="store_true",
  84. help="speed test only.",
  85. )
  86. parser.add_argument(
  87. "opts",
  88. help="Modify config_files options using the command-line",
  89. default=None,
  90. nargs=argparse.REMAINDER,
  91. )
  92. # det args
  93. parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
  94. parser.add_argument("--conf", default=0.1, type=float, help="test conf")
  95. parser.add_argument("--nms", default=0.7, type=float, help="test nms threshold")
  96. parser.add_argument("--tsize", default=None, type=int, help="test img size")
  97. parser.add_argument("--seed", default=None, type=int, help="eval seed")
  98. # tracking args
  99. parser.add_argument("--track_thresh", type=float, default=0.6, help="tracking confidence threshold")
  100. parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
  101. parser.add_argument("--match_thresh", type=float, default=0.9, help="matching threshold for tracking")
  102. parser.add_argument('--min-box-area', type=float, default=100, help='filter out tiny boxes')
  103. # deepsort args
  104. parser.add_argument("--model_folder", type=str, default='pretrained/googlenet_part8_all_xavier_ckpt_56.h5', help="reid model folder")
  105. return parser
  106. def compare_dataframes(gts, ts):
  107. accs = []
  108. names = []
  109. for k, tsacc in ts.items():
  110. if k in gts:
  111. logger.info('Comparing {}...'.format(k))
  112. accs.append(mm.utils.compare_to_groundtruth(gts[k], tsacc, 'iou', distth=0.5))
  113. names.append(k)
  114. else:
  115. logger.warning('No ground truth for {}, skipping.'.format(k))
  116. return accs, names
  117. @logger.catch
  118. def main(exp, args, num_gpu):
  119. if args.seed is not None:
  120. random.seed(args.seed)
  121. torch.manual_seed(args.seed)
  122. cudnn.deterministic = True
  123. warnings.warn(
  124. "You have chosen to seed testing. This will turn on the CUDNN deterministic setting, "
  125. )
  126. is_distributed = num_gpu > 1
  127. # set environment variables for distributed training
  128. cudnn.benchmark = True
  129. rank = args.local_rank
  130. # rank = get_local_rank()
  131. file_name = os.path.join(exp.output_dir, args.experiment_name)
  132. if rank == 0:
  133. os.makedirs(file_name, exist_ok=True)
  134. results_folder = os.path.join(file_name, "track_results_motdt")
  135. os.makedirs(results_folder, exist_ok=True)
  136. model_folder = args.model_folder
  137. setup_logger(file_name, distributed_rank=rank, filename="val_log.txt", mode="a")
  138. logger.info("Args: {}".format(args))
  139. if args.conf is not None:
  140. exp.test_conf = args.conf
  141. if args.nms is not None:
  142. exp.nmsthre = args.nms
  143. if args.tsize is not None:
  144. exp.test_size = (args.tsize, args.tsize)
  145. model = exp.get_model()
  146. logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
  147. #logger.info("Model Structure:\n{}".format(str(model)))
  148. #evaluator = exp.get_evaluator(args.batch_size, is_distributed, args.test)
  149. val_loader = exp.get_eval_loader(args.batch_size, is_distributed, args.test)
  150. evaluator = MOTEvaluator(
  151. args=args,
  152. dataloader=val_loader,
  153. img_size=exp.test_size,
  154. confthre=exp.test_conf,
  155. nmsthre=exp.nmsthre,
  156. num_classes=exp.num_classes,
  157. )
  158. torch.cuda.set_device(rank)
  159. model.cuda(rank)
  160. model.eval()
  161. if not args.speed and not args.trt:
  162. if args.ckpt is None:
  163. ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")
  164. else:
  165. ckpt_file = args.ckpt
  166. logger.info("loading checkpoint")
  167. loc = "cuda:{}".format(rank)
  168. ckpt = torch.load(ckpt_file, map_location=loc)
  169. # load the model state dict
  170. model.load_state_dict(ckpt["model"])
  171. logger.info("loaded checkpoint done.")
  172. if is_distributed:
  173. model = DDP(model, device_ids=[rank])
  174. if args.fuse:
  175. logger.info("\tFusing model...")
  176. model = fuse_model(model)
  177. if args.trt:
  178. assert (
  179. not args.fuse and not is_distributed and args.batch_size == 1
  180. ), "TensorRT model is not support model fusing and distributed inferencing!"
  181. trt_file = os.path.join(file_name, "model_trt.pth")
  182. assert os.path.exists(
  183. trt_file
  184. ), "TensorRT model is not found!\n Run tools/trt.py first!"
  185. model.head.decode_in_inference = False
  186. decoder = model.head.decode_outputs
  187. else:
  188. trt_file = None
  189. decoder = None
  190. # start evaluate
  191. *_, summary = evaluator.evaluate_motdt(
  192. model, is_distributed, args.fp16, trt_file, decoder, exp.test_size, results_folder, model_folder
  193. )
  194. logger.info("\n" + summary)
  195. # evaluate MOTA
  196. mm.lap.default_solver = 'lap'
  197. gt_type = '_val_half'
  198. #gt_type = ''
  199. print('gt_type', gt_type)
  200. gtfiles = glob.glob(
  201. os.path.join('datasets/mot/train', '*/gt/gt{}.txt'.format(gt_type)))
  202. print('gt_files', gtfiles)
  203. tsfiles = [f for f in glob.glob(os.path.join(results_folder, '*.txt')) if not os.path.basename(f).startswith('eval')]
  204. logger.info('Found {} groundtruths and {} test files.'.format(len(gtfiles), len(tsfiles)))
  205. logger.info('Available LAP solvers {}'.format(mm.lap.available_solvers))
  206. logger.info('Default LAP solver \'{}\''.format(mm.lap.default_solver))
  207. logger.info('Loading files.')
  208. gt = OrderedDict([(Path(f).parts[-3], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=1)) for f in gtfiles])
  209. ts = OrderedDict([(os.path.splitext(Path(f).parts[-1])[0], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=-1)) for f in tsfiles])
  210. mh = mm.metrics.create()
  211. accs, names = compare_dataframes(gt, ts)
  212. logger.info('Running metrics')
  213. metrics = ['recall', 'precision', 'num_unique_objects', 'mostly_tracked',
  214. 'partially_tracked', 'mostly_lost', 'num_false_positives', 'num_misses',
  215. 'num_switches', 'num_fragmentations', 'mota', 'motp', 'num_objects']
  216. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  217. # summary = mh.compute_many(accs, names=names, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
  218. # print(mm.io.render_summary(
  219. # summary, formatters=mh.formatters,
  220. # namemap=mm.io.motchallenge_metric_names))
  221. div_dict = {
  222. 'num_objects': ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations'],
  223. 'num_unique_objects': ['mostly_tracked', 'partially_tracked', 'mostly_lost']}
  224. for divisor in div_dict:
  225. for divided in div_dict[divisor]:
  226. summary[divided] = (summary[divided] / summary[divisor])
  227. fmt = mh.formatters
  228. change_fmt_list = ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations', 'mostly_tracked',
  229. 'partially_tracked', 'mostly_lost']
  230. for k in change_fmt_list:
  231. fmt[k] = fmt['mota']
  232. print(mm.io.render_summary(summary, formatters=fmt, namemap=mm.io.motchallenge_metric_names))
  233. metrics = mm.metrics.motchallenge_metrics + ['num_objects']
  234. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  235. print(mm.io.render_summary(summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names))
  236. logger.info('Completed')
  237. if __name__ == "__main__":
  238. args = make_parser().parse_args()
  239. exp = get_exp(args.exp_file, args.name)
  240. exp.merge(args.opts)
  241. if not args.experiment_name:
  242. args.experiment_name = exp.exp_name
  243. num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
  244. assert num_gpu <= torch.cuda.device_count()
  245. launch(
  246. main,
  247. num_gpu,
  248. args.num_machines,
  249. args.machine_rank,
  250. backend=args.dist_backend,
  251. dist_url=args.dist_url,
  252. args=(exp, args, num_gpu),
  253. )