config_parser.h 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <fstream>
  16. #include <iostream>
  17. #include <map>
  18. #include <string>
  19. #include <vector>
  20. #include "json/json.h"
  21. #ifdef _WIN32
  22. #define OS_PATH_SEP "\\"
  23. #else
  24. #define OS_PATH_SEP "/"
  25. #endif
  26. namespace PaddleDetection {
  27. void load_jsonf(std::string jsonfile, Json::Value& jsondata);
  28. // Inference model configuration parser
  29. class ConfigPaser {
  30. public:
  31. ConfigPaser() {}
  32. ~ConfigPaser() {}
  33. bool load_config(const std::string& model_dir,
  34. const std::string& cfg = "infer_cfg") {
  35. Json::Value config;
  36. load_jsonf(model_dir + OS_PATH_SEP + cfg + ".json", config);
  37. // Get model arch : YOLO, SSD, RetinaNet, RCNN, Face, PicoDet, HRNet
  38. if (config.isMember("arch")) {
  39. arch_ = config["arch"].as<std::string>();
  40. } else {
  41. std::cerr
  42. << "Please set model arch,"
  43. << "support value : YOLO, SSD, RetinaNet, RCNN, Face, PicoDet, HRNet."
  44. << std::endl;
  45. return false;
  46. }
  47. // Get draw_threshold for visualization
  48. if (config.isMember("draw_threshold")) {
  49. draw_threshold_ = config["draw_threshold"].as<float>();
  50. } else {
  51. std::cerr << "Please set draw_threshold." << std::endl;
  52. return false;
  53. }
  54. // Get Preprocess for preprocessing
  55. if (config.isMember("Preprocess")) {
  56. preprocess_info_ = config["Preprocess"];
  57. } else {
  58. std::cerr << "Please set Preprocess." << std::endl;
  59. return false;
  60. }
  61. // Get label_list for visualization
  62. if (config.isMember("label_list")) {
  63. label_list_.clear();
  64. for (auto item : config["label_list"]) {
  65. label_list_.emplace_back(item.as<std::string>());
  66. }
  67. } else {
  68. std::cerr << "Please set label_list." << std::endl;
  69. return false;
  70. }
  71. // Get NMS for postprocess
  72. if (config.isMember("NMS")) {
  73. nms_info_ = config["NMS"];
  74. }
  75. // Get fpn_stride in PicoDet
  76. if (config.isMember("fpn_stride")) {
  77. fpn_stride_.clear();
  78. for (auto item : config["fpn_stride"]) {
  79. fpn_stride_.emplace_back(item.as<int>());
  80. }
  81. }
  82. return true;
  83. }
  84. float draw_threshold_;
  85. std::string arch_;
  86. Json::Value preprocess_info_;
  87. Json::Value nms_info_;
  88. std::vector<std::string> label_list_;
  89. std::vector<int> fpn_stride_;
  90. };
  91. } // namespace PaddleDetection