classification_runner.cpp 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. /*************************************************************************
  2. * Copyright (C) [2020] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include "classification_runner.h"
  21. #include <opencv2/opencv.hpp>
  22. #include <chrono>
  23. #include <memory>
  24. #include <string>
  25. #include <thread>
  26. #include <utility>
  27. #include <vector>
  28. #include "cxxutil/log.h"
  29. #if CV_VERSION_EPOCH == 2
  30. #define OPENCV_MAJOR_VERSION 2
  31. #elif CV_VERSION_MAJOR >= 3
  32. #define OPENCV_MAJOR_VERSION CV_VERSION_MAJOR
  33. #endif
  34. static const cv::Size g_out_video_size = cv::Size(1280, 720);
  35. ClassificationRunner::ClassificationRunner(const std::string& model_path, const std::string& func_name,
  36. const std::string& label_path, const std::string& data_path, bool show,
  37. bool save_video)
  38. : StreamRunner(data_path), show_(show), save_video_(save_video) {
  39. // load offline model
  40. model_ = std::make_shared<edk::ModelLoader>(model_path.c_str(), func_name.c_str());
  41. // prepare mlu memory operator and memory
  42. mem_op_.SetModel(model_);
  43. // init easy_infer
  44. infer_.Init(model_, 0);
  45. // create mlu resize and convert operator
  46. auto& in_shape = model_->InputShape(0);
  47. edk::MluResizeConvertOp::Attr rc_attr;
  48. rc_attr.dst_h = in_shape.H();
  49. rc_attr.dst_w = in_shape.W();
  50. rc_attr.batch_size = 1;
  51. rc_attr.core_version = env_.GetCoreVersion();
  52. rc_op_.SetMluQueue(infer_.GetMluQueue());
  53. if (!rc_op_.Init(rc_attr)) {
  54. THROW_EXCEPTION(edk::Exception::INTERNAL, rc_op_.GetLastError());
  55. }
  56. // init postproc
  57. postproc_.reset(new edk::ClassificationPostproc);
  58. postproc_->set_threshold(0.2);
  59. CHECK(SAMPLES, postproc_);
  60. // init osd
  61. osd_.LoadLabels(label_path);
  62. // video writer
  63. if (save_video_) {
  64. #if OPENCV_MAJOR_VERSION > 2
  65. video_writer_.reset(
  66. new cv::VideoWriter("out.avi", cv::VideoWriter::fourcc('M', 'J', 'P', 'G'), 25, g_out_video_size));
  67. #else
  68. video_writer_.reset(new cv::VideoWriter("out.avi", CV_FOURCC('M', 'J', 'P', 'G'), 25, g_out_video_size));
  69. #endif
  70. if (!video_writer_->isOpened()) {
  71. THROW_EXCEPTION(edk::Exception::Code::INIT_FAILED, "create output video file failed");
  72. }
  73. }
  74. mlu_input_ = mem_op_.AllocMluInput();
  75. mlu_output_ = mem_op_.AllocMluOutput();
  76. cpu_output_ = mem_op_.AllocCpuOutput();
  77. Start();
  78. }
  79. ClassificationRunner::~ClassificationRunner() {
  80. Stop();
  81. if (nullptr != mlu_output_) mem_op_.FreeMluOutput(mlu_output_);
  82. if (nullptr != cpu_output_) mem_op_.FreeCpuOutput(cpu_output_);
  83. if (nullptr != mlu_input_) mem_op_.FreeMluInput(mlu_input_);
  84. }
  85. void ClassificationRunner::Process(edk::CnFrame frame) {
  86. // run resize and convert
  87. void* rc_output = mlu_input_[0];
  88. edk::MluResizeConvertOp::InputData input;
  89. input.planes[0] = frame.ptrs[0];
  90. input.planes[1] = frame.ptrs[1];
  91. input.src_w = frame.width;
  92. input.src_h = frame.height;
  93. input.src_stride = frame.strides[0];
  94. rc_op_.BatchingUp(input);
  95. if (!rc_op_.SyncOneOutput(rc_output)) {
  96. decode_->ReleaseBuffer(frame.buf_id);
  97. THROW_EXCEPTION(edk::Exception::INTERNAL, rc_op_.GetLastError());
  98. }
  99. // run inference
  100. infer_.Run(mlu_input_, mlu_output_);
  101. mem_op_.MemcpyOutputD2H(cpu_output_, mlu_output_);
  102. // alloc memory to store image
  103. auto img_data = new uint8_t[frame.strides[0] * frame.height * 3 / 2];
  104. // copy out frame
  105. decode_->CopyFrameD2H(img_data, frame);
  106. // release codec buffer
  107. decode_->ReleaseBuffer(frame.buf_id);
  108. // yuv to bgr
  109. cv::Mat yuv(frame.height * 3 / 2, frame.strides[0], CV_8UC1, img_data);
  110. cv::Mat img;
  111. cv::cvtColor(yuv, img, cv::COLOR_YUV2BGR_NV21);
  112. delete[] img_data;
  113. // resize to show
  114. cv::resize(img, img, cv::Size(1280, 720));
  115. // post process
  116. std::vector<edk::DetectObject> detect_result;
  117. std::vector<std::pair<float*, uint64_t>> postproc_param;
  118. postproc_param.push_back(
  119. std::make_pair(reinterpret_cast<float*>(cpu_output_[0]), model_->OutputShape(0).DataCount()));
  120. detect_result = postproc_->Execute(postproc_param);
  121. std::cout << "----- Classification Result:\n";
  122. int show_number = 2;
  123. for (auto& obj : detect_result) {
  124. std::cout << "[Object] label: " << obj.label << " score: " << obj.score << "\n";
  125. if (!(--show_number)) break;
  126. }
  127. std::cout << "-----------------------------------\n" << std::endl;
  128. osd_.DrawLabel(img, detect_result);
  129. if (show_) {
  130. auto window_name = "classification";
  131. cv::imshow(window_name, img);
  132. cv::waitKey(5);
  133. // std::string fn = std::to_string(frame.frame_id) + ".jpg";
  134. // cv::imwrite(fn.c_str(), img);
  135. }
  136. if (save_video_) {
  137. video_writer_->write(img);
  138. }
  139. }