object_detector.h 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <ctime>
  16. #include <memory>
  17. #include <string>
  18. #include <utility>
  19. #include <vector>
  20. #include <opencv2/core/core.hpp>
  21. #include <opencv2/highgui/highgui.hpp>
  22. #include <opencv2/imgproc/imgproc.hpp>
  23. #include "paddle_api.h" // NOLINT
  24. #include "include/config_parser.h"
  25. #include "include/preprocess_op.h"
  26. #include "include/utils.h"
  27. #include "include/picodet_postprocess.h"
  28. using namespace paddle::lite_api; // NOLINT
  29. namespace PaddleDetection {
  30. // Generate visualization colormap for each class
  31. std::vector<int> GenerateColorMap(int num_class);
  32. // Visualiztion Detection Result
  33. cv::Mat VisualizeResult(const cv::Mat& img,
  34. const std::vector<PaddleDetection::ObjectResult>& results,
  35. const std::vector<std::string>& lables,
  36. const std::vector<int>& colormap,
  37. const bool is_rbox);
  38. class ObjectDetector {
  39. public:
  40. explicit ObjectDetector(const std::string& model_dir,
  41. int cpu_threads = 1,
  42. const int batch_size = 1) {
  43. config_.load_config(model_dir);
  44. printf("config created\n");
  45. threshold_ = config_.draw_threshold_;
  46. preprocessor_.Init(config_.preprocess_info_);
  47. printf("before object detector\n");
  48. LoadModel(model_dir, cpu_threads);
  49. printf("create object detector\n");
  50. }
  51. // Load Paddle inference model
  52. void LoadModel(std::string model_file, int num_theads);
  53. // Run predictor
  54. void Predict(const std::vector<cv::Mat>& imgs,
  55. const double threshold = 0.5,
  56. const int warmup = 0,
  57. const int repeats = 1,
  58. std::vector<PaddleDetection::ObjectResult>* result = nullptr,
  59. std::vector<int>* bbox_num = nullptr,
  60. std::vector<double>* times = nullptr);
  61. // Get Model Label list
  62. const std::vector<std::string>& GetLabelList() const {
  63. return config_.label_list_;
  64. }
  65. private:
  66. // Preprocess image and copy data to input buffer
  67. void Preprocess(const cv::Mat& image_mat);
  68. // Postprocess result
  69. void Postprocess(const std::vector<cv::Mat> mats,
  70. std::vector<PaddleDetection::ObjectResult>* result,
  71. std::vector<int> bbox_num,
  72. bool is_rbox);
  73. std::shared_ptr<PaddlePredictor> predictor_;
  74. Preprocessor preprocessor_;
  75. ImageBlob inputs_;
  76. std::vector<float> output_data_;
  77. std::vector<int> out_bbox_num_data_;
  78. float threshold_;
  79. ConfigPaser config_;
  80. };
  81. } // namespace PaddleDetection