test_session.cpp 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. /*************************************************************************
  2. * Copyright (C) [2020] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include <gtest/gtest.h>
  21. #include <memory>
  22. #include <string>
  23. #include <utility>
  24. #include "cnis/infer_server.h"
  25. #include "cnis/processor.h"
  26. #include "core/session.h"
  27. #include "test_base.h"
  28. namespace infer_server {
  29. auto g_empty_preproc_func = [](ModelIO*, const InferData&, const ModelInfo&) { return true; };
  30. constexpr int device_id = 0;
  31. #ifdef CNIS_USE_MAGICMIND
  32. static const char* model_url = "http://video.cambricon.com/models/MLU370/resnet50_nhwc_tfu_0.5_int8_fp16.model";
  33. #else
  34. static const char* model_url = "http://video.cambricon.com/models/MLU270/Primary_Detector/ssd/resnet34_ssd.cambricon";
  35. #endif
  36. class TestObserver : public Observer {
  37. public:
  38. TestObserver(std::condition_variable* response_cond, std::mutex* response_mutex, std::atomic<bool>* done)
  39. : response_cond_(response_cond), response_mutex_(response_mutex), done_(done) {}
  40. void Response(Status status, PackagePtr data, any user_data) noexcept override {
  41. std::unique_lock<std::mutex> lk(*response_mutex_);
  42. done_->store(true);
  43. lk.unlock();
  44. response_cond_->notify_one();
  45. }
  46. private:
  47. std::condition_variable* response_cond_;
  48. std::mutex* response_mutex_;
  49. std::atomic<bool>* done_;
  50. };
  51. static SessionDesc ReturnSessionDesc(const std::string& name, std::shared_ptr<Processor> preproc, size_t batch_timeout,
  52. BatchStrategy strategy, uint32_t engine_num) {
  53. SessionDesc desc;
  54. desc.name = name;
  55. desc.model = InferServer::LoadModel(model_url);
  56. desc.strategy = strategy;
  57. desc.postproc = Postprocessor::Create();
  58. desc.batch_timeout = 10;
  59. desc.engine_num = engine_num;
  60. desc.show_perf = true;
  61. desc.priority = 0;
  62. desc.host_output_layout = {infer_server::DataType::FLOAT32, infer_server::DimOrder::NCHW};
  63. if (preproc) {
  64. desc.preproc = preproc;
  65. desc.preproc->SetParams<PreprocessorHost::ProcessFunction>("process_function", g_empty_preproc_func);
  66. }
  67. return desc;
  68. }
  69. TEST(InferServerCore, SessionInit) {
  70. // Session init
  71. PriorityThreadPool tp(nullptr);
  72. auto preproc = std::make_shared<PreprocessorHost>();
  73. SessionDesc desc = ReturnSessionDesc("test session", preproc, 5, BatchStrategy::DYNAMIC, 1);
  74. std::unique_ptr<Executor> executor(new Executor(desc, &tp, 0));
  75. std::unique_ptr<Session> session(new Session("init session", executor.get(), false, true));
  76. executor->Link(session.get());
  77. // Session other function
  78. std::string get_session_name = session->GetName();
  79. ASSERT_EQ(session->GetName(), "init session");
  80. ASSERT_EQ(session->GetExecutor(), executor.get());
  81. ASSERT_EQ(session->IsSyncLink(), false);
  82. std::condition_variable response_cond;
  83. std::mutex response_mutex;
  84. std::atomic<bool> done(false);
  85. std::shared_ptr<Observer> test_observer = std::make_shared<TestObserver>(&response_cond, &response_mutex, &done);
  86. session->SetObserver(test_observer);
  87. ASSERT_EQ(session->GetRawObserver(), test_observer.get());
  88. executor->Unlink(session.get());
  89. }
  90. TEST(InferServerCore, SessionSend) {
  91. PriorityThreadPool tp([]() -> bool { return SetCurrentDevice(device_id); }, 3);
  92. SessionDesc desc =
  93. ReturnSessionDesc("test session", std::make_shared<PreprocessorHost>(), 5, BatchStrategy::DYNAMIC, 1);
  94. std::unique_ptr<Executor> executor(new Executor(desc, &tp, 0));
  95. std::unique_ptr<Session> session(new Session("init session", executor.get(), false, true));
  96. executor->Link(session.get());
  97. std::condition_variable response_cond;
  98. std::mutex response_mutex;
  99. std::atomic<bool> done(false);
  100. std::shared_ptr<Observer> test_observer = std::make_shared<TestObserver>(&response_cond, &response_mutex, &done);
  101. session->SetObserver(std::move(test_observer));
  102. // Session send sucess
  103. std::string tag = "test tag";
  104. auto input = Package::Create(1, tag);
  105. any user_data;
  106. ASSERT_TRUE(
  107. session->Send(std::move(input), std::bind(&Observer::Response, session->GetRawObserver(), std::placeholders::_1,
  108. std::placeholders::_2, std::move(user_data))));
  109. std::unique_lock<std::mutex> lk(response_mutex);
  110. response_cond.wait(lk, [&done]() { return done.load(); });
  111. ASSERT_NO_THROW(session->WaitTaskDone(tag));
  112. executor->Unlink(session.get());
  113. }
  114. TEST(InferServerCore, SessionCheckAndResponse) {
  115. PriorityThreadPool tp([]() -> bool { return SetCurrentDevice(device_id); }, 3);
  116. SessionDesc desc =
  117. ReturnSessionDesc("test session", std::make_shared<PreprocessorHost>(), 5, BatchStrategy::DYNAMIC, 1);
  118. std::unique_ptr<Executor> executor(new Executor(desc, &tp, 0));
  119. std::unique_ptr<Session> session(new Session("init session", executor.get(), false, true));
  120. executor->Link(session.get());
  121. std::condition_variable response_cond;
  122. std::mutex response_mutex;
  123. std::atomic<bool> done(false);
  124. std::shared_ptr<Observer> test_observer = std::make_shared<TestObserver>(&response_cond, &response_mutex, &done);
  125. session->SetObserver(std::move(test_observer));
  126. auto input = Package::Create(1, "test tag");
  127. auto ctrl = session->Send(std::move(input), std::bind(&Observer::Response, session->GetRawObserver(),
  128. std::placeholders::_1, std::placeholders::_2, nullptr));
  129. std::unique_lock<std::mutex> lk(response_mutex);
  130. response_cond.wait(lk, [&]() { return done.load(); });
  131. ASSERT_NO_THROW(session->CheckAndResponse(ctrl));
  132. executor->Unlink(session.get());
  133. }
  134. TEST(InferServerCore, SessionDiscardTask) {
  135. PriorityThreadPool tp([]() -> bool { return SetCurrentDevice(device_id); }, 3);
  136. SessionDesc desc =
  137. ReturnSessionDesc("test session", std::make_shared<PreprocessorHost>(), 5, BatchStrategy::DYNAMIC, 1);
  138. std::unique_ptr<Executor> executor(new Executor(desc, &tp, 0));
  139. std::unique_ptr<Session> session(new Session("init session", executor.get(), false, true));
  140. executor->Link(session.get());
  141. std::condition_variable response_cond;
  142. std::mutex response_mutex;
  143. std::atomic<bool> done(false);
  144. std::shared_ptr<TestObserver> test_observer = std::make_shared<TestObserver>(&response_cond, &response_mutex, &done);
  145. session->SetObserver(std::move(test_observer));
  146. // stream1
  147. std::string tag1 = "test tag1";
  148. auto input1 = Package::Create(20, tag1);
  149. // stream2
  150. std::string tag2 = "test tag2";
  151. auto input2 = Package::Create(20, tag2);
  152. session->Send(std::move(input1), std::bind(&Observer::Response, session->GetRawObserver(), std::placeholders::_1,
  153. std::placeholders::_2, nullptr));
  154. session->Send(std::move(input2), std::bind(&Observer::Response, session->GetRawObserver(), std::placeholders::_1,
  155. std::placeholders::_2, nullptr));
  156. ASSERT_NO_THROW(session->DiscardTask(tag1));
  157. std::unique_lock<std::mutex> lk(response_mutex);
  158. response_cond.wait(lk, [&]() { return done.load(); });
  159. executor->Unlink(session.get());
  160. }
  161. } // namespace infer_server