BYTETracker.h 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. #pragma once
  2. #include "STrack.h"
  3. struct Object
  4. {
  5. cv::Rect_<float> rect;
  6. int label;
  7. float prob;
  8. };
  9. class BYTETracker
  10. {
  11. public:
  12. BYTETracker(int frame_rate = 30, int track_buffer = 30);
  13. ~BYTETracker();
  14. vector<STrack> update(const vector<Object>& objects);
  15. Scalar get_color(int idx);
  16. private:
  17. vector<STrack*> joint_stracks(vector<STrack*> &tlista, vector<STrack> &tlistb);
  18. vector<STrack> joint_stracks(vector<STrack> &tlista, vector<STrack> &tlistb);
  19. vector<STrack> sub_stracks(vector<STrack> &tlista, vector<STrack> &tlistb);
  20. void remove_duplicate_stracks(vector<STrack> &resa, vector<STrack> &resb, vector<STrack> &stracksa, vector<STrack> &stracksb);
  21. void linear_assignment(vector<vector<float> > &cost_matrix, int cost_matrix_size, int cost_matrix_size_size, float thresh,
  22. vector<vector<int> > &matches, vector<int> &unmatched_a, vector<int> &unmatched_b);
  23. vector<vector<float> > iou_distance(vector<STrack*> &atracks, vector<STrack> &btracks, int &dist_size, int &dist_size_size);
  24. vector<vector<float> > iou_distance(vector<STrack> &atracks, vector<STrack> &btracks);
  25. vector<vector<float> > ious(vector<vector<float> > &atlbrs, vector<vector<float> > &btlbrs);
  26. double lapjv(const vector<vector<float> > &cost, vector<int> &rowsol, vector<int> &colsol,
  27. bool extend_cost = false, float cost_limit = LONG_MAX, bool return_cost = true);
  28. private:
  29. float track_thresh;
  30. float high_thresh;
  31. float match_thresh;
  32. int frame_id;
  33. int max_time_lost;
  34. vector<STrack> tracked_stracks;
  35. vector<STrack> lost_stracks;
  36. vector<STrack> removed_stracks;
  37. byte_kalman::KalmanFilter kalman_filter;
  38. };