/************************************************************************* * Copyright (C) [2020] by Cambricon, Inc. All rights reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. *************************************************************************/ #ifndef INFER_SERVER_API_H_ #define INFER_SERVER_API_H_ #include #include #include #include #include #include #include #include #include "buffer.h" #include "shape.h" #include "util/any.h" #include "util/base_object.h" #include "config.h" #define CNIS_GET_VERSION(major, minor, patch) (((major) << 20) | ((minor) << 10) | (patch)) #define CNIS_VERSION CNIS_GET_VERSION(CNIS_VERSION_MAJOR, CNIS_VERSION_MINOR, CNIS_VERSION_PATCH) namespace infer_server { /** * @brief Enumeration to specify data type of model input and output */ enum class DataType { UINT8, FLOAT32, FLOAT16, INT16, INT32, INVALID }; /** * @brief Enumeration to specify dim order of model input and output */ enum class DimOrder { NCHW, NHWC, HWCN, TNC, NTC }; /** * @brief Describe data layout on MLU or CPU */ struct DataLayout { DataType dtype; ///< @see DataType DimOrder order; ///< @see DimOrder }; /** * @brief Get size in bytes of type * * @param type Data type enumeration * @return size_t size of specified type */ size_t GetTypeSize(DataType type) noexcept; /** * @brief An enum describes InferServer request return values. */ enum class Status { SUCCESS = 0, ///< The operation was successful ERROR_READWRITE = 1, ///< Read / Write file failed ERROR_MEMORY = 2, ///< Memory error, such as out of memory, memcpy failed INVALID_PARAM = 3, ///< Invalid parameters WRONG_TYPE = 4, ///< Invalid data type in `any` ERROR_BACKEND = 5, ///< Error occured in processor NOT_IMPLEMENTED = 6, ///< Function not implemented TIMEOUT = 7, ///< Time expired STATUS_COUNT = 8, ///< Number of status }; /** * @brief An enum describes batch strategy */ enum class BatchStrategy { DYNAMIC = 0, ///< Cross-request batch STATIC = 1, ///< In-request batch SEQUENCE = 2, ///< Sequence model, unsupported for now STRATEGY_COUNT = 3, ///< Number of strategy }; /** * @brief Convert BatchStrategy to string * * @param strategy batch strategy * @return std::string Stringified batch strategy */ std::string ToString(BatchStrategy strategy) noexcept; /** * @brief Put BatchStrategy into ostream * * @param os ostream * @param s BatchStrategy * @return std::ostream& ostream */ inline std::ostream& operator<<(std::ostream& os, BatchStrategy s) { return os << ToString(s); } /** * @brief Get CNIS version string * * @return std::string version string */ inline std::string Version() { // clang-format off return std::to_string(CNIS_VERSION_MAJOR) + "." + std::to_string(CNIS_VERSION_MINOR) + "." + std::to_string(CNIS_VERSION_PATCH); // clang-format on } /** * @brief Set current deivce for this thread * * @param device_id device id * * @retval true success * @retval false set device failed */ bool SetCurrentDevice(int device_id) noexcept; /** * @brief Check whether device is accessible * * @param device_id device id * * @retval true device is accessible * @retval false no such device */ bool CheckDevice(int device_id) noexcept; /** * @brief Get total device count * * @retval device count */ uint32_t TotalDeviceCount() noexcept; /** * @brief Model interface */ class ModelInfo { public: virtual ~ModelInfo() = default; // ----------- Observers ----------- /** * @brief Get input shape * * @param index index of input * @return const Shape& shape of specified input */ virtual const Shape& InputShape(int index) const noexcept = 0; /** * @brief Get output shape * * @param index index of output * @return const Shape& shape of specified output */ virtual const Shape& OutputShape(int index) const noexcept = 0; /** * @brief Get input layout on MLU * * @param index index of input * @return const DataLayout& data layout of specified input */ virtual const DataLayout& InputLayout(int index) const noexcept = 0; /** * @brief Get output layout on MLU * * @param index index of output * @return const DataLayout& data layout of specified output */ virtual const DataLayout& OutputLayout(int index) const noexcept = 0; /** * @brief Get number of input * * @return uint32_t number of input */ virtual uint32_t InputNum() const noexcept = 0; /** * @brief Get number of output * * @return uint32_t number of output */ virtual uint32_t OutputNum() const noexcept = 0; /** * @brief Get model batch size * * @return uint32_t batch size */ virtual uint32_t BatchSize() const noexcept = 0; /** * @brief Get model key * * @return const std::string& model key */ virtual std::string GetKey() const noexcept = 0; // ----------- Observers End ----------- }; // class ModelInfo using ModelPtr = std::shared_ptr; class RequestControl; /** * @brief Inference data unit */ struct InferData { /** * @brief Set any data into inference data * * @tparam T data type * @param v data value */ template void Set(T&& v) { data = std::forward(v); } /** * @brief Get data by value * * @tparam T data type * @return std::remove_reference::type a copy of data */ template typename std::remove_reference::type Get() const { return any_cast::type>(data); } /** * @brief Get data by lvalue reference * * @tparam T data type * @return std::add_lvalue_reference::type lvalue reference to data */ template typename std::add_lvalue_reference::type GetLref() & { return any_cast::type>(data); } /** * @brief Get data by const lvalue reference * * @tparam T data type * @return std::add_lvalue_reference::type>::type const lvalue reference to data */ template typename std::add_lvalue_reference::type>::type GetLref() const& { return any_cast::type>::type>(data); } /** * @brief Check if InferData has value * * @retval true InferData has value * @retval false InferData does not have value */ bool HasValue() noexcept { return data.has_value(); } /** * @brief Set user data for postprocess * * @tparam T data type * @param v data value */ template void SetUserData(T&& v) { user_data = std::forward(v); } /** * @brief Get user data by value * * @note if T is lvalue reference, data is returned by lvalue reference. * if T is bare type, data is returned by value. * @tparam T data type * @return data */ template T GetUserData() const { return any_cast(user_data); } /// stored data any data; /// user data passed to postprocessor any user_data; /// private member RequestControl* ctrl{nullptr}; /// private member uint32_t index{0}; }; using InferDataPtr = std::shared_ptr; using BatchData = std::vector; /** * @brief Data package, used in request and response */ struct Package { /// a batch of data BatchData data; /// private member, intermediate storage InferDataPtr predict_io{nullptr}; /// tag of this package (such as stream_id, client ip, etc.) std::string tag; /// perf statistics of one request std::map perf; /// private member int64_t priority; static std::shared_ptr Create(uint32_t data_num, const std::string& tag = "") noexcept { auto ret = std::make_shared(); ret->data.reserve(data_num); for (uint32_t idx = 0; idx < data_num; ++idx) { ret->data.emplace_back(new InferData); } ret->tag = tag; return ret; } }; using PackagePtr = std::shared_ptr; /** * @brief Processor interface */ class Processor : public BaseObject { public: /** * @brief Construct a new Processor object * * @param type_name type name of derived processor */ explicit Processor(const std::string& type_name) noexcept : type_name_(type_name) {} /** * @brief Get type name of processor * * @return const std::string& type name */ const std::string& TypeName() const noexcept { return type_name_; } /** * @brief Destroy the Processor object */ virtual ~Processor() = default; /** * @brief Initialize processor * * @retval Status::SUCCESS Init succeeded * @retval other Init failed */ virtual Status Init() noexcept = 0; /** * @brief Process data in package * * @param data Processed data * @retval Status::SUCCESS Process succeeded * @retval other Process failed */ virtual Status Process(PackagePtr data) noexcept = 0; /** * @brief Fork an initialized processor which have the same params as this * * @return std::shared_ptr A new processor */ virtual std::shared_ptr Fork() = 0; private: Processor() = delete; friend class TaskNode; std::unique_lock Lock() noexcept { return std::unique_lock(process_lock_); } std::string type_name_; std::mutex process_lock_; }; // class Processor /** * @brief A convenient CRTP template provided `Fork` and `Create` function * * @tparam T Type of derived class */ template class ProcessorForkable : public Processor { public: /** * @brief Construct a new Processor Forkable object * * @param type_name type name of derived processor */ explicit ProcessorForkable(const std::string& type_name) noexcept : Processor(type_name) {} /** * @brief Destroy the Processor Forkable object */ virtual ~ProcessorForkable() = default; /** * @brief Fork an initialized processor which have the same params as this * * @return std::shared_ptr A new processor */ std::shared_ptr Fork() noexcept(std::is_nothrow_default_constructible::value) final { auto p = std::make_shared(); p->CopyParamsFrom(*this); if (p->Init() != Status::SUCCESS) return nullptr; return p; } /** * @brief Create a processor * * @return std::shared_ptr A new processor */ static std::shared_ptr Create() noexcept(std::is_nothrow_default_constructible::value) { return std::make_shared(); } }; /** * @brief Base class of response observer, only used for async Session */ class Observer { public: /** * @brief Notify the observer one response * * @param status Request status code * @param data Response data * @param user_data User data */ virtual void Response(Status status, PackagePtr data, any user_data) noexcept = 0; /** * @brief Destroy the Observer object */ virtual ~Observer() = default; }; /** * @brief A struct to describe execution graph */ struct SessionDesc { /// session name, distinct session in log std::string name{}; /// model pointer ModelPtr model{nullptr}; /// batch strategy BatchStrategy strategy{BatchStrategy::DYNAMIC}; /** * @brief host input data layout, work when input data is on cpu * * @note built-in processor will transform data from host input layout into MLU input layout * ( @see ModelInfo::InputLayout(int index) ) automatically before infer */ DataLayout host_input_layout{DataType::UINT8, DimOrder::NHWC}; /** * @brief host output data layout * * @note built-in processor will transform from MLU output layout ( @see ModelInfo::OutputLayout(int index) ) * into host output layout automatically after infer */ DataLayout host_output_layout{DataType::FLOAT32, DimOrder::NHWC}; /// preprocessor std::shared_ptr preproc{nullptr}; /// postprocessor std::shared_ptr postproc{nullptr}; /// timeout in milliseconds, only work for BatchStrategy::DYNAMIC uint32_t batch_timeout{100}; /// Session request priority int priority{0}; /** * @brief engine number * * @note multi engine can boost process, but will take more MLU resources */ uint32_t engine_num{1}; /// whether print performance bool show_perf{true}; }; /** * @brief Latency statistics */ struct LatencyStatistic { /// Total processed unit count uint32_t unit_cnt{0}; /// Total recorded value double total{0}; /// Maximum value of one unit float max{0}; /// Minimum value of one unit float min{std::numeric_limits::max()}; }; /** * @brief Throughout statistics */ struct ThroughoutStatistic { /// total request count uint32_t request_cnt{0}; /// total unit cnt uint32_t unit_cnt{0}; /// request per second float rps{0}; /// unit per second float ups{0}; /// real time rps float rps_rt{0}; /// real time ups float ups_rt{0}; }; /// A structure describes linked session of server class Session; /// pointer to Session using Session_t = Session*; class InferServerPrivate; /** * @brief Inference server api class */ class InferServer { public: /** * @brief Construct a new Infer Server object * * @param device_id Specified MLU device ID */ explicit InferServer(int device_id) noexcept; /* ------------------------- Request API -------------------------- */ /** * @brief Create a Session * * @param desc Session description * @param observer Response observer * @return Session_t a Session */ Session_t CreateSession(SessionDesc desc, std::shared_ptr observer) noexcept; /** * @brief Create a synchronous Session * * @param desc Session description * @return Session_t a Session */ Session_t CreateSyncSession(SessionDesc desc) noexcept { return CreateSession(desc, nullptr); } /** * @brief Destroy session * * @param session a Session * @retval true Destroy succeeded * @retval false session does not belong to this server */ bool DestroySession(Session_t session) noexcept; /** * @brief send a inference request * * @warning async api, can be invoked with async Session only. * * @param session link handle * @param input input package * @param user_data user data * @param timeout timeout threshold (milliseconds), -1 for endless */ bool Request(Session_t session, PackagePtr input, any user_data, int timeout = -1) noexcept; /** * @brief send a inference request and wait for response * * @warning synchronous api, can be invoked with synchronous Session only. * * @param session session * @param input input package * @param status execute status * @param response output result * @param timeout timeout threshold (milliseconds), -1 for endless */ bool RequestSync(Session_t session, PackagePtr input, Status* status, PackagePtr response, int timeout = -1) noexcept; /** * @brief Wait task with specified tag done, @see Package::tag * * @note Usually used at EOS * * @param session a Session * @param tag specified tag */ void WaitTaskDone(Session_t session, const std::string& tag) noexcept; /** * @brief Discard task with specified tag done, @see Package::tag * * @note Usually used when you need to stop the process as soon as possible * @param session a Session * @param tag specified tag */ void DiscardTask(Session_t session, const std::string& tag) noexcept; /** * @brief Get model from session * * @param session a Session * @return ModelPtr A model */ ModelPtr GetModel(Session_t session) noexcept; /* --------------------- Model API ----------------------------- */ /** * @brief Set directory to save downloaded model file * * @param model_dir model directory * @retval true Succeeded * @retval false Model not exist */ static bool SetModelDir(const std::string& model_dir) noexcept; /** * @brief Load model from uri, model won't be loaded again if it is already in cache * * @note support download model from remote by HTTP, HTTPS, FTP, while compiled with flag `WITH_CURL`, * use uri such as `../../model_file`, or "https://someweb/model_file" * @param pattern1 offline model uri * @param pattern2 extracted function name, work only if backend is cnrt * @return ModelPtr A model */ static ModelPtr LoadModel(const std::string& pattern1, const std::string& pattern2 = "subnet0") noexcept; #ifdef CNIS_USE_MAGICMIND /** * @brief Load model from memory, model won't be loaded again if it is already in cache * * @param ptr serialized model data in memory * @param size size of model data in memory * @return ModelPtr A model */ static ModelPtr LoadModel(void* ptr, size_t size) noexcept; #else /** * @brief Load model from memory, model won't be loaded again if it is already in cache * * @param ptr serialized model data in memory * @param func_name name of function to be extracted * @return ModelPtr A model */ static ModelPtr LoadModel(void* ptr, const std::string& func_name = "subnet0") noexcept; #endif /** * @brief Remove model from cache, model won't be destroyed if still in use * * @param model a model * @return true Succeed * @return false Model is not in cache */ static bool UnloadModel(ModelPtr model) noexcept; /** * @brief Clear all the models in cache, model won't be destroyed if still in use */ static void ClearModelCache() noexcept; /* ----------------------- Perf API ---------------------------- */ /** * @brief Get the latency statistics * * @param session a session * @return std::map latency statistics */ std::map GetLatency(Session_t session) const noexcept; /** * @brief Get the performance statistics * * @param session a session * @return ThroughoutStatistic throughout statistic */ ThroughoutStatistic GetThroughout(Session_t session) const noexcept; /** * @brief Get the throughout statistics of specified tag * * @param session a session * @param tag tag * @return ThroughoutStatistic throughout statistic */ ThroughoutStatistic GetThroughout(Session_t session, const std::string& tag) const noexcept; private: InferServer() = delete; InferServerPrivate* priv_; }; // class InferServer } // namespace infer_server #endif // INFER_SERVER_API_H_