BYTETracker.cpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #include "BYTETracker.h"
  2. BYTETracker::BYTETracker(int frame_rate, int track_buffer) {
  3. track_thresh = 0;
  4. high_thresh = 0.2;
  5. match_thresh = 0.8;
  6. frame_id = 0;
  7. max_time_lost = int(frame_rate / 30.0 * track_buffer);
  8. }
  9. BYTETracker::~BYTETracker() {
  10. }
  11. vector<STrack> BYTETracker::update(const vector<NvObject> &nvObjects) {
  12. ////////////////// Step 1: Get detections //////////////////
  13. this->frame_id++;
  14. vector<STrack> activated_stracks;
  15. vector<STrack> refind_stracks;
  16. vector<STrack> removed_stracks;
  17. vector<STrack> lost_stracks;
  18. vector<STrack> detections;
  19. vector<STrack> detections_low;
  20. vector<STrack> detections_cp;
  21. vector<STrack> tracked_stracks_swap;
  22. vector<STrack> resa, resb;
  23. vector<STrack> output_stracks;
  24. vector<STrack *> unconfirmed;
  25. vector<STrack *> tracked_stracks;
  26. vector<STrack *> strack_pool;
  27. vector<STrack *> r_tracked_stracks;
  28. if (nvObjects.size() > 0) {
  29. for (int i = 0; i < nvObjects.size(); i++) {
  30. vector<float> tlwh_;
  31. tlwh_.resize(4);
  32. tlwh_[0] = nvObjects[i].rect[0];
  33. tlwh_[1] = nvObjects[i].rect[1];
  34. tlwh_[2] = nvObjects[i].rect[2];
  35. tlwh_[3] = nvObjects[i].rect[3];
  36. float score = nvObjects[i].prob;
  37. STrack strack(tlwh_, score, nvObjects[i].label, nvObjects[i].associatedObjectIn);
  38. if (score >= track_thresh) {
  39. detections.push_back(strack);
  40. } else {
  41. detections_low.push_back(strack);
  42. }
  43. }
  44. }
  45. // Add newly detected tracklets to tracked_stracks
  46. for (int i = 0; i < this->tracked_stracks.size(); i++) {
  47. if (!this->tracked_stracks[i].is_activated)
  48. unconfirmed.push_back(&this->tracked_stracks[i]);
  49. else
  50. tracked_stracks.push_back(&this->tracked_stracks[i]);
  51. }
  52. ////////////////// Step 2: First association, with IoU //////////////////
  53. strack_pool = joint_stracks(tracked_stracks, this->lost_stracks);
  54. STrack::multi_predict(strack_pool, this->kalman_filter);
  55. vector<vector<float>> dists;
  56. int dist_size = 0, dist_size_size = 0;
  57. dists = iou_distance(strack_pool, detections, dist_size, dist_size_size);
  58. vector<vector<int>> matches;
  59. vector<int> u_track, u_detection;
  60. linear_assignment(dists, dist_size, dist_size_size, match_thresh, matches, u_track, u_detection);
  61. for (int i = 0; i < matches.size(); i++) {
  62. STrack *track = strack_pool[matches[i][0]];
  63. STrack *det = &detections[matches[i][1]];
  64. if (track->state == TrackState::Tracked) {
  65. track->update(*det, this->frame_id);
  66. activated_stracks.push_back(*track);
  67. } else {
  68. track->re_activate(*det, this->frame_id, false);
  69. refind_stracks.push_back(*track);
  70. }
  71. }
  72. ////////////////// Step 3: Second association, using low score dets //////////////////
  73. for (int i = 0; i < u_detection.size(); i++) {
  74. detections_cp.push_back(detections[u_detection[i]]);
  75. }
  76. detections.clear();
  77. detections.assign(detections_low.begin(), detections_low.end());
  78. for (int i = 0; i < u_track.size(); i++) {
  79. if (strack_pool[u_track[i]]->state == TrackState::Tracked) {
  80. r_tracked_stracks.push_back(strack_pool[u_track[i]]);
  81. }
  82. }
  83. dists.clear();
  84. dists = iou_distance(r_tracked_stracks, detections, dist_size, dist_size_size);
  85. matches.clear();
  86. u_track.clear();
  87. u_detection.clear();
  88. linear_assignment(dists, dist_size, dist_size_size, 0.5, matches, u_track, u_detection);
  89. for (int i = 0; i < matches.size(); i++) {
  90. STrack *track = r_tracked_stracks[matches[i][0]];
  91. STrack *det = &detections[matches[i][1]];
  92. if (track->state == TrackState::Tracked) {
  93. track->update(*det, this->frame_id);
  94. activated_stracks.push_back(*track);
  95. } else {
  96. track->re_activate(*det, this->frame_id, false);
  97. refind_stracks.push_back(*track);
  98. }
  99. }
  100. for (int i = 0; i < u_track.size(); i++) {
  101. STrack *track = r_tracked_stracks[u_track[i]];
  102. if (track->state != TrackState::Lost) {
  103. track->mark_lost();
  104. lost_stracks.push_back(*track);
  105. }
  106. }
  107. // Deal with unconfirmed tracks, usually tracks with only one beginning frame
  108. detections.clear();
  109. detections.assign(detections_cp.begin(), detections_cp.end());
  110. dists.clear();
  111. dists = iou_distance(unconfirmed, detections, dist_size, dist_size_size);
  112. matches.clear();
  113. vector<int> u_unconfirmed;
  114. u_detection.clear();
  115. linear_assignment(dists, dist_size, dist_size_size, 0.7, matches, u_unconfirmed, u_detection);
  116. for (int i = 0; i < matches.size(); i++) {
  117. unconfirmed[matches[i][0]]->update(detections[matches[i][1]], this->frame_id);
  118. activated_stracks.push_back(*unconfirmed[matches[i][0]]);
  119. }
  120. for (int i = 0; i < u_unconfirmed.size(); i++) {
  121. STrack *track = unconfirmed[u_unconfirmed[i]];
  122. track->mark_removed();
  123. removed_stracks.push_back(*track);
  124. }
  125. ////////////////// Step 4: Init new stracks //////////////////
  126. for (int i = 0; i < u_detection.size(); i++) {
  127. STrack *track = &detections[u_detection[i]];
  128. if (track->score < this->high_thresh)
  129. continue;
  130. track->activate(this->kalman_filter, this->frame_id);
  131. activated_stracks.push_back(*track);
  132. }
  133. ////////////////// Step 5: Update state //////////////////
  134. for (int i = 0; i < this->lost_stracks.size(); i++) {
  135. if (this->frame_id - this->lost_stracks[i].end_frame() > this->max_time_lost) {
  136. this->lost_stracks[i].mark_removed();
  137. removed_stracks.push_back(this->lost_stracks[i]);
  138. }
  139. }
  140. for (int i = 0; i < this->tracked_stracks.size(); i++) {
  141. if (this->tracked_stracks[i].state == TrackState::Tracked) {
  142. tracked_stracks_swap.push_back(this->tracked_stracks[i]);
  143. }
  144. }
  145. this->tracked_stracks.clear();
  146. this->tracked_stracks.assign(tracked_stracks_swap.begin(), tracked_stracks_swap.end());
  147. this->tracked_stracks = joint_stracks(this->tracked_stracks, activated_stracks);
  148. this->tracked_stracks = joint_stracks(this->tracked_stracks, refind_stracks);
  149. //std::cout << activated_stracks.size() << std::endl;
  150. this->lost_stracks = sub_stracks(this->lost_stracks, this->tracked_stracks);
  151. for (int i = 0; i < lost_stracks.size(); i++) {
  152. this->lost_stracks.push_back(lost_stracks[i]);
  153. }
  154. this->lost_stracks = sub_stracks(this->lost_stracks, this->removed_stracks);
  155. for (int i = 0; i < removed_stracks.size(); i++) {
  156. this->removed_stracks.push_back(removed_stracks[i]);
  157. }
  158. remove_duplicate_stracks(resa, resb, this->tracked_stracks, this->lost_stracks);
  159. this->tracked_stracks.clear();
  160. this->tracked_stracks.assign(resa.begin(), resa.end());
  161. this->lost_stracks.clear();
  162. this->lost_stracks.assign(resb.begin(), resb.end());
  163. for (int i = 0; i < this->tracked_stracks.size(); i++) {
  164. if (this->tracked_stracks[i].is_activated) {
  165. output_stracks.push_back(this->tracked_stracks[i]);
  166. }
  167. }
  168. // clean up old objects
  169. vector<STrack > filtered_output_stracks;
  170. std::copy_if(output_stracks.begin(),
  171. output_stracks.end(),
  172. std::back_inserter(filtered_output_stracks),
  173. [](STrack track) {
  174. return track.associatedObjectIn != NULL &&
  175. track.associatedObjectIn->classId == 0;
  176. });
  177. return filtered_output_stracks;
  178. }