io.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import os
  2. from typing import Dict
  3. import numpy as np
  4. def write_results(filename, results_dict: Dict, data_type: str):
  5. if not filename:
  6. return
  7. path = os.path.dirname(filename)
  8. if not os.path.exists(path):
  9. os.makedirs(path)
  10. if data_type in ('mot', 'mcmot', 'lab'):
  11. save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
  12. elif data_type == 'kitti':
  13. save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n'
  14. else:
  15. raise ValueError(data_type)
  16. with open(filename, 'w') as f:
  17. for frame_id, frame_data in results_dict.items():
  18. if data_type == 'kitti':
  19. frame_id -= 1
  20. for tlwh, track_id in frame_data:
  21. if track_id < 0:
  22. continue
  23. x1, y1, w, h = tlwh
  24. x2, y2 = x1 + w, y1 + h
  25. line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0)
  26. f.write(line)
  27. def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
  28. if data_type in ('mot', 'lab'):
  29. read_fun = read_mot_results
  30. else:
  31. raise ValueError('Unknown data type: {}'.format(data_type))
  32. return read_fun(filename, is_gt, is_ignore)
  33. """
  34. labels={'ped', ... % 1
  35. 'person_on_vhcl', ... % 2
  36. 'car', ... % 3
  37. 'bicycle', ... % 4
  38. 'mbike', ... % 5
  39. 'non_mot_vhcl', ... % 6
  40. 'static_person', ... % 7
  41. 'distractor', ... % 8
  42. 'occluder', ... % 9
  43. 'occluder_on_grnd', ... %10
  44. 'occluder_full', ... % 11
  45. 'reflection', ... % 12
  46. 'crowd' ... % 13
  47. };
  48. """
  49. def read_mot_results(filename, is_gt, is_ignore):
  50. valid_labels = {1}
  51. ignore_labels = {2, 7, 8, 12}
  52. results_dict = dict()
  53. if os.path.isfile(filename):
  54. with open(filename, 'r') as f:
  55. for line in f.readlines():
  56. linelist = line.split(',')
  57. if len(linelist) < 7:
  58. continue
  59. fid = int(linelist[0])
  60. if fid < 1:
  61. continue
  62. results_dict.setdefault(fid, list())
  63. box_size = float(linelist[4]) * float(linelist[5])
  64. if is_gt:
  65. if 'MOT16-' in filename or 'MOT17-' in filename:
  66. label = int(float(linelist[7]))
  67. mark = int(float(linelist[6]))
  68. if mark == 0 or label not in valid_labels:
  69. continue
  70. score = 1
  71. elif is_ignore:
  72. if 'MOT16-' in filename or 'MOT17-' in filename:
  73. label = int(float(linelist[7]))
  74. vis_ratio = float(linelist[8])
  75. if label not in ignore_labels and vis_ratio >= 0:
  76. continue
  77. else:
  78. continue
  79. score = 1
  80. else:
  81. score = float(linelist[6])
  82. #if box_size > 7000:
  83. #if box_size <= 7000 or box_size >= 15000:
  84. #if box_size < 15000:
  85. #continue
  86. tlwh = tuple(map(float, linelist[2:6]))
  87. target_id = int(linelist[1])
  88. results_dict[fid].append((tlwh, target_id, score))
  89. return results_dict
  90. def unzip_objs(objs):
  91. if len(objs) > 0:
  92. tlwhs, ids, scores = zip(*objs)
  93. else:
  94. tlwhs, ids, scores = [], [], []
  95. tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
  96. return tlwhs, ids, scores