shape.h 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. /*************************************************************************
  2. * Copyright (C) [2019] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #ifndef EASYINFER_SHAPE_HPP_
  21. #define EASYINFER_SHAPE_HPP_
  22. #include <iostream>
  23. #include <vector>
  24. namespace edk {
  25. /**
  26. * @brief ShapeEx to describe inference model input and output data
  27. * @warning No matter how data is placed in memory, dim values always keep in order of NHWC
  28. */
  29. class ShapeEx {
  30. public:
  31. /// stored value type
  32. using value_type = int;
  33. /**
  34. * @brief Construct a new ShapeEx object
  35. */
  36. ShapeEx() = default;
  37. /**
  38. * @brief Construct a new ShapeEx object from shape vector
  39. *
  40. * @param v vector stored shape value
  41. */
  42. explicit ShapeEx(const std::vector<value_type>& v) noexcept { data_ = v; }
  43. ShapeEx(const ShapeEx&) = default;
  44. ShapeEx& operator=(const ShapeEx&) = default;
  45. ShapeEx(ShapeEx&&) = default;
  46. ShapeEx& operator=(ShapeEx&&) = default;
  47. /**
  48. * @brief Get value of nth dimension
  49. *
  50. * @param offset serial number of dimension
  51. * @return value_type shape value
  52. */
  53. value_type operator[](int offset) const { return data_[offset]; }
  54. /**
  55. * @brief Get value of nth dimension
  56. *
  57. * @param offset serial number of dimension
  58. * @return value_type reference to shape value
  59. */
  60. value_type& operator[](int offset) { return data_[offset]; }
  61. /**
  62. * @brief Returns the dimension size of ShapeEx
  63. *
  64. * @return size_t The dimension size of ShapeEx
  65. */
  66. size_t Size() const noexcept { return data_.size(); };
  67. /**
  68. * @brief Returns whether ShapeEx is empty
  69. *
  70. * @retval true if the ShapeEx doesn't have any value
  71. * @retval false otherwise
  72. */
  73. bool Empty() const noexcept { return data_.empty(); };
  74. /**
  75. * @brief Get vectorized shape value
  76. *
  77. * @return std::vector<value_type> vectorized shape value
  78. */
  79. std::vector<value_type> Vectorize() const noexcept { return data_; }
  80. /**
  81. * @brief Get batchsize
  82. *
  83. * @return value_type batch size
  84. */
  85. value_type BatchSize() const { return data_[0]; }
  86. /**
  87. * @brief Get n value
  88. *
  89. * @note Only work on ShapeEx of 4 dimension, will return 0 if size() != 4.
  90. * @return value_type n or 0
  91. */
  92. value_type N() const noexcept {
  93. if (Size() == 4) {
  94. return data_[0];
  95. }
  96. return 0;
  97. }
  98. /**
  99. * @brief Get height value
  100. *
  101. * @note Only work on ShapeEx of 4 dimension, will return 0 if size() != 4.
  102. * @return value_type height or 0
  103. */
  104. value_type H() const noexcept {
  105. if (Size() == 4) {
  106. return data_[1];
  107. }
  108. return 0;
  109. }
  110. /**
  111. * @brief Get width value
  112. *
  113. * @note Only work on ShapeEx of 4 dimension, will return 0 if size() != 4.
  114. * @return value_type width or 0
  115. */
  116. value_type W() const noexcept {
  117. if (Size() == 4) {
  118. return data_[2];
  119. }
  120. return 0;
  121. }
  122. /**
  123. * @brief Get channel value
  124. *
  125. * @note Only work on ShapeEx of 4 dimension, will return 0 if size() != 4.
  126. * @return value_type channel or 0
  127. */
  128. value_type C() const noexcept {
  129. if (Size() == 4) {
  130. return data_[3];
  131. }
  132. return 0;
  133. }
  134. /**
  135. * @brief Get total data count / batch size
  136. *
  137. * @return Data count
  138. */
  139. int64_t DataCount() const noexcept {
  140. int64_t cnt = 1;
  141. for (size_t i = 1; i < data_.size(); ++i) {
  142. cnt *= data_[i];
  143. }
  144. return cnt;
  145. }
  146. /**
  147. * @brief Get total data count
  148. *
  149. * @return Total data count
  150. */
  151. int64_t BatchDataCount() const noexcept {
  152. int64_t cnt = 1;
  153. for (size_t i = 0; i < data_.size(); ++i) {
  154. cnt *= data_[i];
  155. }
  156. return cnt;
  157. }
  158. /**
  159. * @brief put shape into ostream
  160. *
  161. * @param os Output stream
  162. * @param shape ShapeEx to be printed
  163. * @return Output stream
  164. */
  165. friend std::ostream& operator<<(std::ostream& os, const ShapeEx& shape) {
  166. os << "ShapeEx (";
  167. for (size_t i = 0; i < shape.Size() - 1; ++i) {
  168. os << shape[i] << ", ";
  169. }
  170. if (shape.Size() > 0) os << shape[shape.Size() - 1];
  171. os << ")";
  172. return os;
  173. }
  174. /**
  175. * @brief Judge whether two shapes are equal
  176. *
  177. * @param lhs a ShapeEx
  178. * @param rhs a ShapeEx
  179. * @retval true if two shapes are equal
  180. * @retval false otherwise
  181. */
  182. friend bool operator==(const ShapeEx& lhs, const ShapeEx& rhs) noexcept {
  183. if (lhs.Size() != rhs.Size()) return false;
  184. for (size_t i = 0; i < lhs.Size(); ++i) {
  185. if (lhs[i] != rhs[i]) return false;
  186. }
  187. return true;
  188. }
  189. /**
  190. * @brief Judge whether two shapes are not equal
  191. *
  192. * @param lhs a ShapeEx
  193. * @param rhs a ShapeEx
  194. * @retval true if two shapes are not equal
  195. * @retval false otherwise
  196. */
  197. friend bool operator!=(const ShapeEx& lhs, const ShapeEx& rhs) noexcept { return !(lhs == rhs); }
  198. private:
  199. std::vector<value_type> data_;
  200. }; // class ShapeEx
  201. /**
  202. * @brief Shape to describe inference model input and output data
  203. */
  204. class Shape {
  205. public:
  206. /**
  207. * @brief Construct a new Shape object
  208. *
  209. * @param n data number
  210. * @param h height
  211. * @param w width
  212. * @param c channel
  213. * @param stride aligned width
  214. */
  215. explicit Shape(uint32_t n = 1, uint32_t h = 1, uint32_t w = 1, uint32_t c = 1, uint32_t stride = 1);
  216. /**
  217. * @brief Get stride, which is aligned width
  218. *
  219. * @return Stride
  220. */
  221. inline uint32_t Stride() const { return w > stride_ ? w : stride_; }
  222. /**
  223. * @brief Set the stride
  224. *
  225. * @param s Stride
  226. */
  227. inline void SetStride(uint32_t s) { stride_ = s; }
  228. /**
  229. * @brief Get Step, row length, equals to stride multiply c
  230. *
  231. * @return Step
  232. */
  233. inline uint64_t Step() const { return Stride() * c; }
  234. /**
  235. * @brief Get total data count, equal to memory size
  236. *
  237. * @return Data count
  238. */
  239. inline uint64_t DataCount() const { return n * h * Step(); }
  240. /**
  241. * @brief Get n * h * w * c, which is unaligned data size
  242. *
  243. * @return nhwc
  244. */
  245. inline uint64_t nhwc() const { return n * h * w * c; }
  246. /**
  247. * @brief Get h * w * c, which is size of one data part
  248. *
  249. * @return hwc
  250. */
  251. inline uint64_t hwc() const { return h * w * c; }
  252. /**
  253. * @brief Get h * w, which is size of one channel in one data part
  254. *
  255. * @return hw
  256. */
  257. inline uint64_t hw() const { return h * w; }
  258. /**
  259. * @brief Print shape
  260. *
  261. * @param os Output stream
  262. * @param shape Shape to be printed
  263. * @return Output stream
  264. */
  265. friend std::ostream &operator<<(std::ostream &os, const Shape &shape);
  266. /**
  267. * @brief Judge whether two shapes are equal
  268. *
  269. * @param other Another shape
  270. * @return Return true if two shapes are equal
  271. */
  272. bool operator==(const Shape &other) const;
  273. /**
  274. * @brief Judge whether two shapes are not equal
  275. *
  276. * @param other Another shape
  277. * @return Return true if two shapes are not equal
  278. */
  279. bool operator!=(const Shape &other) const;
  280. /// data number
  281. uint32_t n;
  282. /// height
  283. uint32_t h;
  284. /// width
  285. uint32_t w;
  286. /// channel
  287. uint32_t c;
  288. private:
  289. uint32_t stride_;
  290. }; // class Shape
  291. } // namespace edk
  292. #endif // EASYINFER_SHAPE_HPP_