easy_track.h 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. /*************************************************************************
  2. * Copyright (C) [2019] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. /**
  21. * @file easy_track.h
  22. * This file contains FeatureMatchTrack class and KcfTrack class.
  23. * Its purpose is to achieve object tracking.
  24. */
  25. #ifndef EASYTRACK_EASY_TRACK_H_
  26. #define EASYTRACK_EASY_TRACK_H_
  27. #include <memory>
  28. #include <vector>
  29. #include "cxxutil/exception.h"
  30. #include "easyinfer/model_loader.h"
  31. namespace edk {
  32. /**
  33. * @brief Struct of BoundingBox
  34. */
  35. struct BoundingBox {
  36. float x; ///< Topleft coordinates x
  37. float y; ///< Topleft coordinates y
  38. float width; ///< BoundingBox width
  39. float height; ///< BoundingBox height
  40. };
  41. /**
  42. * @brief Struct of detection objects.
  43. */
  44. struct DetectObject {
  45. /// Object detection label
  46. int label;
  47. /// Object detection confidence rate
  48. float score;
  49. /// Struct BoundingBox
  50. BoundingBox bbox;
  51. /// Object track identification
  52. int track_id;
  53. /// Object index in input vector
  54. int detect_id;
  55. /**
  56. * @brief Features of object extraction.
  57. * @attention The dimension of the feature vector is 128.
  58. */
  59. std::vector<float> feature;
  60. /// internal
  61. mutable float feat_mold;
  62. };
  63. /// Alias of vector stored DetectObject
  64. using Objects = std::vector<DetectObject>;
  65. /**
  66. * @brief Track frame stored frame data and information needed in track
  67. */
  68. struct TrackFrame {
  69. /**
  70. * @brief The data of frame.
  71. * @attention This parameter is used for KcfTrack only.
  72. */
  73. void *data;
  74. /// The width of trackframe
  75. uint32_t width;
  76. /// The height of trackframe
  77. uint32_t height;
  78. /// The identification of trackframe
  79. int64_t frame_id;
  80. /// The identification of device
  81. int device_id;
  82. /**
  83. * @brief Color space enumeration.
  84. */
  85. enum class ColorSpace { GRAY, NV21, NV12, RGB24, BGR24 } format;
  86. /**
  87. * @brief Device type enumeration.
  88. */
  89. enum class DevType {
  90. CPU = 0,
  91. MLU,
  92. } dev_type;
  93. };
  94. /**
  95. * @brief EasyTrack class, help for tracking objects.
  96. */
  97. class EasyTrack {
  98. public:
  99. /**
  100. * @brief Destroy the EasyTrack object.
  101. */
  102. virtual ~EasyTrack() {}
  103. /**
  104. * @brief Update object status and do track
  105. *
  106. * @param frame Track frame
  107. * @param detects Detected objects
  108. * @param tracks Tracked objects
  109. */
  110. virtual void UpdateFrame(const TrackFrame &frame, const Objects &detects, Objects *tracks) noexcept(false) = 0;
  111. }; // class EasyTrack
  112. class FeatureMatchPrivate;
  113. /**
  114. * @brief Track objects based on match feature.
  115. *
  116. * @note Match tentative and featureless objects using IOU,
  117. * and cascade-match confirmed objects using feature cosine distance
  118. */
  119. class FeatureMatchTrack : public EasyTrack {
  120. public:
  121. /**
  122. * @brief Constructor of the FeatureMatchTrack class.
  123. */
  124. FeatureMatchTrack();
  125. /**
  126. * @brief Destroy the FeatureMatchTrack object.
  127. */
  128. ~FeatureMatchTrack();
  129. /**
  130. * @brief Set params related to Tracking algorithm.
  131. *
  132. * @param max_cosine_distance Threshold of cosine distance
  133. * @param nn_budget Tracker only saves the latest [nn_budget] samples of feature for each object
  134. * @param max_iou_distance Threshold of iou distance
  135. * @param max_age Object stay alive for [max_age] after disappeared
  136. * @param n_init After matched [n_init] times in a row, object is turned from TENTATIVE to CONFIRMED
  137. */
  138. void SetParams(float max_cosine_distance, int nn_budget, float max_iou_distance, int max_age, int n_init);
  139. /**
  140. * @brief Update object status and do tracking using cascade matching and IOU matching.
  141. *
  142. * @param frame Track frame
  143. * @param detects Detected objects
  144. * @param tracks Tracked objects
  145. */
  146. void UpdateFrame(const TrackFrame &frame, const Objects &detects, Objects *tracks) override;
  147. private:
  148. FeatureMatchPrivate *fm_p_;
  149. friend class FeatureMatchPrivate;
  150. float max_cosine_distance_ = 0.2;
  151. float max_iou_distance_ = 0.7;
  152. int max_age_ = 30;
  153. int n_init_ = 3;
  154. uint32_t nn_budget_ = 100;
  155. }; // class FeatureMatchTrack
  156. class KcfTrackPrivate;
  157. /**
  158. * @brief Track objects based on KCF
  159. *
  160. * @note Track objects using KCF, and match them using IOU
  161. */
  162. class KcfTrack : public EasyTrack {
  163. public:
  164. /**
  165. * @brief Constructor of the KcfTrack class.
  166. */
  167. KcfTrack();
  168. /**
  169. * @brief Destroy the KcfTrack object.
  170. */
  171. ~KcfTrack();
  172. /**
  173. * @brief Set params related to offline model.
  174. *
  175. * @param model ModelLoader
  176. * @param dev_id the id of device
  177. * @param batch_size Batch size
  178. */
  179. void SetModel(std::shared_ptr<ModelLoader> model, int dev_id = 0, uint32_t batch_size = 1);
  180. /**
  181. * @brief Set params related to KcfTrack.
  182. * @param max_iou_distance Threshold of iou distance
  183. */
  184. void SetParams(float max_iou_distance);
  185. /**
  186. * @brief Update result of objects tracking after kcf and IOU matching.
  187. * @see edk::EasyTrack::UpdateFrame
  188. */
  189. void UpdateFrame(const TrackFrame &frame, const Objects &detects, Objects *tracks) override;
  190. private:
  191. KcfTrackPrivate *kcf_p_;
  192. friend class KcfTrackPrivate;
  193. float max_iou_distance_ = 0.7;
  194. }; // class KcfTrack
  195. /**
  196. * @brief Insert DetectObject into the ostream
  197. *
  198. * @param os output stream to insert data to
  199. * @param obj reference to an DetectObjec to insert
  200. *
  201. * @return reference to output stream
  202. */
  203. inline std::ostream &operator<<(std::ostream &os, const DetectObject &obj) {
  204. os << "[Object] label: " << obj.label << " score: " << obj.score << " track_id: " << obj.track_id << '\t'
  205. << "bbox: " << obj.bbox.x << " " << obj.bbox.y << " " << obj.bbox.width << " " << obj.bbox.height;
  206. return os;
  207. }
  208. } // namespace edk
  209. #endif // EASYTRACK_EASY_TRACK_H_