trajectory.h 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // The code is based on:
  15. // https://github.com/CnybTseng/JDE/blob/master/platforms/common/trajectory.h
  16. // Ths copyright of CnybTseng/JDE is as follows:
  17. // MIT License
  18. #pragma once
  19. #include <vector>
  20. #include <opencv2/core/core.hpp>
  21. #include <opencv2/highgui/highgui.hpp>
  22. #include <opencv2/imgproc/imgproc.hpp>
  23. #include "opencv2/video/tracking.hpp"
  24. namespace PaddleDetection {
  25. typedef enum { New = 0, Tracked = 1, Lost = 2, Removed = 3 } TrajectoryState;
  26. class Trajectory;
  27. typedef std::vector<Trajectory> TrajectoryPool;
  28. typedef std::vector<Trajectory>::iterator TrajectoryPoolIterator;
  29. typedef std::vector<Trajectory *> TrajectoryPtrPool;
  30. typedef std::vector<Trajectory *>::iterator TrajectoryPtrPoolIterator;
  31. class TKalmanFilter : public cv::KalmanFilter {
  32. public:
  33. TKalmanFilter(void);
  34. virtual ~TKalmanFilter(void) {}
  35. virtual void init(const cv::Mat &measurement);
  36. virtual const cv::Mat &predict();
  37. virtual const cv::Mat &correct(const cv::Mat &measurement);
  38. virtual void project(cv::Mat *mean, cv::Mat *covariance) const;
  39. private:
  40. float std_weight_position;
  41. float std_weight_velocity;
  42. };
  43. inline TKalmanFilter::TKalmanFilter(void) : cv::KalmanFilter(8, 4) {
  44. cv::KalmanFilter::transitionMatrix = cv::Mat::eye(8, 8, CV_32F);
  45. for (int i = 0; i < 4; ++i)
  46. cv::KalmanFilter::transitionMatrix.at<float>(i, i + 4) = 1;
  47. cv::KalmanFilter::measurementMatrix = cv::Mat::eye(4, 8, CV_32F);
  48. std_weight_position = 1 / 20.f;
  49. std_weight_velocity = 1 / 160.f;
  50. }
  51. class Trajectory : public TKalmanFilter {
  52. public:
  53. Trajectory();
  54. Trajectory(const cv::Vec4f &ltrb, float score, const cv::Mat &embedding);
  55. Trajectory(const Trajectory &other);
  56. Trajectory &operator=(const Trajectory &rhs);
  57. virtual ~Trajectory(void) {}
  58. static int next_id();
  59. virtual const cv::Mat &predict(void);
  60. virtual void update(Trajectory *traj,
  61. int timestamp,
  62. bool update_embedding = true);
  63. virtual void activate(int timestamp);
  64. virtual void reactivate(Trajectory *traj, int timestamp, bool newid = false);
  65. virtual void mark_lost(void);
  66. virtual void mark_removed(void);
  67. friend TrajectoryPool operator+(const TrajectoryPool &a,
  68. const TrajectoryPool &b);
  69. friend TrajectoryPool operator+(const TrajectoryPool &a,
  70. const TrajectoryPtrPool &b);
  71. friend TrajectoryPool &operator+=(TrajectoryPool &a, // NOLINT
  72. const TrajectoryPtrPool &b);
  73. friend TrajectoryPool operator-(const TrajectoryPool &a,
  74. const TrajectoryPool &b);
  75. friend TrajectoryPool &operator-=(TrajectoryPool &a, // NOLINT
  76. const TrajectoryPool &b);
  77. friend TrajectoryPtrPool operator+(const TrajectoryPtrPool &a,
  78. const TrajectoryPtrPool &b);
  79. friend TrajectoryPtrPool operator+(const TrajectoryPtrPool &a,
  80. TrajectoryPool *b);
  81. friend TrajectoryPtrPool operator-(const TrajectoryPtrPool &a,
  82. const TrajectoryPtrPool &b);
  83. friend cv::Mat embedding_distance(const TrajectoryPool &a,
  84. const TrajectoryPool &b);
  85. friend cv::Mat embedding_distance(const TrajectoryPtrPool &a,
  86. const TrajectoryPtrPool &b);
  87. friend cv::Mat embedding_distance(const TrajectoryPtrPool &a,
  88. const TrajectoryPool &b);
  89. friend cv::Mat mahalanobis_distance(const TrajectoryPool &a,
  90. const TrajectoryPool &b);
  91. friend cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a,
  92. const TrajectoryPtrPool &b);
  93. friend cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a,
  94. const TrajectoryPool &b);
  95. friend cv::Mat iou_distance(const TrajectoryPool &a, const TrajectoryPool &b);
  96. friend cv::Mat iou_distance(const TrajectoryPtrPool &a,
  97. const TrajectoryPtrPool &b);
  98. friend cv::Mat iou_distance(const TrajectoryPtrPool &a,
  99. const TrajectoryPool &b);
  100. private:
  101. void update_embedding(const cv::Mat &embedding);
  102. public:
  103. TrajectoryState state;
  104. cv::Vec4f ltrb;
  105. cv::Mat smooth_embedding;
  106. int id;
  107. bool is_activated;
  108. int timestamp;
  109. int starttime;
  110. float score;
  111. private:
  112. static int count;
  113. cv::Vec4f xyah;
  114. cv::Mat current_embedding;
  115. float eta;
  116. int length;
  117. };
  118. inline cv::Vec4f ltrb2xyah(const cv::Vec4f &ltrb) {
  119. cv::Vec4f xyah;
  120. xyah[0] = (ltrb[0] + ltrb[2]) * 0.5f;
  121. xyah[1] = (ltrb[1] + ltrb[3]) * 0.5f;
  122. xyah[3] = ltrb[3] - ltrb[1];
  123. xyah[2] = (ltrb[2] - ltrb[0]) / xyah[3];
  124. return xyah;
  125. }
  126. inline Trajectory::Trajectory()
  127. : state(New),
  128. ltrb(cv::Vec4f()),
  129. smooth_embedding(cv::Mat()),
  130. id(0),
  131. is_activated(false),
  132. timestamp(0),
  133. starttime(0),
  134. score(0),
  135. eta(0.9),
  136. length(0) {}
  137. inline Trajectory::Trajectory(const cv::Vec4f &ltrb_,
  138. float score_,
  139. const cv::Mat &embedding)
  140. : state(New),
  141. ltrb(ltrb_),
  142. smooth_embedding(cv::Mat()),
  143. id(0),
  144. is_activated(false),
  145. timestamp(0),
  146. starttime(0),
  147. score(score_),
  148. eta(0.9),
  149. length(0) {
  150. xyah = ltrb2xyah(ltrb);
  151. update_embedding(embedding);
  152. }
  153. inline Trajectory::Trajectory(const Trajectory &other)
  154. : state(other.state),
  155. ltrb(other.ltrb),
  156. id(other.id),
  157. is_activated(other.is_activated),
  158. timestamp(other.timestamp),
  159. starttime(other.starttime),
  160. xyah(other.xyah),
  161. score(other.score),
  162. eta(other.eta),
  163. length(other.length) {
  164. other.smooth_embedding.copyTo(smooth_embedding);
  165. other.current_embedding.copyTo(current_embedding);
  166. // copy state in KalmanFilter
  167. other.statePre.copyTo(cv::KalmanFilter::statePre);
  168. other.statePost.copyTo(cv::KalmanFilter::statePost);
  169. other.errorCovPre.copyTo(cv::KalmanFilter::errorCovPre);
  170. other.errorCovPost.copyTo(cv::KalmanFilter::errorCovPost);
  171. }
  172. inline Trajectory &Trajectory::operator=(const Trajectory &rhs) {
  173. this->state = rhs.state;
  174. this->ltrb = rhs.ltrb;
  175. rhs.smooth_embedding.copyTo(this->smooth_embedding);
  176. this->id = rhs.id;
  177. this->is_activated = rhs.is_activated;
  178. this->timestamp = rhs.timestamp;
  179. this->starttime = rhs.starttime;
  180. this->xyah = rhs.xyah;
  181. this->score = rhs.score;
  182. rhs.current_embedding.copyTo(this->current_embedding);
  183. this->eta = rhs.eta;
  184. this->length = rhs.length;
  185. // copy state in KalmanFilter
  186. rhs.statePre.copyTo(cv::KalmanFilter::statePre);
  187. rhs.statePost.copyTo(cv::KalmanFilter::statePost);
  188. rhs.errorCovPre.copyTo(cv::KalmanFilter::errorCovPre);
  189. rhs.errorCovPost.copyTo(cv::KalmanFilter::errorCovPost);
  190. return *this;
  191. }
  192. inline int Trajectory::next_id() {
  193. ++count;
  194. return count;
  195. }
  196. inline void Trajectory::mark_lost(void) { state = Lost; }
  197. inline void Trajectory::mark_removed(void) { state = Removed; }
  198. } // namespace PaddleDetection