test_apply_stride_align_for_scaler.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /*************************************************************************
  2. * Copyright (C) [2019] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include <gtest/gtest.h>
  21. #include <memory>
  22. #include <string>
  23. #include <utility>
  24. #include <vector>
  25. #include "cnstream_frame_va.hpp"
  26. #include "cnstream_module.hpp"
  27. #include "cnstream_pipeline.hpp"
  28. #include "data_source.hpp"
  29. #include "test_base.hpp"
  30. static constexpr const char *gmp4_path = "../../modules/unitest/source/data/img_300x300.mp4";
  31. class MsgObserverForTest : cnstream::StreamMsgObserver {
  32. public:
  33. void Update(const cnstream::StreamMsg& smsg) override {
  34. if (smsg.type == cnstream::StreamMsgType::EOS_MSG) {
  35. wakener_.set_value();
  36. }
  37. }
  38. void WaitForEos() {
  39. wakener_.get_future().get();
  40. }
  41. private:
  42. std::promise<void> wakener_;
  43. };
  44. class ImageReceiver : public cnstream::Module, public cnstream::ModuleCreator<ImageReceiver> {
  45. public:
  46. explicit ImageReceiver(const std::string& mname) : cnstream::Module(mname) {}
  47. bool Open(cnstream::ModuleParamSet param_set) override { return true; }
  48. void Close() override {}
  49. int Process(std::shared_ptr<cnstream::CNFrameInfo> data) override {
  50. cnstream::CNDataFramePtr frame = data->collection.Get<cnstream::CNDataFramePtr>(cnstream::kCNDataFrameTag);
  51. frames.push_back(frame);
  52. return 0;
  53. }
  54. const std::vector<cnstream::CNDataFramePtr>& GetFrames() const { return frames; }
  55. void Clear() { frames.clear(); }
  56. private:
  57. std::vector<cnstream::CNDataFramePtr> frames;
  58. };
  59. bool CompareFrames(const std::vector<cnstream::CNDataFramePtr> &src_frames,
  60. const std::vector<cnstream::CNDataFramePtr> &aligned_frames) {
  61. EXPECT_EQ(src_frames.size(), aligned_frames.size());
  62. size_t frame_num = src_frames.size();
  63. for (size_t fi = 0; fi < frame_num; ++fi) {
  64. auto src_frame = src_frames[fi];
  65. auto aligned_frame = aligned_frames[fi];
  66. EXPECT_FALSE(aligned_frame->stride[0] % 128);
  67. EXPECT_FALSE(aligned_frame->stride[1] % 128);
  68. auto src_mat = src_frame->ImageBGR();
  69. auto dst_mat = aligned_frame->ImageBGR();
  70. EXPECT_EQ(0, memcmp(src_mat.data, dst_mat.data, src_mat.total() * src_mat.elemSize()));
  71. }
  72. return true;
  73. }
  74. std::vector<cnstream::CNDataFramePtr> GetFrames(const cnstream::ModuleParamSet &source_params) {
  75. cnstream::Pipeline pipeline("pipeline");
  76. cnstream::CNModuleConfig receiver_config;
  77. receiver_config.name = "receiver";
  78. receiver_config.className = "ImageReceiver";
  79. receiver_config.maxInputQueueSize = 5;
  80. receiver_config.parallelism = 1;
  81. cnstream::CNModuleConfig source_config;
  82. source_config.name = "source";
  83. source_config.className = "cnstream::DataSource";
  84. source_config.next = {"receiver"};
  85. source_config.parameters = source_params;
  86. source_config.maxInputQueueSize = 0;
  87. source_config.parallelism = 0;
  88. EXPECT_TRUE(pipeline.BuildPipeline({source_config, receiver_config}));
  89. cnstream::DataSource* source = dynamic_cast<cnstream::DataSource*>(pipeline.GetModule("source"));
  90. ImageReceiver* receiver = dynamic_cast<ImageReceiver*>(pipeline.GetModule("receiver"));
  91. EXPECT_NE(nullptr, source);
  92. EXPECT_NE(nullptr, receiver);
  93. MsgObserverForTest observer;
  94. pipeline.SetStreamMsgObserver(reinterpret_cast<cnstream::StreamMsgObserver*>(&observer));
  95. EXPECT_TRUE(pipeline.Start());
  96. std::string filename = GetExePath() + gmp4_path;
  97. auto handler =
  98. cnstream::FileHandler::Create(source, "0", filename, 30, false);
  99. EXPECT_NE(nullptr, handler);
  100. EXPECT_EQ(0, source->AddSource(handler));
  101. observer.WaitForEos();
  102. pipeline.Stop();
  103. return receiver->GetFrames();
  104. }
  105. static bool TestFunc(const std::string &decoder_type, const std::string &output_type) {
  106. cnstream::ModuleParamSet source_params = {
  107. std::make_pair("decoder_type", decoder_type),
  108. std::make_pair("output_type", decoder_type),
  109. std::make_pair("device_id", "0")
  110. };
  111. auto origin_frames = GetFrames(source_params);
  112. source_params["apply_stride_align_for_scaler"] = "true";
  113. source_params["output_type"] = output_type;
  114. auto aligned_frames = GetFrames(source_params);
  115. // compare frames
  116. return CompareFrames(origin_frames, aligned_frames);
  117. }
  118. TEST(Source_StrideAlign, mlu_decoder_output_cpu) {
  119. // EXPECT_TRUE(TestFunc("mlu", "cpu"));
  120. }
  121. TEST(Source_StrideAlign, mlu_decoder_output_mlu) {
  122. EXPECT_TRUE(TestFunc("mlu", "mlu"));
  123. }
  124. TEST(Source_StrideAlign, cpu_decoder_output_cpu) {
  125. EXPECT_TRUE(TestFunc("cpu", "cpu"));
  126. }
  127. TEST(Source_StrideAlign, cpu_decoder_output_mlu) {
  128. EXPECT_TRUE(TestFunc("cpu", "mlu"));
  129. }