save_track.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """
  2. Copyright (c) https://github.com/xingyizhou/CenterTrack
  3. Modified by Peize Sun, Rufeng Zhang
  4. """
  5. # coding: utf-8
  6. import os
  7. import json
  8. import logging
  9. from collections import defaultdict
  10. def save_track(results, out_root, video_to_images, video_names, data_split='val'):
  11. assert out_root is not None
  12. out_dir = os.path.join(out_root, data_split)
  13. if not os.path.exists(out_dir):
  14. os.mkdir(out_dir)
  15. # save json.
  16. # json_path = os.path.join(out_dir, "track_results.json")
  17. # with open(json_path, "w") as f:
  18. # f.write(json.dumps(results))
  19. # f.flush()
  20. # save it in standard format.
  21. track_dir = os.path.join(out_dir, "tracks")
  22. if not os.path.exists(track_dir):
  23. os.mkdir(track_dir)
  24. for video_id in video_to_images.keys():
  25. video_infos = video_to_images[video_id]
  26. video_name = video_names[video_id]
  27. file_path = os.path.join(track_dir, "{}.txt".format(video_name))
  28. f = open(file_path, "w")
  29. tracks = defaultdict(list)
  30. for video_info in video_infos:
  31. image_id, frame_id = video_info["image_id"], video_info["frame_id"]
  32. result = results[image_id]
  33. for item in result:
  34. if not ("tracking_id" in item):
  35. raise NotImplementedError
  36. tracking_id = item["tracking_id"]
  37. bbox = item["bbox"]
  38. bbox = [bbox[0], bbox[1], bbox[2], bbox[3], item['score'], item['active']]
  39. tracks[tracking_id].append([frame_id] + bbox)
  40. rename_track_id = 0
  41. for track_id in sorted(tracks):
  42. rename_track_id += 1
  43. for t in tracks[track_id]:
  44. if t[6] > 0:
  45. f.write("{},{},{:.2f},{:.2f},{:.2f},{:.2f},-1,-1,-1,-1\n".format(
  46. t[0], rename_track_id, t[1], t[2], t[3] - t[1], t[4] - t[2]))
  47. f.close()