postprocess_mobilenet_ssd_plate_detection.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. /*************************************************************************
  2. * Copyright (C) [2021] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include <algorithm>
  21. #include <cstring>
  22. #include <iostream>
  23. #include <memory>
  24. #include <string>
  25. #include <utility>
  26. #include <vector>
  27. #include "cnstream_frame_va.hpp"
  28. #include "postproc.hpp"
  29. class PostprocMSSDPlateDetection : public cnstream::ObjPostproc {
  30. public:
  31. /**
  32. * @brief Execute postproc on neural ssd network outputs
  33. *
  34. * @param net_outputs: neural network outputs
  35. * @param model: model information(you can get input shape and output shape from model)
  36. * @param package: smart pointer of struct to store processed result
  37. * @param obj: the object to be processed
  38. *
  39. * @return return 0 if succeed
  40. */
  41. int Execute(const std::vector<float*>& net_outputs, const std::shared_ptr<edk::ModelLoader>& model,
  42. const cnstream::CNFrameInfoPtr& package,
  43. const std::shared_ptr<cnstream::CNInferObject>& obj) override;
  44. DECLARE_REFLEX_OBJECT_EX(PostprocMSSDPlateDetection, cnstream::ObjPostproc)
  45. }; // class PostprocMSSDPlateDetection
  46. IMPLEMENT_REFLEX_OBJECT_EX(PostprocMSSDPlateDetection, cnstream::ObjPostproc)
  47. #define CLIP(x) ((x) < 0 ? 0 : ((x) > 1 ? 1 : (x)))
  48. int PostprocMSSDPlateDetection::Execute(const std::vector<float*>& net_outputs,
  49. const std::shared_ptr<edk::ModelLoader>& model,
  50. const cnstream::CNFrameInfoPtr& package,
  51. const std::shared_ptr<cnstream::CNInferObject>& obj) {
  52. auto data = net_outputs[0];
  53. auto box_num = data[0];
  54. data += 64;
  55. // find the plate with the highest score
  56. float max_score = -1.0f;
  57. cnstream::CNInferBoundingBox selected_bbox;
  58. for (decltype(box_num) bi = 0; bi < box_num; ++bi) {
  59. float cur_score = data[2];
  60. if (cur_score > max_score) {
  61. max_score = cur_score;
  62. selected_bbox.x = data[3];
  63. selected_bbox.y = data[4];
  64. selected_bbox.w = data[5] - selected_bbox.x;
  65. selected_bbox.h = data[6] - selected_bbox.y;
  66. }
  67. data += 7;
  68. }
  69. if (max_score < threshold_) return 0; // no plate found
  70. // coordinates to the original image
  71. const auto& vehicle_bbox = obj->bbox;
  72. selected_bbox.x = selected_bbox.x * vehicle_bbox.w + vehicle_bbox.x;
  73. selected_bbox.y = selected_bbox.y * vehicle_bbox.h + vehicle_bbox.y;
  74. selected_bbox.w = selected_bbox.w * vehicle_bbox.w;
  75. selected_bbox.h = selected_bbox.h * vehicle_bbox.h;
  76. selected_bbox.x = CLIP(selected_bbox.x);
  77. selected_bbox.y = CLIP(selected_bbox.y);
  78. selected_bbox.w = std::min(1 - selected_bbox.x, selected_bbox.w);
  79. selected_bbox.h = std::min(1 - selected_bbox.y, selected_bbox.h);
  80. if (selected_bbox.w <= 0.0f || selected_bbox.h <= 0.0f) return 0;
  81. std::shared_ptr<cnstream::CNInferObject> plate_object = std::make_shared<cnstream::CNInferObject>();
  82. plate_object->id = "80"; // the label index in CNStream/data/models/label_map_coco_add_license_plate.txt
  83. plate_object->score = max_score;
  84. plate_object->bbox = selected_bbox;
  85. // plate flag is used by PlateFilter
  86. // see CNStream/samples/common/obj_filter/plate_filter.cpp
  87. plate_object->collection.Add("plate_flag", true);
  88. // in order to facilitate the addition of the recognized license plate to the vehicle attributes.
  89. // see CNStream/samples/common/postprocess/postprocess_lprnet.cpp
  90. plate_object->collection.Add("plate_container", obj);
  91. cnstream::CNInferObjsPtr objs_holder = package->collection.Get<cnstream::CNInferObjsPtr>(cnstream::kCNInferObjsTag);
  92. cnstream::CNObjsVec& objs = objs_holder->objs_;
  93. std::lock_guard<std::mutex> lk(objs_holder->mutex_);
  94. objs.push_back(plate_object);
  95. return 0;
  96. }