/************************************************************************* * Copyright (C) [2019] by Cambricon, Inc. All rights reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. *************************************************************************/ #ifndef EASYINFER_SHAPE_HPP_ #define EASYINFER_SHAPE_HPP_ #include <iostream> #include <vector> namespace edk { /** * @brief ShapeEx to describe inference model input and output data * @warning No matter how data is placed in memory, dim values always keep in order of NHWC */ class ShapeEx { public: /// stored value type using value_type = int; /** * @brief Construct a new ShapeEx object */ ShapeEx() = default; /** * @brief Construct a new ShapeEx object from shape vector * * @param v vector stored shape value */ explicit ShapeEx(const std::vector<value_type>& v) noexcept { data_ = v; } ShapeEx(const ShapeEx&) = default; ShapeEx& operator=(const ShapeEx&) = default; ShapeEx(ShapeEx&&) = default; ShapeEx& operator=(ShapeEx&&) = default; /** * @brief Get value of nth dimension * * @param offset serial number of dimension * @return value_type shape value */ value_type operator[](int offset) const { return data_[offset]; } /** * @brief Get value of nth dimension * * @param offset serial number of dimension * @return value_type reference to shape value */ value_type& operator[](int offset) { return data_[offset]; } /** * @brief Returns the dimension size of ShapeEx * * @return size_t The dimension size of ShapeEx */ size_t Size() const noexcept { return data_.size(); }; /** * @brief Returns whether ShapeEx is empty * * @retval true if the ShapeEx doesn't have any value * @retval false otherwise */ bool Empty() const noexcept { return data_.empty(); }; /** * @brief Get vectorized shape value * * @return std::vector<value_type> vectorized shape value */ std::vector<value_type> Vectorize() const noexcept { return data_; } /** * @brief Get batchsize * * @return value_type batch size */ value_type BatchSize() const { return data_[0]; } /** * @brief Get n value * * @note Only work on ShapeEx of 4 dimension, will return 0 if size() != 4. * @return value_type n or 0 */ value_type N() const noexcept { if (Size() == 4) { return data_[0]; } return 0; } /** * @brief Get height value * * @note Only work on ShapeEx of 4 dimension, will return 0 if size() != 4. * @return value_type height or 0 */ value_type H() const noexcept { if (Size() == 4) { return data_[1]; } return 0; } /** * @brief Get width value * * @note Only work on ShapeEx of 4 dimension, will return 0 if size() != 4. * @return value_type width or 0 */ value_type W() const noexcept { if (Size() == 4) { return data_[2]; } return 0; } /** * @brief Get channel value * * @note Only work on ShapeEx of 4 dimension, will return 0 if size() != 4. * @return value_type channel or 0 */ value_type C() const noexcept { if (Size() == 4) { return data_[3]; } return 0; } /** * @brief Get total data count / batch size * * @return Data count */ int64_t DataCount() const noexcept { int64_t cnt = 1; for (size_t i = 1; i < data_.size(); ++i) { cnt *= data_[i]; } return cnt; } /** * @brief Get total data count * * @return Total data count */ int64_t BatchDataCount() const noexcept { int64_t cnt = 1; for (size_t i = 0; i < data_.size(); ++i) { cnt *= data_[i]; } return cnt; } /** * @brief put shape into ostream * * @param os Output stream * @param shape ShapeEx to be printed * @return Output stream */ friend std::ostream& operator<<(std::ostream& os, const ShapeEx& shape) { os << "ShapeEx ("; for (size_t i = 0; i < shape.Size() - 1; ++i) { os << shape[i] << ", "; } if (shape.Size() > 0) os << shape[shape.Size() - 1]; os << ")"; return os; } /** * @brief Judge whether two shapes are equal * * @param lhs a ShapeEx * @param rhs a ShapeEx * @retval true if two shapes are equal * @retval false otherwise */ friend bool operator==(const ShapeEx& lhs, const ShapeEx& rhs) noexcept { if (lhs.Size() != rhs.Size()) return false; for (size_t i = 0; i < lhs.Size(); ++i) { if (lhs[i] != rhs[i]) return false; } return true; } /** * @brief Judge whether two shapes are not equal * * @param lhs a ShapeEx * @param rhs a ShapeEx * @retval true if two shapes are not equal * @retval false otherwise */ friend bool operator!=(const ShapeEx& lhs, const ShapeEx& rhs) noexcept { return !(lhs == rhs); } private: std::vector<value_type> data_; }; // class ShapeEx /** * @brief Shape to describe inference model input and output data */ class Shape { public: /** * @brief Construct a new Shape object * * @param n data number * @param h height * @param w width * @param c channel * @param stride aligned width */ explicit Shape(uint32_t n = 1, uint32_t h = 1, uint32_t w = 1, uint32_t c = 1, uint32_t stride = 1); /** * @brief Get stride, which is aligned width * * @return Stride */ inline uint32_t Stride() const { return w > stride_ ? w : stride_; } /** * @brief Set the stride * * @param s Stride */ inline void SetStride(uint32_t s) { stride_ = s; } /** * @brief Get Step, row length, equals to stride multiply c * * @return Step */ inline uint64_t Step() const { return Stride() * c; } /** * @brief Get total data count, equal to memory size * * @return Data count */ inline uint64_t DataCount() const { return n * h * Step(); } /** * @brief Get n * h * w * c, which is unaligned data size * * @return nhwc */ inline uint64_t nhwc() const { return n * h * w * c; } /** * @brief Get h * w * c, which is size of one data part * * @return hwc */ inline uint64_t hwc() const { return h * w * c; } /** * @brief Get h * w, which is size of one channel in one data part * * @return hw */ inline uint64_t hw() const { return h * w; } /** * @brief Print shape * * @param os Output stream * @param shape Shape to be printed * @return Output stream */ friend std::ostream &operator<<(std::ostream &os, const Shape &shape); /** * @brief Judge whether two shapes are equal * * @param other Another shape * @return Return true if two shapes are equal */ bool operator==(const Shape &other) const; /** * @brief Judge whether two shapes are not equal * * @param other Another shape * @return Return true if two shapes are not equal */ bool operator!=(const Shape &other) const; /// data number uint32_t n; /// height uint32_t h; /// width uint32_t w; /// channel uint32_t c; private: uint32_t stride_; }; // class Shape } // namespace edk #endif // EASYINFER_SHAPE_HPP_