test_user_define.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. /*************************************************************************
  2. * Copyright (C) [2020] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include <glog/logging.h>
  21. #include <gtest/gtest.h>
  22. #include <memory>
  23. #include <utility>
  24. #include "cnis/buffer.h"
  25. #include "cnis/infer_server.h"
  26. #include "cnis/processor.h"
  27. #include "fixture.h"
  28. namespace infer_server {
  29. #ifdef CNIS_USE_MAGICMIND
  30. static const char* model_url = "http://video.cambricon.com/models/MLU370/resnet50_nhwc_tfu_0.5_int8_fp16.model";
  31. #else
  32. constexpr const char* model_url =
  33. "http://video.cambricon.com/models/MLU270/Primary_Detector/ssd/resnet34_ssd.cambricon";
  34. #endif
  35. struct MyData {
  36. Buffer data;
  37. };
  38. class MyProcessor : public ProcessorForkable<MyProcessor> {
  39. public:
  40. MyProcessor() noexcept : ProcessorForkable<MyProcessor>("MyProcessor") {}
  41. ~MyProcessor() {}
  42. Status Process(PackagePtr pack) noexcept override {
  43. if (!SetCurrentDevice(dev_id_)) return Status::ERROR_BACKEND;
  44. // discard all input and pass empty data to next processor
  45. for (auto& it : pack->data) {
  46. it->data.reset();
  47. }
  48. auto preproc_output = pool_->Request();
  49. ModelIO model_input;
  50. model_input.buffers.emplace_back(std::move(preproc_output));
  51. model_input.shapes.emplace_back(model_->InputShape(0));
  52. pack->predict_io.reset(new InferData);
  53. pack->predict_io->Set(std::move(model_input));
  54. return Status::SUCCESS;
  55. }
  56. Status Init() noexcept override {
  57. constexpr const char* params[] = {"model_info", "device_id"};
  58. for (auto p : params) {
  59. if (!HaveParam(p)) {
  60. LOG(ERROR) << p << " has not been set!";
  61. return Status::INVALID_PARAM;
  62. }
  63. }
  64. try {
  65. model_ = GetParam<ModelPtr>("model_info");
  66. dev_id_ = GetParam<int>("device_id");
  67. if (!SetCurrentDevice(dev_id_)) return Status::ERROR_BACKEND;
  68. auto shape = model_->InputShape(0);
  69. auto layout = model_->InputLayout(0);
  70. pool_.reset(new MluMemoryPool(shape.BatchDataCount() * GetTypeSize(layout.dtype), 3, dev_id_));
  71. } catch (bad_any_cast&) {
  72. LOG(ERROR) << "Unmatched data type";
  73. return Status::WRONG_TYPE;
  74. }
  75. return Status::SUCCESS;
  76. }
  77. private:
  78. std::unique_ptr<MluMemoryPool> pool_{nullptr};
  79. ModelPtr model_;
  80. int dev_id_;
  81. };
  82. TEST_F(InferServerTestAPI, UserDefine) {
  83. auto model = server_->LoadModel(model_url);
  84. ASSERT_TRUE(model) << "load model failed";
  85. auto preproc = MyProcessor::Create();
  86. SessionDesc desc;
  87. desc.name = "test user define";
  88. desc.model = model;
  89. desc.strategy = BatchStrategy::DYNAMIC;
  90. desc.preproc = std::move(preproc);
  91. desc.batch_timeout = 10;
  92. desc.engine_num = 1;
  93. desc.show_perf = true;
  94. desc.priority = 0;
  95. Session_t session = server_->CreateSyncSession(desc);
  96. ASSERT_TRUE(session);
  97. auto input = std::make_shared<Package>();
  98. input->data.reserve(32);
  99. for (uint32_t idx = 0; idx < 32; ++idx) {
  100. input->data.emplace_back(new InferData);
  101. input->data[idx]->Set(MyData());
  102. }
  103. auto output = std::make_shared<Package>();
  104. Status status;
  105. ASSERT_TRUE(server_->RequestSync(session, input, &status, output));
  106. ASSERT_EQ(status, Status::SUCCESS);
  107. EXPECT_EQ(output->data.size(), 32u);
  108. EXPECT_NO_THROW(output->data[0]->Get<ModelIO>());
  109. server_->DestroySession(session);
  110. }
  111. } // namespace infer_server