evaluation.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import os
  2. import numpy as np
  3. import copy
  4. import motmetrics as mm
  5. mm.lap.default_solver = 'lap'
  6. class Evaluator(object):
  7. def __init__(self, data_root, seq_name, data_type):
  8. self.data_root = data_root
  9. self.seq_name = seq_name
  10. self.data_type = data_type
  11. self.load_annotations()
  12. self.reset_accumulator()
  13. def load_annotations(self):
  14. assert self.data_type == 'mot'
  15. gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt')
  16. self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True)
  17. self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True)
  18. def reset_accumulator(self):
  19. self.acc = mm.MOTAccumulator(auto_id=True)
  20. def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
  21. # results
  22. trk_tlwhs = np.copy(trk_tlwhs)
  23. trk_ids = np.copy(trk_ids)
  24. # gts
  25. gt_objs = self.gt_frame_dict.get(frame_id, [])
  26. gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
  27. # ignore boxes
  28. ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
  29. ignore_tlwhs = unzip_objs(ignore_objs)[0]
  30. # remove ignored results
  31. keep = np.ones(len(trk_tlwhs), dtype=bool)
  32. iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5)
  33. if len(iou_distance) > 0:
  34. match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
  35. match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
  36. match_ious = iou_distance[match_is, match_js]
  37. match_js = np.asarray(match_js, dtype=int)
  38. match_js = match_js[np.logical_not(np.isnan(match_ious))]
  39. keep[match_js] = False
  40. trk_tlwhs = trk_tlwhs[keep]
  41. trk_ids = trk_ids[keep]
  42. #match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
  43. #match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
  44. #match_ious = iou_distance[match_is, match_js]
  45. #match_js = np.asarray(match_js, dtype=int)
  46. #match_js = match_js[np.logical_not(np.isnan(match_ious))]
  47. #keep[match_js] = False
  48. #trk_tlwhs = trk_tlwhs[keep]
  49. #trk_ids = trk_ids[keep]
  50. # get distance matrix
  51. iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)
  52. # acc
  53. self.acc.update(gt_ids, trk_ids, iou_distance)
  54. if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'):
  55. events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
  56. else:
  57. events = None
  58. return events
  59. def eval_file(self, filename):
  60. self.reset_accumulator()
  61. result_frame_dict = read_results(filename, self.data_type, is_gt=False)
  62. #frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys())))
  63. frames = sorted(list(set(result_frame_dict.keys())))
  64. for frame_id in frames:
  65. trk_objs = result_frame_dict.get(frame_id, [])
  66. trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
  67. self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
  68. return self.acc
  69. @staticmethod
  70. def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')):
  71. names = copy.deepcopy(names)
  72. if metrics is None:
  73. metrics = mm.metrics.motchallenge_metrics
  74. metrics = copy.deepcopy(metrics)
  75. mh = mm.metrics.create()
  76. summary = mh.compute_many(
  77. accs,
  78. metrics=metrics,
  79. names=names,
  80. generate_overall=True
  81. )
  82. return summary
  83. @staticmethod
  84. def save_summary(summary, filename):
  85. import pandas as pd
  86. writer = pd.ExcelWriter(filename)
  87. summary.to_excel(writer)
  88. writer.save()
  89. def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
  90. if data_type in ('mot', 'lab'):
  91. read_fun = read_mot_results
  92. else:
  93. raise ValueError('Unknown data type: {}'.format(data_type))
  94. return read_fun(filename, is_gt, is_ignore)
  95. """
  96. labels={'ped', ... % 1
  97. 'person_on_vhcl', ... % 2
  98. 'car', ... % 3
  99. 'bicycle', ... % 4
  100. 'mbike', ... % 5
  101. 'non_mot_vhcl', ... % 6
  102. 'static_person', ... % 7
  103. 'distractor', ... % 8
  104. 'occluder', ... % 9
  105. 'occluder_on_grnd', ... %10
  106. 'occluder_full', ... % 11
  107. 'reflection', ... % 12
  108. 'crowd' ... % 13
  109. };
  110. """
  111. def read_mot_results(filename, is_gt, is_ignore):
  112. valid_labels = {1}
  113. ignore_labels = {2, 7, 8, 12}
  114. results_dict = dict()
  115. if os.path.isfile(filename):
  116. with open(filename, 'r') as f:
  117. for line in f.readlines():
  118. linelist = line.split(',')
  119. if len(linelist) < 7:
  120. continue
  121. fid = int(linelist[0])
  122. if fid < 1:
  123. continue
  124. results_dict.setdefault(fid, list())
  125. box_size = float(linelist[4]) * float(linelist[5])
  126. if is_gt:
  127. if 'MOT16-' in filename or 'MOT17-' in filename:
  128. label = int(float(linelist[7]))
  129. mark = int(float(linelist[6]))
  130. if mark == 0 or label not in valid_labels:
  131. continue
  132. score = 1
  133. elif is_ignore:
  134. if 'MOT16-' in filename or 'MOT17-' in filename:
  135. label = int(float(linelist[7]))
  136. vis_ratio = float(linelist[8])
  137. if label not in ignore_labels and vis_ratio >= 0:
  138. continue
  139. else:
  140. continue
  141. score = 1
  142. else:
  143. score = float(linelist[6])
  144. #if box_size > 7000:
  145. #if box_size <= 7000 or box_size >= 15000:
  146. #if box_size < 15000:
  147. #continue
  148. tlwh = tuple(map(float, linelist[2:6]))
  149. target_id = int(linelist[1])
  150. results_dict[fid].append((tlwh, target_id, score))
  151. return results_dict
  152. def unzip_objs(objs):
  153. if len(objs) > 0:
  154. tlwhs, ids, scores = zip(*objs)
  155. else:
  156. tlwhs, ids, scores = [], [], []
  157. tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
  158. return tlwhs, ids, scores