shape.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. /*************************************************************************
  2. * Copyright (C) [2019] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #ifndef INFER_SERVER_SHAPE_H_
  21. #define INFER_SERVER_SHAPE_H_
  22. #include <iostream>
  23. #include <vector>
  24. namespace infer_server {
  25. /**
  26. * @brief Shape to describe inference model input and output data
  27. * @warning No matter how data is placed in memory, dim values always keep in order of NHWC
  28. */
  29. class Shape {
  30. public:
  31. /// stored value type
  32. using value_type = int64_t;
  33. /**
  34. * @brief Construct a new Shape object
  35. */
  36. Shape() = default;
  37. /**
  38. * @brief Construct a new Shape object from shape vector
  39. *
  40. * @param v vector stored shape value
  41. */
  42. explicit Shape(const std::vector<value_type>& v) noexcept { data_ = v; }
  43. /**
  44. * @brief Copy assign a Shape object from shape vector
  45. *
  46. * @param v vector stored shape value
  47. */
  48. Shape& operator=(const std::vector<value_type>& v) noexcept {
  49. data_ = v;
  50. return *this;
  51. }
  52. Shape(const Shape&) = default;
  53. Shape& operator=(const Shape&) = default;
  54. Shape(Shape&&) = default;
  55. Shape& operator=(Shape&&) = default;
  56. /**
  57. * @brief Get value of nth dimension
  58. *
  59. * @param offset serial number of dimension
  60. * @return value_type shape value
  61. */
  62. value_type operator[](int offset) const noexcept { return data_[offset]; }
  63. /**
  64. * @brief Get value of nth dimension
  65. *
  66. * @param offset serial number of dimension
  67. * @return value_type reference to shape value
  68. */
  69. value_type& operator[](int offset) noexcept { return data_[offset]; }
  70. /**
  71. * @brief Returns the dimension size of Shape
  72. *
  73. * @return size_t The dimension size of Shape
  74. */
  75. size_t Size() const noexcept { return data_.size(); };
  76. /**
  77. * @brief Returns whether Shape is empty
  78. *
  79. * @retval true if the Shape doesn't have any value
  80. * @retval false otherwise
  81. */
  82. bool Empty() const noexcept { return data_.empty(); };
  83. /**
  84. * @brief Get vectorized shape value
  85. *
  86. * @return std::vector<value_type> vectorized shape value
  87. */
  88. std::vector<value_type> Vectorize() const noexcept { return data_; }
  89. /**
  90. * @brief Get batchsize
  91. *
  92. * @return value_type batch size
  93. */
  94. value_type BatchSize() const noexcept { return data_[0]; }
  95. /**
  96. * @brief Get total data count / batch size
  97. *
  98. * @return Data count
  99. */
  100. int64_t DataCount() const noexcept {
  101. int64_t cnt = 1;
  102. for (size_t i = 1; i < data_.size(); ++i) {
  103. cnt *= data_[i];
  104. }
  105. return cnt;
  106. }
  107. /**
  108. * @brief Get total data count
  109. *
  110. * @return Total data count
  111. */
  112. int64_t BatchDataCount() const noexcept {
  113. int64_t cnt = 1;
  114. for (size_t i = 0; i < data_.size(); ++i) {
  115. cnt *= data_[i];
  116. }
  117. return cnt;
  118. }
  119. /**
  120. * @brief Print shape
  121. *
  122. * @param os Output stream
  123. * @param shape Shape to be printed
  124. * @return Output stream
  125. */
  126. friend std::ostream& operator<<(std::ostream& os, const Shape& shape) {
  127. os << "Shape (";
  128. for (size_t i = 0; i < shape.Size() - 1; ++i) {
  129. os << shape[i] << ", ";
  130. }
  131. if (shape.Size() > 0) os << shape[shape.Size() - 1];
  132. os << ")";
  133. return os;
  134. }
  135. /**
  136. * @brief Judge whether two shapes are equal
  137. *
  138. * @param other another Shape
  139. * @retval true if two shapes are equal
  140. * @retval false otherwise
  141. */
  142. bool operator==(const Shape& other) const noexcept {
  143. if (Size() != other.Size()) return false;
  144. for (size_t i = 0; i < Size(); ++i) {
  145. if (data_[i] != other[i]) return false;
  146. }
  147. return true;
  148. }
  149. /**
  150. * @brief Judge whether two shapes are not equal
  151. *
  152. * @param other another Shape
  153. * @retval true if two shapes are not equal
  154. * @retval false otherwise
  155. */
  156. bool operator!=(const Shape& other) const noexcept { return !(*this == other); }
  157. private:
  158. std::vector<value_type> data_;
  159. }; // class Shape
  160. } // namespace infer_server
  161. #endif // INFER_SERVER_SHAPE_H_