/************************************************************************* * Copyright (C) [2019] by Cambricon, Inc. All rights reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. *************************************************************************/ /** * @file model_loader.h * * This file contains a declaration of the ModelLoader class and involved struct */ #ifndef EASYINFER_MODEL_LOADER_H_ #define EASYINFER_MODEL_LOADER_H_ #include #include #include #include "cxxutil/edk_attribute.h" #include "cxxutil/exception.h" #include "easyinfer/shape.h" namespace edk { /** * @brief Enumeration to specify data type of model input and output */ enum class DataType { UINT8, FLOAT32, FLOAT16, INT16, INT32 }; /** * @brief Enumeration to specify dim order of model input and output */ enum class DimOrder { NCHW, NHWC, HWCN, TNC, NTC }; /** * @brief Describe data layout on MLU or CPU */ struct DataLayout { DataType dtype; ///< @see edk::DataType DimOrder order; ///< @see edk::DimOrder }; class ModelLoaderPrivate; class ModelLoaderInternalInterface; /** * @brief A helper class to load offline model and get model infomation */ class ModelLoader { public: friend class ModelLoaderInternalInterface; /** * @brief Constructor 1. Construct a new Model Loader object * * @note Delegate to constructor 2 for construct * @param model_path Inference offline model path * @param function_name Name of function in offline model */ ModelLoader(const std::string& model_path, const std::string& function_name); /** * @brief Constructor 2. Construct a new Model Loader object * * @note Delegate to constructor 3 for construct * @param model_path Model path * @param function_name Function name */ ModelLoader(const char* model_path, const char* function_name); /** * @brief Constructor 3. Construct a new Model Loader object * * @param mem_ptr Offline model binary stored in memory * @param function_name Function name */ ModelLoader(void* mem_ptr, const char* function_name); /** * @brief Destroy the Model Loader object */ ~ModelLoader(); /** * @brief Set specified input data layout on CPU * * @param layout Data layout * @param data_index Data index */ void SetCpuInputLayout(DataLayout layout, int data_index); /** * @brief Set specified output data layout on CPU * * @param layout Data layout * @param data_index Data index */ void SetCpuOutputLayout(DataLayout layout, int data_index); /** * @brief Get specified input data layout on CPU * * @param data_index Data index * @return Data layout */ DataLayout GetCpuInputLayout(int data_index) const; /** * @brief Get specified output data layout on CPU * * @param data_index Data index * @return Data layout */ DataLayout GetCpuOutputLayout(int data_index) const; /** * @brief Adjust MLU stack memory according to model size * * @note Adjust MLU stack memory. Do nothing if model size is not larger than current stack memory size. * @return Return true if stack memory is adjusted. */ bool AdjustStackMemory(); /** * @brief Get model output number * * @return Model output number */ uint32_t OutputNum() const; /** * @brief Get model input number * * @return Model input number */ uint32_t InputNum() const; /** * @brief Get model input data shapes * * @deprecated use ModelLoader::InputShape(uint32_t) instead * @return Model input data shapes */ attribute_deprecated const std::vector& InputShapes() const; /** * @brief Get model output data shapes * * @deprecated use ModelLoader::OutputShape(uint32_t) instead * @return Model output data shapes */ attribute_deprecated const std::vector& OutputShapes() const; /** * @brief Get model input data shape * * @param index input index * @return Model input data shape */ const ShapeEx& InputShape(uint32_t index) const; /** * @brief Get model output data shape * * @param index output index * @return Model output data shape */ const ShapeEx& OutputShape(uint32_t index) const; /** * @brief Get model parallelism * * @note Not supported on MLU100, always return 1. * @return Model parallelism */ int ModelParallelism() const; /** * @brief Get the input data batch align size * * @param data_index Data index * @return input data batch align size */ int64_t GetInputDataBatchAlignSize(int data_index) const; /** * @brief Get the output data batch align size * * @param data_index Data index * @return output data batch align size */ int64_t GetOutputDataBatchAlignSize(int data_index) const; private: std::unique_ptr d_ptr_; ModelLoader(const ModelLoader&) = delete; ModelLoader& operator=(const ModelLoader&) = delete; ModelLoader(ModelLoader&&) = delete; ModelLoader& operator=(ModelLoader&&) = delete; }; // class ModelLoader } // namespace edk #endif // EASYINFER_MODEL_LOADER_H_