infer_server.h 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. /*************************************************************************
  2. * Copyright (C) [2020] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #ifndef INFER_SERVER_API_H_
  21. #define INFER_SERVER_API_H_
  22. #include <functional>
  23. #include <limits>
  24. #include <map>
  25. #include <memory>
  26. #include <mutex>
  27. #include <string>
  28. #include <utility>
  29. #include <vector>
  30. #include "buffer.h"
  31. #include "shape.h"
  32. #include "util/any.h"
  33. #include "util/base_object.h"
  34. #include "config.h"
  35. #define CNIS_GET_VERSION(major, minor, patch) (((major) << 20) | ((minor) << 10) | (patch))
  36. #define CNIS_VERSION CNIS_GET_VERSION(CNIS_VERSION_MAJOR, CNIS_VERSION_MINOR, CNIS_VERSION_PATCH)
  37. namespace infer_server {
  38. /**
  39. * @brief Enumeration to specify data type of model input and output
  40. */
  41. enum class DataType { UINT8, FLOAT32, FLOAT16, INT16, INT32, INVALID };
  42. /**
  43. * @brief Enumeration to specify dim order of model input and output
  44. */
  45. enum class DimOrder { NCHW, NHWC, HWCN, TNC, NTC };
  46. /**
  47. * @brief Describe data layout on MLU or CPU
  48. */
  49. struct DataLayout {
  50. DataType dtype; ///< @see DataType
  51. DimOrder order; ///< @see DimOrder
  52. };
  53. /**
  54. * @brief Get size in bytes of type
  55. *
  56. * @param type Data type enumeration
  57. * @return size_t size of specified type
  58. */
  59. size_t GetTypeSize(DataType type) noexcept;
  60. /**
  61. * @brief An enum describes InferServer request return values.
  62. */
  63. enum class Status {
  64. SUCCESS = 0, ///< The operation was successful
  65. ERROR_READWRITE = 1, ///< Read / Write file failed
  66. ERROR_MEMORY = 2, ///< Memory error, such as out of memory, memcpy failed
  67. INVALID_PARAM = 3, ///< Invalid parameters
  68. WRONG_TYPE = 4, ///< Invalid data type in `any`
  69. ERROR_BACKEND = 5, ///< Error occured in processor
  70. NOT_IMPLEMENTED = 6, ///< Function not implemented
  71. TIMEOUT = 7, ///< Time expired
  72. STATUS_COUNT = 8, ///< Number of status
  73. };
  74. /**
  75. * @brief An enum describes batch strategy
  76. */
  77. enum class BatchStrategy {
  78. DYNAMIC = 0, ///< Cross-request batch
  79. STATIC = 1, ///< In-request batch
  80. SEQUENCE = 2, ///< Sequence model, unsupported for now
  81. STRATEGY_COUNT = 3, ///< Number of strategy
  82. };
  83. /**
  84. * @brief Convert BatchStrategy to string
  85. *
  86. * @param strategy batch strategy
  87. * @return std::string Stringified batch strategy
  88. */
  89. std::string ToString(BatchStrategy strategy) noexcept;
  90. /**
  91. * @brief Put BatchStrategy into ostream
  92. *
  93. * @param os ostream
  94. * @param s BatchStrategy
  95. * @return std::ostream& ostream
  96. */
  97. inline std::ostream& operator<<(std::ostream& os, BatchStrategy s) { return os << ToString(s); }
  98. /**
  99. * @brief Get CNIS version string
  100. *
  101. * @return std::string version string
  102. */
  103. inline std::string Version() {
  104. // clang-format off
  105. return std::to_string(CNIS_VERSION_MAJOR) + "." +
  106. std::to_string(CNIS_VERSION_MINOR) + "." +
  107. std::to_string(CNIS_VERSION_PATCH);
  108. // clang-format on
  109. }
  110. /**
  111. * @brief Set current deivce for this thread
  112. *
  113. * @param device_id device id
  114. *
  115. * @retval true success
  116. * @retval false set device failed
  117. */
  118. bool SetCurrentDevice(int device_id) noexcept;
  119. /**
  120. * @brief Check whether device is accessible
  121. *
  122. * @param device_id device id
  123. *
  124. * @retval true device is accessible
  125. * @retval false no such device
  126. */
  127. bool CheckDevice(int device_id) noexcept;
  128. /**
  129. * @brief Get total device count
  130. *
  131. * @retval device count
  132. */
  133. uint32_t TotalDeviceCount() noexcept;
  134. /**
  135. * @brief Model interface
  136. */
  137. class ModelInfo {
  138. public:
  139. virtual ~ModelInfo() = default;
  140. // ----------- Observers -----------
  141. /**
  142. * @brief Get input shape
  143. *
  144. * @param index index of input
  145. * @return const Shape& shape of specified input
  146. */
  147. virtual const Shape& InputShape(int index) const noexcept = 0;
  148. /**
  149. * @brief Get output shape
  150. *
  151. * @param index index of output
  152. * @return const Shape& shape of specified output
  153. */
  154. virtual const Shape& OutputShape(int index) const noexcept = 0;
  155. /**
  156. * @brief Get input layout on MLU
  157. *
  158. * @param index index of input
  159. * @return const DataLayout& data layout of specified input
  160. */
  161. virtual const DataLayout& InputLayout(int index) const noexcept = 0;
  162. /**
  163. * @brief Get output layout on MLU
  164. *
  165. * @param index index of output
  166. * @return const DataLayout& data layout of specified output
  167. */
  168. virtual const DataLayout& OutputLayout(int index) const noexcept = 0;
  169. /**
  170. * @brief Get number of input
  171. *
  172. * @return uint32_t number of input
  173. */
  174. virtual uint32_t InputNum() const noexcept = 0;
  175. /**
  176. * @brief Get number of output
  177. *
  178. * @return uint32_t number of output
  179. */
  180. virtual uint32_t OutputNum() const noexcept = 0;
  181. /**
  182. * @brief Get model batch size
  183. *
  184. * @return uint32_t batch size
  185. */
  186. virtual uint32_t BatchSize() const noexcept = 0;
  187. /**
  188. * @brief Get model key
  189. *
  190. * @return const std::string& model key
  191. */
  192. virtual std::string GetKey() const noexcept = 0;
  193. // ----------- Observers End -----------
  194. }; // class ModelInfo
  195. using ModelPtr = std::shared_ptr<ModelInfo>;
  196. class RequestControl;
  197. /**
  198. * @brief Inference data unit
  199. */
  200. struct InferData {
  201. /**
  202. * @brief Set any data into inference data
  203. *
  204. * @tparam T data type
  205. * @param v data value
  206. */
  207. template <typename T>
  208. void Set(T&& v) {
  209. data = std::forward<T>(v);
  210. }
  211. /**
  212. * @brief Get data by value
  213. *
  214. * @tparam T data type
  215. * @return std::remove_reference<T>::type a copy of data
  216. */
  217. template <typename T>
  218. typename std::remove_reference<T>::type Get() const {
  219. return any_cast<typename std::remove_reference<T>::type>(data);
  220. }
  221. /**
  222. * @brief Get data by lvalue reference
  223. *
  224. * @tparam T data type
  225. * @return std::add_lvalue_reference<T>::type lvalue reference to data
  226. */
  227. template <typename T>
  228. typename std::add_lvalue_reference<T>::type GetLref() & {
  229. return any_cast<typename std::add_lvalue_reference<T>::type>(data);
  230. }
  231. /**
  232. * @brief Get data by const lvalue reference
  233. *
  234. * @tparam T data type
  235. * @return std::add_lvalue_reference<typename std::add_const<T>::type>::type const lvalue reference to data
  236. */
  237. template <typename T>
  238. typename std::add_lvalue_reference<typename std::add_const<T>::type>::type GetLref() const& {
  239. return any_cast<typename std::add_lvalue_reference<typename std::add_const<T>::type>::type>(data);
  240. }
  241. /**
  242. * @brief Check if InferData has value
  243. *
  244. * @retval true InferData has value
  245. * @retval false InferData does not have value
  246. */
  247. bool HasValue() noexcept {
  248. return data.has_value();
  249. }
  250. /**
  251. * @brief Set user data for postprocess
  252. *
  253. * @tparam T data type
  254. * @param v data value
  255. */
  256. template <typename T>
  257. void SetUserData(T&& v) {
  258. user_data = std::forward<T>(v);
  259. }
  260. /**
  261. * @brief Get user data by value
  262. *
  263. * @note if T is lvalue reference, data is returned by lvalue reference.
  264. * if T is bare type, data is returned by value.
  265. * @tparam T data type
  266. * @return data
  267. */
  268. template <typename T>
  269. T GetUserData() const {
  270. return any_cast<T>(user_data);
  271. }
  272. /// stored data
  273. any data;
  274. /// user data passed to postprocessor
  275. any user_data;
  276. /// private member
  277. RequestControl* ctrl{nullptr};
  278. /// private member
  279. uint32_t index{0};
  280. };
  281. using InferDataPtr = std::shared_ptr<InferData>;
  282. using BatchData = std::vector<InferDataPtr>;
  283. /**
  284. * @brief Data package, used in request and response
  285. */
  286. struct Package {
  287. /// a batch of data
  288. BatchData data;
  289. /// private member, intermediate storage
  290. InferDataPtr predict_io{nullptr};
  291. /// tag of this package (such as stream_id, client ip, etc.)
  292. std::string tag;
  293. /// perf statistics of one request
  294. std::map<std::string, float> perf;
  295. /// private member
  296. int64_t priority;
  297. static std::shared_ptr<Package> Create(uint32_t data_num, const std::string& tag = "") noexcept {
  298. auto ret = std::make_shared<Package>();
  299. ret->data.reserve(data_num);
  300. for (uint32_t idx = 0; idx < data_num; ++idx) {
  301. ret->data.emplace_back(new InferData);
  302. }
  303. ret->tag = tag;
  304. return ret;
  305. }
  306. };
  307. using PackagePtr = std::shared_ptr<Package>;
  308. /**
  309. * @brief Processor interface
  310. */
  311. class Processor : public BaseObject {
  312. public:
  313. /**
  314. * @brief Construct a new Processor object
  315. *
  316. * @param type_name type name of derived processor
  317. */
  318. explicit Processor(const std::string& type_name) noexcept : type_name_(type_name) {}
  319. /**
  320. * @brief Get type name of processor
  321. *
  322. * @return const std::string& type name
  323. */
  324. const std::string& TypeName() const noexcept { return type_name_; }
  325. /**
  326. * @brief Destroy the Processor object
  327. */
  328. virtual ~Processor() = default;
  329. /**
  330. * @brief Initialize processor
  331. *
  332. * @retval Status::SUCCESS Init succeeded
  333. * @retval other Init failed
  334. */
  335. virtual Status Init() noexcept = 0;
  336. /**
  337. * @brief Process data in package
  338. *
  339. * @param data Processed data
  340. * @retval Status::SUCCESS Process succeeded
  341. * @retval other Process failed
  342. */
  343. virtual Status Process(PackagePtr data) noexcept = 0;
  344. /**
  345. * @brief Fork an initialized processor which have the same params as this
  346. *
  347. * @return std::shared_ptr<Processor> A new processor
  348. */
  349. virtual std::shared_ptr<Processor> Fork() = 0;
  350. private:
  351. Processor() = delete;
  352. friend class TaskNode;
  353. std::unique_lock<std::mutex> Lock() noexcept { return std::unique_lock<std::mutex>(process_lock_); }
  354. std::string type_name_;
  355. std::mutex process_lock_;
  356. }; // class Processor
  357. /**
  358. * @brief A convenient CRTP template provided `Fork` and `Create` function
  359. *
  360. * @tparam T Type of derived class
  361. */
  362. template <typename T>
  363. class ProcessorForkable : public Processor {
  364. public:
  365. /**
  366. * @brief Construct a new Processor Forkable object
  367. *
  368. * @param type_name type name of derived processor
  369. */
  370. explicit ProcessorForkable(const std::string& type_name) noexcept : Processor(type_name) {}
  371. /**
  372. * @brief Destroy the Processor Forkable object
  373. */
  374. virtual ~ProcessorForkable() = default;
  375. /**
  376. * @brief Fork an initialized processor which have the same params as this
  377. *
  378. * @return std::shared_ptr<Processor> A new processor
  379. */
  380. std::shared_ptr<Processor> Fork() noexcept(std::is_nothrow_default_constructible<T>::value) final {
  381. auto p = std::make_shared<T>();
  382. p->CopyParamsFrom(*this);
  383. if (p->Init() != Status::SUCCESS) return nullptr;
  384. return p;
  385. }
  386. /**
  387. * @brief Create a processor
  388. *
  389. * @return std::shared_ptr<T> A new processor
  390. */
  391. static std::shared_ptr<T> Create() noexcept(std::is_nothrow_default_constructible<T>::value) {
  392. return std::make_shared<T>();
  393. }
  394. };
  395. /**
  396. * @brief Base class of response observer, only used for async Session
  397. */
  398. class Observer {
  399. public:
  400. /**
  401. * @brief Notify the observer one response
  402. *
  403. * @param status Request status code
  404. * @param data Response data
  405. * @param user_data User data
  406. */
  407. virtual void Response(Status status, PackagePtr data, any user_data) noexcept = 0;
  408. /**
  409. * @brief Destroy the Observer object
  410. */
  411. virtual ~Observer() = default;
  412. };
  413. /**
  414. * @brief A struct to describe execution graph
  415. */
  416. struct SessionDesc {
  417. /// session name, distinct session in log
  418. std::string name{};
  419. /// model pointer
  420. ModelPtr model{nullptr};
  421. /// batch strategy
  422. BatchStrategy strategy{BatchStrategy::DYNAMIC};
  423. /**
  424. * @brief host input data layout, work when input data is on cpu
  425. *
  426. * @note built-in processor will transform data from host input layout into MLU input layout
  427. * ( @see ModelInfo::InputLayout(int index) ) automatically before infer
  428. */
  429. DataLayout host_input_layout{DataType::UINT8, DimOrder::NHWC};
  430. /**
  431. * @brief host output data layout
  432. *
  433. * @note built-in processor will transform from MLU output layout ( @see ModelInfo::OutputLayout(int index) )
  434. * into host output layout automatically after infer
  435. */
  436. DataLayout host_output_layout{DataType::FLOAT32, DimOrder::NHWC};
  437. /// preprocessor
  438. std::shared_ptr<Processor> preproc{nullptr};
  439. /// postprocessor
  440. std::shared_ptr<Processor> postproc{nullptr};
  441. /// timeout in milliseconds, only work for BatchStrategy::DYNAMIC
  442. uint32_t batch_timeout{100};
  443. /// Session request priority
  444. int priority{0};
  445. /**
  446. * @brief engine number
  447. *
  448. * @note multi engine can boost process, but will take more MLU resources
  449. */
  450. uint32_t engine_num{1};
  451. /// whether print performance
  452. bool show_perf{true};
  453. };
  454. /**
  455. * @brief Latency statistics
  456. */
  457. struct LatencyStatistic {
  458. /// Total processed unit count
  459. uint32_t unit_cnt{0};
  460. /// Total recorded value
  461. double total{0};
  462. /// Maximum value of one unit
  463. float max{0};
  464. /// Minimum value of one unit
  465. float min{std::numeric_limits<float>::max()};
  466. };
  467. /**
  468. * @brief Throughout statistics
  469. */
  470. struct ThroughoutStatistic {
  471. /// total request count
  472. uint32_t request_cnt{0};
  473. /// total unit cnt
  474. uint32_t unit_cnt{0};
  475. /// request per second
  476. float rps{0};
  477. /// unit per second
  478. float ups{0};
  479. /// real time rps
  480. float rps_rt{0};
  481. /// real time ups
  482. float ups_rt{0};
  483. };
  484. /// A structure describes linked session of server
  485. class Session;
  486. /// pointer to Session
  487. using Session_t = Session*;
  488. class InferServerPrivate;
  489. /**
  490. * @brief Inference server api class
  491. */
  492. class InferServer {
  493. public:
  494. /**
  495. * @brief Construct a new Infer Server object
  496. *
  497. * @param device_id Specified MLU device ID
  498. */
  499. explicit InferServer(int device_id) noexcept;
  500. /* ------------------------- Request API -------------------------- */
  501. /**
  502. * @brief Create a Session
  503. *
  504. * @param desc Session description
  505. * @param observer Response observer
  506. * @return Session_t a Session
  507. */
  508. Session_t CreateSession(SessionDesc desc, std::shared_ptr<Observer> observer) noexcept;
  509. /**
  510. * @brief Create a synchronous Session
  511. *
  512. * @param desc Session description
  513. * @return Session_t a Session
  514. */
  515. Session_t CreateSyncSession(SessionDesc desc) noexcept { return CreateSession(desc, nullptr); }
  516. /**
  517. * @brief Destroy session
  518. *
  519. * @param session a Session
  520. * @retval true Destroy succeeded
  521. * @retval false session does not belong to this server
  522. */
  523. bool DestroySession(Session_t session) noexcept;
  524. /**
  525. * @brief send a inference request
  526. *
  527. * @warning async api, can be invoked with async Session only.
  528. *
  529. * @param session link handle
  530. * @param input input package
  531. * @param user_data user data
  532. * @param timeout timeout threshold (milliseconds), -1 for endless
  533. */
  534. bool Request(Session_t session, PackagePtr input, any user_data, int timeout = -1) noexcept;
  535. /**
  536. * @brief send a inference request and wait for response
  537. *
  538. * @warning synchronous api, can be invoked with synchronous Session only.
  539. *
  540. * @param session session
  541. * @param input input package
  542. * @param status execute status
  543. * @param response output result
  544. * @param timeout timeout threshold (milliseconds), -1 for endless
  545. */
  546. bool RequestSync(Session_t session, PackagePtr input, Status* status, PackagePtr response, int timeout = -1) noexcept;
  547. /**
  548. * @brief Wait task with specified tag done, @see Package::tag
  549. *
  550. * @note Usually used at EOS
  551. *
  552. * @param session a Session
  553. * @param tag specified tag
  554. */
  555. void WaitTaskDone(Session_t session, const std::string& tag) noexcept;
  556. /**
  557. * @brief Discard task with specified tag done, @see Package::tag
  558. *
  559. * @note Usually used when you need to stop the process as soon as possible
  560. * @param session a Session
  561. * @param tag specified tag
  562. */
  563. void DiscardTask(Session_t session, const std::string& tag) noexcept;
  564. /**
  565. * @brief Get model from session
  566. *
  567. * @param session a Session
  568. * @return ModelPtr A model
  569. */
  570. ModelPtr GetModel(Session_t session) noexcept;
  571. /* --------------------- Model API ----------------------------- */
  572. /**
  573. * @brief Set directory to save downloaded model file
  574. *
  575. * @param model_dir model directory
  576. * @retval true Succeeded
  577. * @retval false Model not exist
  578. */
  579. static bool SetModelDir(const std::string& model_dir) noexcept;
  580. /**
  581. * @brief Load model from uri, model won't be loaded again if it is already in cache
  582. *
  583. * @note support download model from remote by HTTP, HTTPS, FTP, while compiled with flag `WITH_CURL`,
  584. * use uri such as `../../model_file`, or "https://someweb/model_file"
  585. * @param pattern1 offline model uri
  586. * @param pattern2 extracted function name, work only if backend is cnrt
  587. * @return ModelPtr A model
  588. */
  589. static ModelPtr LoadModel(const std::string& pattern1, const std::string& pattern2 = "subnet0") noexcept;
  590. #ifdef CNIS_USE_MAGICMIND
  591. /**
  592. * @brief Load model from memory, model won't be loaded again if it is already in cache
  593. *
  594. * @param ptr serialized model data in memory
  595. * @param size size of model data in memory
  596. * @return ModelPtr A model
  597. */
  598. static ModelPtr LoadModel(void* ptr, size_t size) noexcept;
  599. #else
  600. /**
  601. * @brief Load model from memory, model won't be loaded again if it is already in cache
  602. *
  603. * @param ptr serialized model data in memory
  604. * @param func_name name of function to be extracted
  605. * @return ModelPtr A model
  606. */
  607. static ModelPtr LoadModel(void* ptr, const std::string& func_name = "subnet0") noexcept;
  608. #endif
  609. /**
  610. * @brief Remove model from cache, model won't be destroyed if still in use
  611. *
  612. * @param model a model
  613. * @return true Succeed
  614. * @return false Model is not in cache
  615. */
  616. static bool UnloadModel(ModelPtr model) noexcept;
  617. /**
  618. * @brief Clear all the models in cache, model won't be destroyed if still in use
  619. */
  620. static void ClearModelCache() noexcept;
  621. /* ----------------------- Perf API ---------------------------- */
  622. /**
  623. * @brief Get the latency statistics
  624. *
  625. * @param session a session
  626. * @return std::map<std::string, PerfStatistic> latency statistics
  627. */
  628. std::map<std::string, LatencyStatistic> GetLatency(Session_t session) const noexcept;
  629. /**
  630. * @brief Get the performance statistics
  631. *
  632. * @param session a session
  633. * @return ThroughoutStatistic throughout statistic
  634. */
  635. ThroughoutStatistic GetThroughout(Session_t session) const noexcept;
  636. /**
  637. * @brief Get the throughout statistics of specified tag
  638. *
  639. * @param session a session
  640. * @param tag tag
  641. * @return ThroughoutStatistic throughout statistic
  642. */
  643. ThroughoutStatistic GetThroughout(Session_t session, const std::string& tag) const noexcept;
  644. private:
  645. InferServer() = delete;
  646. InferServerPrivate* priv_;
  647. }; // class InferServer
  648. } // namespace infer_server
  649. #endif // INFER_SERVER_API_H_