keypoint_detector.h 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <ctime>
  16. #include <memory>
  17. #include <string>
  18. #include <utility>
  19. #include <vector>
  20. #include <opencv2/core/core.hpp>
  21. #include <opencv2/highgui/highgui.hpp>
  22. #include <opencv2/imgproc/imgproc.hpp>
  23. #include "paddle_api.h" // NOLINT
  24. #include "include/config_parser.h"
  25. #include "include/keypoint_postprocess.h"
  26. #include "include/preprocess_op.h"
  27. using namespace paddle::lite_api; // NOLINT
  28. namespace PaddleDetection {
  29. // Object KeyPoint Result
  30. struct KeyPointResult {
  31. // Keypoints: shape(N x 3); N: number of Joints; 3: x,y,conf
  32. std::vector<float> keypoints;
  33. int num_joints = -1;
  34. };
  35. // Visualiztion KeyPoint Result
  36. cv::Mat VisualizeKptsResult(const cv::Mat& img,
  37. const std::vector<KeyPointResult>& results,
  38. const std::vector<int>& colormap,
  39. float threshold = 0.2);
  40. class KeyPointDetector {
  41. public:
  42. explicit KeyPointDetector(const std::string& model_dir,
  43. int cpu_threads = 1,
  44. const int batch_size = 1,
  45. bool use_dark = true) {
  46. config_.load_config(model_dir);
  47. threshold_ = config_.draw_threshold_;
  48. use_dark_ = use_dark;
  49. preprocessor_.Init(config_.preprocess_info_);
  50. printf("before keypoint detector\n");
  51. LoadModel(model_dir, cpu_threads);
  52. printf("create keypoint detector\n");
  53. }
  54. // Load Paddle inference model
  55. void LoadModel(std::string model_file, int num_theads);
  56. // Run predictor
  57. void Predict(const std::vector<cv::Mat> imgs,
  58. std::vector<std::vector<float>>& center,
  59. std::vector<std::vector<float>>& scale,
  60. const int warmup = 0,
  61. const int repeats = 1,
  62. std::vector<KeyPointResult>* result = nullptr,
  63. std::vector<double>* times = nullptr);
  64. // Get Model Label list
  65. const std::vector<std::string>& GetLabelList() const {
  66. return config_.label_list_;
  67. }
  68. bool use_dark(){return this->use_dark_;}
  69. inline float get_threshold() {return threshold_;};
  70. private:
  71. // Preprocess image and copy data to input buffer
  72. void Preprocess(const cv::Mat& image_mat);
  73. // Postprocess result
  74. void Postprocess(std::vector<float>& output,
  75. std::vector<int64_t>& output_shape,
  76. std::vector<int64_t>& idxout,
  77. std::vector<int64_t>& idx_shape,
  78. std::vector<KeyPointResult>* result,
  79. std::vector<std::vector<float>>& center,
  80. std::vector<std::vector<float>>& scale);
  81. std::shared_ptr<PaddlePredictor> predictor_;
  82. Preprocessor preprocessor_;
  83. ImageBlob inputs_;
  84. std::vector<float> output_data_;
  85. std::vector<int64_t> idx_data_;
  86. float threshold_;
  87. ConfigPaser config_;
  88. bool use_dark_;
  89. };
  90. } // namespace PaddleDetection