postprocess_vehicle_cts.cpp 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. /*************************************************************************
  2. * Copyright (C) [2021] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include <algorithm>
  21. #include <array>
  22. #include <functional>
  23. #include <memory>
  24. #include <string>
  25. #include <utility>
  26. #include <vector>
  27. #include "cnstream_frame_va.hpp"
  28. #include "postproc.hpp"
  29. #include "cnstream_logging.hpp"
  30. /*
  31. * @brief
  32. * Postprocessing for model cnstream/data/models/vehicle_cts_b4c4_bgra_mlu270.cambricon
  33. **/
  34. class PostprocVehicleCts : public cnstream::ObjPostproc {
  35. public:
  36. int Execute(const std::vector<float*>& net_outputs, const std::shared_ptr<edk::ModelLoader>& model,
  37. const cnstream::CNFrameInfoPtr& finfo, const std::shared_ptr<cnstream::CNInferObject>& obj) override;
  38. DECLARE_REFLEX_OBJECT_EX(PostprocVehicleCts, cnstream::ObjPostproc)
  39. }; // classd ObjPostprocClassification
  40. IMPLEMENT_REFLEX_OBJECT_EX(PostprocVehicleCts, cnstream::ObjPostproc)
  41. int PostprocVehicleCts::Execute(const std::vector<float*>& net_outputs,
  42. const std::shared_ptr<edk::ModelLoader>& model,
  43. const cnstream::CNFrameInfoPtr& finfo,
  44. const std::shared_ptr<cnstream::CNInferObject>& obj) {
  45. static const std::array<std::string, 3> category_names = {"COLOR", "TYPE", "TOWARDS"};
  46. static const std::array<std::vector<std::string>, 3> categories = {
  47. std::vector<std::string>({ /* colors */
  48. "BROWN", "DARK_GREY", "GREY", "WHITE", "PINK", "PURPLE",
  49. "RED", "GREEN", "BLUE", "GOLD", "CYAN", "YELLOW", "BLACK"
  50. }),
  51. std::vector<std::string>({ /* types */
  52. "MPV", "MEGA_BUS", "HGV", "MINI_BUS", "COMPACT_VAN", "MINI_VAN",
  53. "PICKUP", "SUV", "LIGHT_BUS", "CAR"
  54. }),
  55. std::vector<std::string>({ /* sides */
  56. "BACK", "FRONT", "SIDE", "BACK_LEFT", "BACK_RIGHT", "FRONT_LEFT",
  57. "FRONT_RIGHT"
  58. }),
  59. };
  60. bool check_model = true;
  61. if (model->OutputNum() != categories.size()) {
  62. check_model = false;
  63. } else {
  64. for (uint32_t output_idx = 0; output_idx < model->OutputNum(); ++output_idx) {
  65. if (static_cast<size_t>(model->OutputShape(output_idx).DataCount()) != categories[output_idx].size()) {
  66. check_model = false;
  67. break;
  68. }
  69. }
  70. }
  71. if (!check_model)
  72. LOGF(POSTPROC_VEHICLE_CTS) << "Model mismatched.";
  73. auto ArgMax = [] (float* data, size_t size) {
  74. return std::distance(data, std::max_element(data, data + size));
  75. };
  76. for (uint32_t output_idx = 0; output_idx < model->OutputNum(); ++output_idx) {
  77. float* net_output = net_outputs[output_idx];
  78. auto max_score_idx = ArgMax(net_output, model->OutputShape(output_idx).DataCount());
  79. if (net_output[max_score_idx] < 0.3) {
  80. obj->AddExtraAttribute(category_names[output_idx], "uncertain");
  81. } else {
  82. std::string score_str = std::to_string(net_output[max_score_idx]);
  83. score_str = score_str.substr(0, std::min(size_t(4), score_str.size()));
  84. std::string str = categories[output_idx][max_score_idx] +
  85. " score[" + score_str + "]";
  86. obj->AddExtraAttribute(category_names[output_idx], str);
  87. }
  88. }
  89. return 0;
  90. }