postprocess_classification.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. /*************************************************************************
  2. * Copyright (C) [2019] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include <memory>
  21. #include <string>
  22. #include <utility>
  23. #include <vector>
  24. #include "cnstream_frame_va.hpp"
  25. #include "postproc.hpp"
  26. #include "cnstream_logging.hpp"
  27. class PostprocClassification : public cnstream::Postproc {
  28. public:
  29. int Execute(const std::vector<float*>& net_outputs, const std::shared_ptr<edk::ModelLoader>& model,
  30. const cnstream::CNFrameInfoPtr& package) override;
  31. DECLARE_REFLEX_OBJECT_EX(PostprocClassification, cnstream::Postproc)
  32. }; // classd PostprocClassification
  33. IMPLEMENT_REFLEX_OBJECT_EX(PostprocClassification, cnstream::Postproc)
  34. int PostprocClassification::Execute(const std::vector<float*>& net_outputs,
  35. const std::shared_ptr<edk::ModelLoader>& model,
  36. const cnstream::CNFrameInfoPtr& package) {
  37. if (net_outputs.size() != 1) {
  38. LOGE(DEMO) << "[Warning] classification neuron network only has one output,"
  39. " but get " +
  40. std::to_string(net_outputs.size());
  41. return -1;
  42. }
  43. auto data = net_outputs[0];
  44. auto len = model->OutputShape(0).DataCount();
  45. auto pscore = data;
  46. float mscore = 0;
  47. int label = 0;
  48. for (decltype(len) i = 0; i < len; ++i) {
  49. auto score = *(pscore + i);
  50. if (score > mscore) {
  51. mscore = score;
  52. label = i;
  53. }
  54. }
  55. auto obj = std::make_shared<cnstream::CNInferObject>();
  56. obj->id = std::to_string(label);
  57. obj->score = mscore;
  58. cnstream::CNInferObjsPtr objs_holder = package->collection.Get<cnstream::CNInferObjsPtr>(cnstream::kCNInferObjsTag);
  59. std::lock_guard<std::mutex> objs_mutex(objs_holder->mutex_);
  60. objs_holder->objs_.push_back(obj);
  61. return 0;
  62. }
  63. class ObjPostprocClassification : public cnstream::ObjPostproc {
  64. public:
  65. int Execute(const std::vector<float*>& net_outputs, const std::shared_ptr<edk::ModelLoader>& model,
  66. const cnstream::CNFrameInfoPtr& finfo, const std::shared_ptr<cnstream::CNInferObject>& obj) override;
  67. DECLARE_REFLEX_OBJECT_EX(ObjPostprocClassification, cnstream::ObjPostproc)
  68. }; // classd ObjPostprocClassification
  69. IMPLEMENT_REFLEX_OBJECT_EX(ObjPostprocClassification, cnstream::ObjPostproc)
  70. int ObjPostprocClassification::Execute(const std::vector<float*>& net_outputs,
  71. const std::shared_ptr<edk::ModelLoader>& model,
  72. const cnstream::CNFrameInfoPtr& finfo,
  73. const std::shared_ptr<cnstream::CNInferObject>& obj) {
  74. if (net_outputs.size() != 1) {
  75. LOGE(DEMO) << "[Warning] classification neuron network only has one output,"
  76. " but get " + std::to_string(net_outputs.size());
  77. return -1;
  78. }
  79. auto data = net_outputs[0];
  80. auto len = model->OutputShape(0).DataCount();
  81. auto pscore = data;
  82. float mscore = 0;
  83. int label = 0;
  84. for (decltype(len) i = 0; i < len; ++i) {
  85. auto score = *(pscore + i);
  86. if (score > mscore) {
  87. mscore = score;
  88. label = i;
  89. }
  90. }
  91. cnstream::CNInferAttr attr;
  92. attr.id = 0;
  93. attr.value = label;
  94. attr.score = mscore;
  95. obj->AddAttribute("classification", attr);
  96. return 0;
  97. }