/************************************************************************* * Copyright (C) [2019] by Cambricon, Inc. All rights reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. *************************************************************************/ #ifndef INFER_SERVER_SHAPE_H_ #define INFER_SERVER_SHAPE_H_ #include #include namespace infer_server { /** * @brief Shape to describe inference model input and output data * @warning No matter how data is placed in memory, dim values always keep in order of NHWC */ class Shape { public: /// stored value type using value_type = int64_t; /** * @brief Construct a new Shape object */ Shape() = default; /** * @brief Construct a new Shape object from shape vector * * @param v vector stored shape value */ explicit Shape(const std::vector& v) noexcept { data_ = v; } /** * @brief Copy assign a Shape object from shape vector * * @param v vector stored shape value */ Shape& operator=(const std::vector& v) noexcept { data_ = v; return *this; } Shape(const Shape&) = default; Shape& operator=(const Shape&) = default; Shape(Shape&&) = default; Shape& operator=(Shape&&) = default; /** * @brief Get value of nth dimension * * @param offset serial number of dimension * @return value_type shape value */ value_type operator[](int offset) const noexcept { return data_[offset]; } /** * @brief Get value of nth dimension * * @param offset serial number of dimension * @return value_type reference to shape value */ value_type& operator[](int offset) noexcept { return data_[offset]; } /** * @brief Returns the dimension size of Shape * * @return size_t The dimension size of Shape */ size_t Size() const noexcept { return data_.size(); }; /** * @brief Returns whether Shape is empty * * @retval true if the Shape doesn't have any value * @retval false otherwise */ bool Empty() const noexcept { return data_.empty(); }; /** * @brief Get vectorized shape value * * @return std::vector vectorized shape value */ std::vector Vectorize() const noexcept { return data_; } /** * @brief Get batchsize * * @return value_type batch size */ value_type BatchSize() const noexcept { return data_[0]; } /** * @brief Get total data count / batch size * * @return Data count */ int64_t DataCount() const noexcept { int64_t cnt = 1; for (size_t i = 1; i < data_.size(); ++i) { cnt *= data_[i]; } return cnt; } /** * @brief Get total data count * * @return Total data count */ int64_t BatchDataCount() const noexcept { int64_t cnt = 1; for (size_t i = 0; i < data_.size(); ++i) { cnt *= data_[i]; } return cnt; } /** * @brief Print shape * * @param os Output stream * @param shape Shape to be printed * @return Output stream */ friend std::ostream& operator<<(std::ostream& os, const Shape& shape) { os << "Shape ("; for (size_t i = 0; i < shape.Size() - 1; ++i) { os << shape[i] << ", "; } if (shape.Size() > 0) os << shape[shape.Size() - 1]; os << ")"; return os; } /** * @brief Judge whether two shapes are equal * * @param other another Shape * @retval true if two shapes are equal * @retval false otherwise */ bool operator==(const Shape& other) const noexcept { if (Size() != other.Size()) return false; for (size_t i = 0; i < Size(); ++i) { if (data_[i] != other[i]) return false; } return true; } /** * @brief Judge whether two shapes are not equal * * @param other another Shape * @retval true if two shapes are not equal * @retval false otherwise */ bool operator!=(const Shape& other) const noexcept { return !(*this == other); } private: std::vector data_; }; // class Shape } // namespace infer_server #endif // INFER_SERVER_SHAPE_H_