test_buffer.cpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. /*************************************************************************
  2. * Copyright (C) [2020] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include <gtest/gtest.h>
  21. #include <future>
  22. #include <memory>
  23. #include <thread>
  24. #include <vector>
  25. #include "cnis/buffer.h"
  26. #include "cxxutil/exception.h"
  27. #include "fixture.h"
  28. TEST_F(InferServerTestAPI, Buffer) {
  29. int times = 100;
  30. char *raw_str = new char[20 * 100 * 16];
  31. char *raw_str_out = new char[20 * 100 * 16];
  32. infer_server::Buffer str(20 * 100 * 16);
  33. infer_server::Buffer str_out(20 * 100 * 16);
  34. EXPECT_FALSE(str.OnMlu());
  35. EXPECT_FALSE(str_out.OnMlu());
  36. while (--times) {
  37. try {
  38. size_t str_size = 170 * times;
  39. snprintf(raw_str, str_size, "test MluMemory, s: %lu", str_size);
  40. infer_server::Buffer mlu_src(str_size, device_id_);
  41. infer_server::Buffer mlu_dst(str_size, device_id_);
  42. EXPECT_FALSE(mlu_src.OwnMemory());
  43. EXPECT_FALSE(mlu_dst.OwnMemory());
  44. (void)mlu_src.MutableData();
  45. (void)mlu_dst.MutableData();
  46. EXPECT_TRUE(mlu_src.OwnMemory());
  47. EXPECT_TRUE(mlu_dst.OwnMemory());
  48. EXPECT_TRUE(mlu_src.OnMlu());
  49. EXPECT_TRUE(mlu_dst.OnMlu());
  50. str.CopyFrom(raw_str, str_size);
  51. mlu_src.CopyFrom(str, str_size);
  52. mlu_dst.CopyFrom(mlu_src, str_size);
  53. mlu_dst.CopyTo(&str_out, str_size);
  54. str_out.CopyTo(raw_str_out, str_size);
  55. EXPECT_STREQ(raw_str, raw_str_out);
  56. } catch (edk::Exception &err) {
  57. EXPECT_TRUE(false) << err.what();
  58. }
  59. }
  60. delete[] raw_str;
  61. delete[] raw_str_out;
  62. }
  63. TEST_F(InferServerTestAPI, MluMemoryPoolBuffer) {
  64. int times = 100;
  65. constexpr size_t kStrLength = 20 * 100 * 16;
  66. constexpr size_t kBufferNum = 6;
  67. char *str = new char[kStrLength];
  68. char *str_out = new char[kStrLength];
  69. infer_server::MluMemoryPool pool(kStrLength, kBufferNum, device_id_);
  70. {
  71. std::vector<infer_server::Buffer> cache;
  72. for (size_t i = 0; i < kBufferNum; ++i) {
  73. EXPECT_NO_THROW(cache.emplace_back(pool.Request()));
  74. }
  75. EXPECT_THROW(pool.Request(10), edk::Exception) << "pool should be empty";
  76. cache.clear();
  77. EXPECT_NO_THROW(pool.Request(10));
  78. }
  79. // test destruct
  80. std::future<void> res2;
  81. {
  82. infer_server::MluMemoryPool p(kStrLength, kBufferNum, device_id_);
  83. std::promise<void> flag;
  84. std::future<void> has_requested = flag.get_future();
  85. res2 = std::async(std::launch::async, [&p, &flag]() {
  86. infer_server::Buffer m;
  87. EXPECT_NO_THROW(m = p.Request());
  88. flag.set_value();
  89. std::this_thread::sleep_for(std::chrono::milliseconds(10));
  90. });
  91. has_requested.get();
  92. }
  93. res2.get();
  94. while (--times) {
  95. try {
  96. size_t a_number = 170 * times;
  97. snprintf(str, kStrLength, "test MluMemory, s: %lu", a_number);
  98. void *in = reinterpret_cast<void *>(str);
  99. void *out = reinterpret_cast<void *>(str_out);
  100. infer_server::Buffer mlu_src, mlu_dst;
  101. EXPECT_FALSE(mlu_src.OwnMemory());
  102. EXPECT_FALSE(mlu_dst.OwnMemory());
  103. EXPECT_NO_THROW(mlu_src = pool.Request(10));
  104. EXPECT_NO_THROW(mlu_dst = pool.Request(10));
  105. EXPECT_TRUE(mlu_src.OwnMemory());
  106. EXPECT_TRUE(mlu_dst.OwnMemory());
  107. mlu_src.CopyFrom(in, kStrLength);
  108. mlu_dst.CopyFrom(mlu_src, kStrLength);
  109. mlu_dst.CopyTo(out, kStrLength);
  110. EXPECT_STREQ(str, str_out);
  111. } catch (edk::Exception &err) {
  112. EXPECT_TRUE(false) << err.what();
  113. }
  114. }
  115. delete[] str;
  116. delete[] str_out;
  117. }