evaluation.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # ------------------------------------------------------------------------
  2. # Copyright (c) 2021 megvii-model. All Rights Reserved.
  3. # ------------------------------------------------------------------------
  4. # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
  5. # Copyright (c) 2020 SenseTime. All Rights Reserved.
  6. # ------------------------------------------------------------------------
  7. # Modified from DETR (https://github.com/facebookresearch/detr)
  8. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  9. # ------------------------------------------------------------------------
  10. import os
  11. import numpy as np
  12. import copy
  13. import motmetrics as mm
  14. mm.lap.default_solver = 'lap'
  15. import os
  16. from typing import Dict
  17. import numpy as np
  18. import logging
  19. def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
  20. if data_type in ('mot', 'lab'):
  21. read_fun = read_mot_results
  22. else:
  23. raise ValueError('Unknown data type: {}'.format(data_type))
  24. return read_fun(filename, is_gt, is_ignore)
  25. # def read_mot_results(filename, is_gt, is_ignore):
  26. # results_dict = dict()
  27. # if os.path.isfile(filename):
  28. # with open(filename, 'r') as f:
  29. # for line in f.readlines():
  30. # linelist = line.split(',')
  31. # if len(linelist) < 7:
  32. # continue
  33. # fid = int(linelist[0])
  34. # if fid < 1:
  35. # continue
  36. # results_dict.setdefault(fid, list())
  37. # if is_gt:
  38. # mark = int(float(linelist[6]))
  39. # if mark == 0 :
  40. # continue
  41. # score = 1
  42. # elif is_ignore:
  43. # score = 1
  44. # else:
  45. # score = float(linelist[6])
  46. # tlwh = tuple(map(float, linelist[2:6]))
  47. # target_id = int(float(linelist[1]))
  48. # results_dict[fid].append((tlwh, target_id, score))
  49. # return results_dict
  50. def read_mot_results(filename, is_gt, is_ignore):
  51. valid_labels = {1}
  52. ignore_labels = {0, 2, 7, 8, 12}
  53. results_dict = dict()
  54. if os.path.isfile(filename):
  55. with open(filename, 'r') as f:
  56. for line in f.readlines():
  57. linelist = line.split(',')
  58. if len(linelist) < 7:
  59. continue
  60. fid = int(linelist[0])
  61. if fid < 1:
  62. continue
  63. results_dict.setdefault(fid, list())
  64. if is_gt:
  65. if 'MOT16-' in filename or 'MOT17-' in filename:
  66. label = int(float(linelist[7]))
  67. mark = int(float(linelist[6]))
  68. if mark == 0 or label not in valid_labels:
  69. continue
  70. score = 1
  71. elif is_ignore:
  72. if 'MOT16-' in filename or 'MOT17-' in filename:
  73. label = int(float(linelist[7]))
  74. vis_ratio = float(linelist[8])
  75. if label not in ignore_labels and vis_ratio >= 0:
  76. continue
  77. elif 'MOT15' in filename:
  78. label = int(float(linelist[6]))
  79. if label not in ignore_labels:
  80. continue
  81. else:
  82. continue
  83. score = 1
  84. else:
  85. score = float(linelist[6])
  86. tlwh = tuple(map(float, linelist[2:6]))
  87. target_id = int(linelist[1])
  88. results_dict[fid].append((tlwh, target_id, score))
  89. return results_dict
  90. def unzip_objs(objs):
  91. if len(objs) > 0:
  92. tlwhs, ids, scores = zip(*objs)
  93. else:
  94. tlwhs, ids, scores = [], [], []
  95. tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
  96. return tlwhs, ids, scores
  97. class Evaluator(object):
  98. def __init__(self, data_root, seq_name, data_type='mot'):
  99. self.data_root = data_root
  100. self.seq_name = seq_name
  101. self.data_type = data_type
  102. self.load_annotations()
  103. self.reset_accumulator()
  104. def load_annotations(self):
  105. assert self.data_type == 'mot'
  106. gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt')
  107. self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True)
  108. self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True)
  109. def reset_accumulator(self):
  110. self.acc = mm.MOTAccumulator(auto_id=True)
  111. def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
  112. # results
  113. trk_tlwhs = np.copy(trk_tlwhs)
  114. trk_ids = np.copy(trk_ids)
  115. # gts
  116. gt_objs = self.gt_frame_dict.get(frame_id, [])
  117. gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
  118. # ignore boxes
  119. ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
  120. ignore_tlwhs = unzip_objs(ignore_objs)[0]
  121. # remove ignored results
  122. keep = np.ones(len(trk_tlwhs), dtype=bool)
  123. iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5)
  124. if len(iou_distance) > 0:
  125. match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
  126. match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
  127. match_ious = iou_distance[match_is, match_js]
  128. match_js = np.asarray(match_js, dtype=int)
  129. match_js = match_js[np.logical_not(np.isnan(match_ious))]
  130. keep[match_js] = False
  131. trk_tlwhs = trk_tlwhs[keep]
  132. trk_ids = trk_ids[keep]
  133. # get distance matrix
  134. iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)
  135. # acc
  136. self.acc.update(gt_ids, trk_ids, iou_distance)
  137. if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'):
  138. events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
  139. else:
  140. events = None
  141. return events
  142. def eval_file(self, filename):
  143. self.reset_accumulator()
  144. result_frame_dict = read_results(filename, self.data_type, is_gt=False)
  145. #frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys())))
  146. frames = sorted(list(set(result_frame_dict.keys())))
  147. for frame_id in frames:
  148. trk_objs = result_frame_dict.get(frame_id, [])
  149. trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
  150. self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
  151. return self.acc
  152. @staticmethod
  153. def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')):
  154. names = copy.deepcopy(names)
  155. if metrics is None:
  156. metrics = mm.metrics.motchallenge_metrics
  157. metrics = copy.deepcopy(metrics)
  158. mh = mm.metrics.create()
  159. summary = mh.compute_many(
  160. accs,
  161. metrics=metrics,
  162. names=names,
  163. generate_overall=True
  164. )
  165. return summary
  166. @staticmethod
  167. def save_summary(summary, filename):
  168. import pandas as pd
  169. writer = pd.ExcelWriter(filename)
  170. summary.to_excel(writer)
  171. writer.save()