detection_runner.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. /*************************************************************************
  2. * Copyright (C) [2020] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include "detection_runner.h"
  21. #include <opencv2/opencv.hpp>
  22. #include <chrono>
  23. #include <memory>
  24. #include <string>
  25. #include <thread>
  26. #include <utility>
  27. #include <vector>
  28. #include "cxxutil/log.h"
  29. #if CV_VERSION_EPOCH == 2
  30. #define OPENCV_MAJOR_VERSION 2
  31. #elif CV_VERSION_MAJOR >= 3
  32. #define OPENCV_MAJOR_VERSION CV_VERSION_MAJOR
  33. #endif
  34. static const cv::Size g_out_video_size = cv::Size(1280, 720);
  35. DetectionRunner::DetectionRunner(const std::string& model_path, const std::string& func_name,
  36. const std::string& label_path, const std::string& track_model_path,
  37. const std::string& track_func_name, const std::string& data_path,
  38. const std::string& net_type, bool show, bool save_video)
  39. : StreamRunner(data_path), show_(show), save_video_(save_video) {
  40. // set mlu environment
  41. env_.SetDeviceId(0);
  42. env_.BindDevice();
  43. // load offline model
  44. model_ = std::make_shared<edk::ModelLoader>(model_path.c_str(), func_name.c_str());
  45. // prepare mlu memory operator and memory
  46. mem_op_.SetModel(model_);
  47. // init easy_infer
  48. infer_.Init(model_, 0);
  49. // create mlu resize and convert operator
  50. auto& in_shape = model_->InputShape(0);
  51. edk::MluResizeConvertOp::Attr rc_attr;
  52. rc_attr.dst_h = in_shape.H();
  53. rc_attr.dst_w = in_shape.W();
  54. rc_attr.batch_size = 1;
  55. rc_attr.core_version = env_.GetCoreVersion();
  56. rc_op_.SetMluQueue(infer_.GetMluQueue());
  57. if (!rc_op_.Init(rc_attr)) {
  58. THROW_EXCEPTION(edk::Exception::INTERNAL, rc_op_.GetLastError());
  59. }
  60. // init postproc
  61. if (net_type == "SSD") {
  62. postproc_.reset(new edk::SsdPostproc);
  63. } else if (net_type == "YOLOv3") {
  64. postproc_.reset(new edk::Yolov3Postproc);
  65. } else {
  66. THROW_EXCEPTION(edk::Exception::INVALID_ARG, "unsupported net type: " + net_type);
  67. }
  68. postproc_->set_threshold(0.6);
  69. CHECK(SAMPLES, postproc_);
  70. // init tracker
  71. tracker_.reset(new edk::FeatureMatchTrack);
  72. feature_extractor_.reset(new FeatureExtractor);
  73. if (track_model_path != "" && track_model_path != "cpu") {
  74. feature_extractor_->Init(track_model_path.c_str(), track_func_name.c_str());
  75. }
  76. // init osd
  77. osd_.LoadLabels(label_path);
  78. // display or video writer
  79. if (save_video_) {
  80. #if OPENCV_MAJOR_VERSION > 2
  81. video_writer_.reset(
  82. new cv::VideoWriter("out.avi", cv::VideoWriter::fourcc('M', 'J', 'P', 'G'), 25, g_out_video_size));
  83. #else
  84. video_writer_.reset(new cv::VideoWriter("out.avi", CV_FOURCC('M', 'J', 'P', 'G'), 25, g_out_video_size));
  85. #endif
  86. if (!video_writer_->isOpened()) {
  87. THROW_EXCEPTION(edk::Exception::Code::INIT_FAILED, "create output video file failed");
  88. }
  89. }
  90. mlu_input_ = mem_op_.AllocMluInput();
  91. mlu_output_ = mem_op_.AllocMluOutput();
  92. cpu_output_ = mem_op_.AllocCpuOutput();
  93. Start();
  94. }
  95. DetectionRunner::~DetectionRunner() {
  96. Stop();
  97. if (nullptr != mlu_output_) mem_op_.FreeMluOutput(mlu_output_);
  98. if (nullptr != cpu_output_) mem_op_.FreeCpuOutput(cpu_output_);
  99. if (nullptr != mlu_input_) mem_op_.FreeMluInput(mlu_input_);
  100. }
  101. void DetectionRunner::Process(edk::CnFrame frame) {
  102. // run resize and convert
  103. void* rc_output = mlu_input_[0];
  104. edk::MluResizeConvertOp::InputData input;
  105. input.planes[0] = frame.ptrs[0];
  106. input.planes[1] = frame.ptrs[1];
  107. input.src_w = frame.width;
  108. input.src_h = frame.height;
  109. input.src_stride = frame.strides[0];
  110. rc_op_.BatchingUp(input);
  111. if (!rc_op_.SyncOneOutput(rc_output)) {
  112. decode_->ReleaseBuffer(frame.buf_id);
  113. THROW_EXCEPTION(edk::Exception::INTERNAL, rc_op_.GetLastError());
  114. }
  115. // run inference
  116. infer_.Run(mlu_input_, mlu_output_);
  117. mem_op_.MemcpyOutputD2H(cpu_output_, mlu_output_);
  118. // alloc memory to store image
  119. auto img_data = new uint8_t[frame.strides[0] * frame.height * 3 / 2];
  120. // copy out frame
  121. decode_->CopyFrameD2H(img_data, frame);
  122. // release codec buffer
  123. decode_->ReleaseBuffer(frame.buf_id);
  124. // yuv to bgr
  125. cv::Mat yuv(frame.height * 3 / 2, frame.strides[0], CV_8UC1, img_data);
  126. cv::Mat img;
  127. cv::cvtColor(yuv, img, cv::COLOR_YUV2BGR_NV21);
  128. delete[] img_data;
  129. // resize to show
  130. cv::resize(img, img, cv::Size(1280, 720));
  131. // post process
  132. std::vector<edk::DetectObject> detect_result;
  133. std::vector<std::pair<float*, uint64_t>> postproc_param;
  134. postproc_param.push_back(
  135. std::make_pair(reinterpret_cast<float*>(cpu_output_[0]), model_->OutputShape(0).BatchDataCount()));
  136. detect_result = postproc_->Execute(postproc_param);
  137. // track
  138. edk::TrackFrame track_img;
  139. track_img.data = img.data;
  140. track_img.width = img.cols;
  141. track_img.height = img.rows;
  142. track_img.format = edk::TrackFrame::ColorSpace::RGB24;
  143. static int64_t frame_id = 0;
  144. track_img.frame_id = frame_id++;
  145. // extract feature
  146. for (auto& obj : detect_result) {
  147. obj.feature = feature_extractor_->ExtractFeature(track_img, obj);
  148. }
  149. std::vector<edk::DetectObject> track_result;
  150. track_result.clear();
  151. tracker_->UpdateFrame(track_img, detect_result, &track_result);
  152. std::cout << "----- Object detected in one frame:\n";
  153. for (auto& obj : track_result) {
  154. std::cout << obj << "\n";
  155. }
  156. std::cout << "-----------------------------------\n" << std::endl;
  157. osd_.DrawLabel(img, track_result);
  158. if (show_) {
  159. auto window_name = "stream app";
  160. cv::imshow(window_name, img);
  161. cv::waitKey(5);
  162. // std::string fn = std::to_string(frame.frame_id) + ".jpg";
  163. // cv::imwrite(fn.c_str(), img);
  164. }
  165. if (save_video_) {
  166. video_writer_->write(img);
  167. }
  168. }