test_thread_pool.cpp 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /*************************************************************************
  2. * Copyright (C) [2019] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include <gtest/gtest.h>
  21. #include <atomic>
  22. #include <chrono>
  23. #include <condition_variable>
  24. #include <memory>
  25. #include <mutex>
  26. #include <thread>
  27. #include "infer_thread_pool.hpp"
  28. namespace cnstream {
  29. class InferThreadPoolTest {
  30. public:
  31. explicit InferThreadPoolTest(InferThreadPool* tp) : tp_(tp) {}
  32. InferTaskSptr PopTask() { return tp_->PopTask(); }
  33. int GetThreadNum() { return static_cast<int>(tp_->threads_.size()); }
  34. int GetTaskNum() {
  35. std::unique_lock<std::mutex> lk(tp_->mtx_);
  36. return static_cast<int>(tp_->task_q_.size());
  37. }
  38. private:
  39. InferThreadPool* tp_;
  40. };
  41. TEST(Inferencer, InferThreadPool_Constructor) {
  42. std::shared_ptr<InferThreadPool> tp = NULL;
  43. EXPECT_NO_THROW(tp = std::make_shared<InferThreadPool>());
  44. InferThreadPoolTest tp_test(tp.get());
  45. EXPECT_EQ(tp_test.GetThreadNum(), 0);
  46. }
  47. TEST(Inferencer, InferThreadPool_Init) {
  48. InferThreadPool tp;
  49. EXPECT_NO_THROW(tp.Init(0, 0));
  50. InferThreadPoolTest tp_test(&tp);
  51. EXPECT_EQ(tp_test.GetThreadNum(), 0);
  52. tp.Destroy();
  53. EXPECT_NO_THROW(tp.Init(0, 5));
  54. EXPECT_EQ(tp_test.GetThreadNum(), 5);
  55. tp.Destroy();
  56. }
  57. TEST(Inferencer, InferThreadPool_Destroy) {
  58. InferThreadPool tp;
  59. EXPECT_NO_THROW(tp.Init(0, 1));
  60. EXPECT_NO_THROW(tp.Destroy());
  61. InferThreadPoolTest tp_test(&tp);
  62. EXPECT_EQ(tp_test.GetThreadNum(), 0);
  63. }
  64. TEST(Inferencer, InferThreadPool_SubmitTask) {
  65. InferTaskSptr task = std::make_shared<InferTask>([]() -> int { return 1; });
  66. InferThreadPool tp;
  67. tp.Destroy();
  68. /* not running, submit task failed */
  69. EXPECT_NO_THROW(tp.SubmitTask(task));
  70. InferThreadPoolTest tp_test(&tp);
  71. EXPECT_EQ(tp_test.GetTaskNum(), 0);
  72. /* running, submit task success */
  73. std::condition_variable pause;
  74. std::mutex mtx;
  75. std::atomic<bool> task_run(false);
  76. task = std::make_shared<InferTask>([&]() -> int {
  77. std::unique_lock<std::mutex> lk(mtx);
  78. /*
  79. pause and block the only one thread in threadpool
  80. */
  81. task_run.store(true);
  82. pause.wait(lk);
  83. return 0;
  84. });
  85. auto task2 = std::make_shared<InferTask>([]() -> int { return 0; });
  86. tp.Init(0, 1);
  87. EXPECT_NO_THROW(tp.SubmitTask(task));
  88. while (!task_run.load()) {
  89. // wait for the first task is running
  90. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  91. }
  92. EXPECT_NO_THROW(tp.SubmitTask(task2));
  93. EXPECT_EQ(tp_test.GetTaskNum(), 1);
  94. pause.notify_one();
  95. tp.Destroy();
  96. }
  97. TEST(Inferencer, InferThreadPool_PopTask) {
  98. std::condition_variable pause;
  99. std::mutex mtx;
  100. std::atomic<bool> task_run(false);
  101. InferTaskSptr task = std::make_shared<InferTask>([&]() -> int {
  102. task_run.store(true);
  103. std::unique_lock<std::mutex> lk(mtx);
  104. /*
  105. pause and block the only one thread in threadpool
  106. */
  107. pause.wait(lk);
  108. return 1;
  109. });
  110. InferThreadPool tp;
  111. tp.Init(0, 1);
  112. tp.SubmitTask(task);
  113. InferTaskSptr task_for_pop = std::make_shared<InferTask>([&]() -> int { return 1; });
  114. task_for_pop->task_msg = "test_pop";
  115. tp.SubmitTask(task_for_pop);
  116. while (!task_run.load()) {
  117. std::this_thread::sleep_for(std::chrono::milliseconds(100));
  118. }
  119. InferThreadPoolTest tp_test(&tp);
  120. auto task_popped = tp_test.PopTask();
  121. EXPECT_EQ(task_popped->task_msg, "test_pop");
  122. pause.notify_one();
  123. tp.Destroy();
  124. }
  125. TEST(Inferencer, InferThreadPool_TaskSequence) {
  126. constexpr int ktask_num = 5;
  127. InferThreadPool tp;
  128. tp.Init(0, ktask_num);
  129. std::chrono::steady_clock::time_point ts[ktask_num]; // NOLINT
  130. InferTaskSptr tasks[ktask_num]; // NOLINT
  131. std::function<int(std::chrono::steady_clock::time_point * t)> func =
  132. [](std::chrono::steady_clock::time_point* t) -> int {
  133. *t = std::chrono::steady_clock::now();
  134. return 0;
  135. };
  136. for (int i = 0; i < ktask_num; ++i) {
  137. tasks[i] = std::make_shared<InferTask>(std::bind(func, ts + i));
  138. if (i != 0) {
  139. tasks[i]->BindFrontTask(tasks[i - 1]);
  140. }
  141. }
  142. for (int i = ktask_num - 1; i >= 0; --i) {
  143. tp.SubmitTask(tasks[i]);
  144. }
  145. for (auto& task : tasks) {
  146. task->WaitForTaskComplete();
  147. }
  148. for (int i = 1; i < ktask_num; ++i) {
  149. EXPECT_GT(ts[i].time_since_epoch().count(), ts[i - 1].time_since_epoch().count());
  150. }
  151. tp.Destroy();
  152. }
  153. } // namespace cnstream