test_model.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. /*************************************************************************
  2. * Copyright (C) [2020] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include <gtest/gtest.h>
  21. #include <memory>
  22. #include <string>
  23. #include <vector>
  24. #include "cnrt.h"
  25. #include "core/data_type.h"
  26. #include "model/model.h"
  27. #include "test_base.h"
  28. #define CHECK_CNRT_RET(ret, msg) \
  29. do { \
  30. EXPECT_EQ(ret, CNRT_RET_SUCCESS) << msg << " error code: " << ret; \
  31. } while (0)
  32. namespace infer_server {
  33. TEST_F(InferServerTest, Model) {
  34. auto m = std::make_shared<Model>();
  35. #ifdef CNIS_USE_MAGICMIND
  36. std::string model_uri = "http://video.cambricon.com/models/MLU370/resnet50_nhwc_tfu_0.5_int8_fp16.model";
  37. // download model
  38. auto tmp = InferServer::LoadModel(model_uri);
  39. InferServer::UnloadModel(tmp);
  40. tmp.reset();
  41. ASSERT_TRUE(m->Init("resnet50_nhwc_tfu_0.5_int8_fp16.model"));
  42. auto* model = m->GetModel();
  43. size_t i_num = model->GetInputNum();
  44. size_t o_num = model->GetOutputNum();
  45. ASSERT_EQ(i_num, m->InputNum());
  46. ASSERT_EQ(o_num, m->OutputNum());
  47. std::vector<mm::Dims> in_dims = model->GetInputDimensions();
  48. std::vector<mm::Dims> out_dims = model->GetOutputDimensions();
  49. std::vector<mm::DataType> i_dtypes = model->GetInputDataTypes();
  50. std::vector<mm::DataType> o_dtypes = model->GetOutputDataTypes();
  51. // TODO(dmh): test layout after read layout from model supported by mm
  52. for (size_t idx = 0; idx < i_num; ++idx) {
  53. EXPECT_EQ(detail::CastDataType(i_dtypes[idx]), m->InputLayout(idx).dtype);
  54. EXPECT_EQ(Shape(in_dims[idx].GetDims()), m->InputShape(idx));
  55. }
  56. for (size_t idx = 0; idx < o_num; ++idx) {
  57. EXPECT_EQ(detail::CastDataType(o_dtypes[idx]), m->OutputLayout(idx).dtype);
  58. EXPECT_EQ(Shape(out_dims[idx].GetDims()), m->OutputShape(idx));
  59. }
  60. EXPECT_EQ(in_dims[0].GetDimValue(0), m->BatchSize());
  61. #else
  62. std::string model_path = GetExePath() + "../../../tests/data/resnet50_270.cambricon";
  63. ASSERT_TRUE(m->Init(model_path, "subnet0"));
  64. cnrtRet_t error_code;
  65. auto function = m->GetFunction();
  66. auto model = m->GetModel();
  67. int batch_size;
  68. error_code = cnrtQueryModelParallelism(model, &batch_size);
  69. CHECK_CNRT_RET(error_code, "Query Model Parallelism failed.");
  70. EXPECT_GE(batch_size, 0);
  71. EXPECT_EQ(static_cast<uint32_t>(batch_size), m->BatchSize());
  72. int64_t* input_sizes = nullptr;
  73. int64_t* output_sizes = nullptr;
  74. int input_num = 0, output_num = 0;
  75. error_code = cnrtGetInputDataSize(&input_sizes, &input_num, function);
  76. CHECK_CNRT_RET(error_code, "Get input data size failed.");
  77. EXPECT_EQ(m->InputNum(), static_cast<uint32_t>(input_num));
  78. error_code = cnrtGetOutputDataSize(&output_sizes, &output_num, function);
  79. CHECK_CNRT_RET(error_code, "Get output data size failed.");
  80. EXPECT_EQ(m->OutputNum(), static_cast<uint32_t>(output_num));
  81. // get io shapes
  82. int* input_dim_values = nullptr;
  83. int* output_dim_values = nullptr;
  84. int dim_num = 0;
  85. for (int i = 0; i < input_num; ++i) {
  86. error_code = cnrtGetInputDataShape(&input_dim_values, &dim_num, i, function);
  87. CHECK_CNRT_RET(error_code, "Get input data size failed.");
  88. // nhwc shape
  89. for (int j = 0; j < dim_num; ++j) {
  90. EXPECT_EQ(m->InputShape(i)[j], input_dim_values[j]);
  91. }
  92. free(input_dim_values);
  93. }
  94. for (int i = 0; i < output_num; ++i) {
  95. error_code = cnrtGetOutputDataShape(&output_dim_values, &dim_num, i, function);
  96. CHECK_CNRT_RET(error_code, "Get output data shape failed.");
  97. // nhwc shape
  98. for (int j = 0; j < dim_num; ++j) {
  99. EXPECT_EQ(m->OutputShape(i)[j], output_dim_values[j]);
  100. }
  101. free(output_dim_values);
  102. }
  103. EXPECT_EQ(m->GetKey().compare(model_path + "_" + "subnet0"), 0);
  104. #endif
  105. }
  106. } // namespace infer_server