cnpostproc.cpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. /*************************************************************************
  2. * Copyright (C) [2019] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include "cnpostproc.h"
  21. #include <algorithm> // sort
  22. #include <cstring> // memset
  23. #include <list>
  24. #include <string>
  25. #include <utility>
  26. #include <vector>
  27. #include "cxxutil/log.h"
  28. using std::pair;
  29. using std::vector;
  30. using std::to_string;
  31. namespace edk {
  32. #define CLIP(x) ((x) < 0 ? 0 : ((x) > 1 ? 1 : (x)))
  33. void CnPostproc::set_threshold(const float threshold) { threshold_ = threshold; }
  34. vector<DetectObject> CnPostproc::Execute(const vector<pair<float*, uint64_t>>& net_outputs) {
  35. return Postproc(net_outputs);
  36. }
  37. vector<DetectObject> ClassificationPostproc::Postproc(const vector<pair<float*, uint64_t>>& net_outputs) {
  38. if (net_outputs.size() != 1) {
  39. LOGW(SAMPLES) << "Classification neuron network only has one output but get " + to_string(net_outputs.size());
  40. }
  41. float* data = net_outputs[0].first;
  42. uint64_t len = net_outputs[0].second;
  43. std::list<DetectObject> objs;
  44. for (decltype(len) i = 0; i < len; ++i) {
  45. if (data[i] < threshold_) continue;
  46. DetectObject obj;
  47. memset(&obj.bbox, 0, sizeof(BoundingBox));
  48. obj.label = i;
  49. obj.score = data[i];
  50. objs.emplace_back(std::move(obj));
  51. }
  52. objs.sort([](const DetectObject& a, const DetectObject& b) { return a.score > b.score; });
  53. return std::vector<DetectObject>(objs.begin(), objs.end());
  54. }
  55. vector<DetectObject> SsdPostproc::Postproc(const vector<pair<float*, uint64_t>>& net_outputs) {
  56. if (net_outputs.size() != 1) {
  57. LOGW(SAMPLES) << "Ssd neuron network only has one output, but get " + to_string(net_outputs.size());
  58. }
  59. vector<DetectObject> objs;
  60. float* data = net_outputs[0].first;
  61. // auto len = net_outputs[0].second;
  62. float box_num = data[0]; // get box num by batch index
  63. data += 64; // skip box num of all batch
  64. for (decltype(box_num) bi = 0; bi < box_num; ++bi) {
  65. DetectObject obj;
  66. if (data[1] == 0) continue;
  67. obj.label = data[1] - 1;
  68. obj.score = data[2];
  69. if (threshold_ > 0 && obj.score < threshold_) continue;
  70. obj.bbox.x = CLIP(data[3]);
  71. obj.bbox.y = CLIP(data[4]);
  72. obj.bbox.width = CLIP(data[5]) - obj.bbox.x;
  73. obj.bbox.height = CLIP(data[6]) - obj.bbox.y;
  74. objs.push_back(obj);
  75. data += 7;
  76. }
  77. return objs;
  78. }
  79. namespace detail {
  80. template <typename dtype>
  81. struct Clip {
  82. Clip(dtype _down, dtype _up) : down(_down), up(_up) {}
  83. inline dtype operator()(dtype val) {
  84. return std::min(up, std::max(down, val));
  85. }
  86. dtype down;
  87. dtype up;
  88. };
  89. } // namespace detail
  90. detail::Clip<float> Clip0_1_float(0, 1);
  91. vector<DetectObject> Yolov3Postproc::Postproc(const vector<pair<float*, uint64_t>>& net_outputs) {
  92. vector<DetectObject> objs;
  93. float* data = net_outputs[0].first;
  94. uint64_t len = net_outputs[0].second;
  95. constexpr int box_step = 7;
  96. const int box_num = static_cast<int>(data[0]);
  97. CHECK(SAMPLES, static_cast<uint64_t>(64 + box_num * box_step) <= len);
  98. for (int bi = 0; bi < box_num; ++bi) {
  99. DetectObject obj;
  100. obj.label = static_cast<int>(data[64 + bi * box_step + 1]);
  101. obj.score = data[64 + bi * box_step + 2];
  102. if (obj.label == 0) continue;
  103. if (threshold_ > 0 && obj.score < threshold_) continue;
  104. obj.bbox.x = Clip0_1_float(data[64 + bi * box_step + 3]);
  105. obj.bbox.y = Clip0_1_float(data[64 + bi * box_step + 4]);
  106. obj.bbox.width = Clip0_1_float(data[64 + bi * box_step + 5]) - obj.bbox.x;
  107. obj.bbox.height = Clip0_1_float(data[64 + bi * box_step + 6]) - obj.bbox.y;
  108. obj.bbox.x = (obj.bbox.x - padl_ratio_) / (1 - padl_ratio_ - padr_ratio_);
  109. obj.bbox.y = (obj.bbox.y - padt_ratio_) / (1 - padb_ratio_ - padt_ratio_);
  110. obj.bbox.width /= (1 - padl_ratio_ - padr_ratio_);
  111. obj.bbox.height /= (1 - padb_ratio_ - padt_ratio_);
  112. obj.track_id = -1;
  113. if (obj.bbox.width <= 0) continue;
  114. if (obj.bbox.height <= 0) continue;
  115. objs.push_back(obj);
  116. }
  117. return objs;
  118. }
  119. } // namespace edk