123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- from loguru import logger
- import torch
- import torch.backends.cudnn as cudnn
- from torch.nn.parallel import DistributedDataParallel as DDP
- from yolox.core import launch
- from yolox.exp import get_exp
- from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
- from yolox.evaluators import MOTEvaluator
- import argparse
- import os
- import random
- import warnings
- import glob
- import motmetrics as mm
- from collections import OrderedDict
- from pathlib import Path
- def compare_dataframes(gts, ts):
- accs = []
- names = []
- for k, tsacc in ts.items():
- if k in gts:
- logger.info('Comparing {}...'.format(k))
- accs.append(mm.utils.compare_to_groundtruth(gts[k], tsacc, 'iou', distth=0.5))
- names.append(k)
- else:
- logger.warning('No ground truth for {}, skipping.'.format(k))
- return accs, names
- # evaluate MOTA
- results_folder = 'YOLOX_outputs/yolox_x_ablation/track_results'
- mm.lap.default_solver = 'lap'
- gt_type = '_val_half'
- #gt_type = ''
- print('gt_type', gt_type)
- gtfiles = glob.glob(
- os.path.join('datasets/mot/train', '*/gt/gt{}.txt'.format(gt_type)))
- print('gt_files', gtfiles)
- tsfiles = [f for f in glob.glob(os.path.join(results_folder, '*.txt')) if not os.path.basename(f).startswith('eval')]
- logger.info('Found {} groundtruths and {} test files.'.format(len(gtfiles), len(tsfiles)))
- logger.info('Available LAP solvers {}'.format(mm.lap.available_solvers))
- logger.info('Default LAP solver \'{}\''.format(mm.lap.default_solver))
- logger.info('Loading files.')
- gt = OrderedDict([(Path(f).parts[-3], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=1)) for f in gtfiles])
- ts = OrderedDict([(os.path.splitext(Path(f).parts[-1])[0], mm.io.loadtxt(f, fmt='mot15-2D', min_confidence=-1.0)) for f in tsfiles])
- mh = mm.metrics.create()
- accs, names = compare_dataframes(gt, ts)
- logger.info('Running metrics')
- metrics = ['recall', 'precision', 'num_unique_objects', 'mostly_tracked',
- 'partially_tracked', 'mostly_lost', 'num_false_positives', 'num_misses',
- 'num_switches', 'num_fragmentations', 'mota', 'motp', 'num_objects']
- summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
- # summary = mh.compute_many(accs, names=names, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
- # print(mm.io.render_summary(
- # summary, formatters=mh.formatters,
- # namemap=mm.io.motchallenge_metric_names))
- div_dict = {
- 'num_objects': ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations'],
- 'num_unique_objects': ['mostly_tracked', 'partially_tracked', 'mostly_lost']}
- for divisor in div_dict:
- for divided in div_dict[divisor]:
- summary[divided] = (summary[divided] / summary[divisor])
- fmt = mh.formatters
- change_fmt_list = ['num_false_positives', 'num_misses', 'num_switches', 'num_fragmentations', 'mostly_tracked',
- 'partially_tracked', 'mostly_lost']
- for k in change_fmt_list:
- fmt[k] = fmt['mota']
- print(mm.io.render_summary(summary, formatters=fmt, namemap=mm.io.motchallenge_metric_names))
- metrics = mm.metrics.motchallenge_metrics + ['num_objects']
- summary = mh.compute_many(accs, names=names, metrics=metrics, generate_overall=True)
- print(mm.io.render_summary(summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names))
- logger.info('Completed')
|