mota.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from loguru import logger
  2. import torch
  3. import torch.backends.cudnn as cudnn
  4. from torch.nn.parallel import DistributedDataParallel as DDP
  5. from yolox.core import launch
  6. from yolox.exp import get_exp
  7. from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
  8. from yolox.evaluators import MOTEvaluator
  9. import argparse
  10. import os
  11. import random
  12. import warnings
  13. import glob
  14. import motmetrics as mm
  15. from collections import OrderedDict
  16. from pathlib import Path
  17. def compare_dataframes(gts, ts):
  18. accs = []
  19. names = []
  20. for k, tsacc in ts.items():
  21. if k in gts:
  22. logger.info('Comparing {}...'.format(k))
  23. accs.append(mm.utils.compare_to_groundtruth(gts[k], tsacc, 'iou', distth=0.5))
  24. names.append(k)
  25. else:
  26. logger.warning('No ground truth for {}, skipping.'.format(k))
  27. return accs, names
  28. # evaluate MOTA
  29. results_folder = 'YOLOX_outputs/yolox_x_ablation/track_results'
  30. mm.lap.default_solver = 'lap'
  31. gt_type = '_val_half'
  32. #gt_type = ''
  33. print('gt_type', gt_type)
  34. gtfiles = glob.glob(
  35. os.path.join('datasets/mot/train', '*/gt/gt{}.txt'.format(gt_type)))
  36. print('gt_files', gtfiles)
  37. tsfiles = [f for f in glob.glob(os.path.join(results_folder, '*.txt')) if not os.path.basename(f).startswith('eval')]
  38. logger.info('Found {} groundtruths and {} test files.'.format(len(gtfiles), len(tsfiles)))
  39. logger.info('Available LAP solvers {}'.format(mm.lap.available_solvers))
  40. logger.info('Default LAP solver \'{}\''.format(mm.lap.default_solver))
  41. logger.info('Loading files.')
  42. gt = OrderedDict([(Path(f).parts[-3], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=1)) for f in gtfiles])
  43. ts = OrderedDict([(os.path.splitext(Path(f).parts[-1])[0], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=-1.0)) for f in tsfiles])
  44. mh = mm.metrics.create()
  45. accs, names = compare_dataframes(gt, ts)
  46. logger.info('Running metrics')
  47. metrics = ['recall', 'precision', 'num_unique_objects', 'mostly_tracked',
  48. 'partially_tracked', 'mostly_lost', 'num_false_positives', 'num_misses',
  49. 'num_switches', 'num_fragmentations', 'mota', 'motp', 'num_objects']
  50. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  51. # summary = mh.compute_many(accs, names=names, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
  52. # print(mm.io.render_summary(
  53. # summary, formatters=mh.formatters,
  54. # namemap=mm.io.motchallenge_metric_names))
  55. div_dict = {
  56. 'num_objects': ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations'],
  57. 'num_unique_objects': ['mostly_tracked', 'partially_tracked', 'mostly_lost']}
  58. for divisor in div_dict:
  59. for divided in div_dict[divisor]:
  60. summary[divided] = (summary[divided] / summary[divisor])
  61. fmt = mh.formatters
  62. change_fmt_list = ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations', 'mostly_tracked',
  63. 'partially_tracked', 'mostly_lost']
  64. for k in change_fmt_list:
  65. fmt[k] = fmt['mota']
  66. print(mm.io.render_summary(summary, formatters=fmt, namemap=mm.io.motchallenge_metric_names))
  67. metrics = mm.metrics.motchallenge_metrics + ['num_objects']
  68. summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
  69. print(mm.io.render_summary(summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names))
  70. logger.info('Completed')