123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- /*************************************************************************
- * Copyright (C) [2019] by Cambricon, Inc. All rights reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
- * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
- * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
- * THE SOFTWARE.
- *************************************************************************/
- /**
- * @file model_loader.h
- *
- * This file contains a declaration of the ModelLoader class and involved struct
- */
- #ifndef EASYINFER_MODEL_LOADER_H_
- #define EASYINFER_MODEL_LOADER_H_
- #include <memory>
- #include <string>
- #include <vector>
- #include "cxxutil/edk_attribute.h"
- #include "cxxutil/exception.h"
- #include "easyinfer/shape.h"
- namespace edk {
- /**
- * @brief Enumeration to specify data type of model input and output
- */
- enum class DataType { UINT8, FLOAT32, FLOAT16, INT16, INT32 };
- /**
- * @brief Enumeration to specify dim order of model input and output
- */
- enum class DimOrder { NCHW, NHWC, HWCN, TNC, NTC };
- /**
- * @brief Describe data layout on MLU or CPU
- */
- struct DataLayout {
- DataType dtype; ///< @see edk::DataType
- DimOrder order; ///< @see edk::DimOrder
- };
- class ModelLoaderPrivate;
- class ModelLoaderInternalInterface;
- /**
- * @brief A helper class to load offline model and get model infomation
- */
- class ModelLoader {
- public:
- friend class ModelLoaderInternalInterface;
- /**
- * @brief Constructor 1. Construct a new Model Loader object
- *
- * @note Delegate to constructor 2 for construct
- * @param model_path Inference offline model path
- * @param function_name Name of function in offline model
- */
- ModelLoader(const std::string& model_path, const std::string& function_name);
- /**
- * @brief Constructor 2. Construct a new Model Loader object
- *
- * @note Delegate to constructor 3 for construct
- * @param model_path Model path
- * @param function_name Function name
- */
- ModelLoader(const char* model_path, const char* function_name);
- /**
- * @brief Constructor 3. Construct a new Model Loader object
- *
- * @param mem_ptr Offline model binary stored in memory
- * @param function_name Function name
- */
- ModelLoader(void* mem_ptr, const char* function_name);
- /**
- * @brief Destroy the Model Loader object
- */
- ~ModelLoader();
- /**
- * @brief Set specified input data layout on CPU
- *
- * @param layout Data layout
- * @param data_index Data index
- */
- void SetCpuInputLayout(DataLayout layout, int data_index);
- /**
- * @brief Set specified output data layout on CPU
- *
- * @param layout Data layout
- * @param data_index Data index
- */
- void SetCpuOutputLayout(DataLayout layout, int data_index);
- /**
- * @brief Get specified input data layout on CPU
- *
- * @param data_index Data index
- * @return Data layout
- */
- DataLayout GetCpuInputLayout(int data_index) const;
- /**
- * @brief Get specified output data layout on CPU
- *
- * @param data_index Data index
- * @return Data layout
- */
- DataLayout GetCpuOutputLayout(int data_index) const;
- /**
- * @brief Adjust MLU stack memory according to model size
- *
- * @note Adjust MLU stack memory. Do nothing if model size is not larger than current stack memory size.
- * @return Return true if stack memory is adjusted.
- */
- bool AdjustStackMemory();
- /**
- * @brief Get model output number
- *
- * @return Model output number
- */
- uint32_t OutputNum() const;
- /**
- * @brief Get model input number
- *
- * @return Model input number
- */
- uint32_t InputNum() const;
- /**
- * @brief Get model input data shapes
- *
- * @deprecated use ModelLoader::InputShape(uint32_t) instead
- * @return Model input data shapes
- */
- attribute_deprecated const std::vector<Shape>& InputShapes() const;
- /**
- * @brief Get model output data shapes
- *
- * @deprecated use ModelLoader::OutputShape(uint32_t) instead
- * @return Model output data shapes
- */
- attribute_deprecated const std::vector<Shape>& OutputShapes() const;
- /**
- * @brief Get model input data shape
- *
- * @param index input index
- * @return Model input data shape
- */
- const ShapeEx& InputShape(uint32_t index) const;
- /**
- * @brief Get model output data shape
- *
- * @param index output index
- * @return Model output data shape
- */
- const ShapeEx& OutputShape(uint32_t index) const;
- /**
- * @brief Get model parallelism
- *
- * @note Not supported on MLU100, always return 1.
- * @return Model parallelism
- */
- int ModelParallelism() const;
- /**
- * @brief Get the input data batch align size
- *
- * @param data_index Data index
- * @return input data batch align size
- */
- int64_t GetInputDataBatchAlignSize(int data_index) const;
- /**
- * @brief Get the output data batch align size
- *
- * @param data_index Data index
- * @return output data batch align size
- */
- int64_t GetOutputDataBatchAlignSize(int data_index) const;
- private:
- std::unique_ptr<ModelLoaderPrivate> d_ptr_;
- ModelLoader(const ModelLoader&) = delete;
- ModelLoader& operator=(const ModelLoader&) = delete;
- ModelLoader(ModelLoader&&) = delete;
- ModelLoader& operator=(ModelLoader&&) = delete;
- }; // class ModelLoader
- } // namespace edk
- #endif // EASYINFER_MODEL_LOADER_H_
|