/************************************************************************* * Copyright (C) [2019] by Cambricon, Inc. All rights reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. *************************************************************************/ #ifndef EASYINFER_MLU_TASK_QUEUE_H_ #define EASYINFER_MLU_TASK_QUEUE_H_ #include #include #include #include #include #include #include "device/mlu_context.h" namespace edk { #define CALL_CNRT_FUNC(func, msg) \ do { \ cnrtRet_t ret = (func); \ if (CNRT_RET_SUCCESS != ret) { \ THROW_EXCEPTION(Exception::INTERNAL, std::string(msg) + ", cnrt error code : " + std::to_string(ret)); \ } \ } while (0) class TimeMark { public: TimeMark() { CALL_CNRT_FUNC(cnrtCreateNotifier(&base_), "Create notifier failed"); } TimeMark(TimeMark&& other) : base_(other.base_) { other.base_ = nullptr; } TimeMark& operator=(TimeMark&& other) { base_ = other.base_; other.base_ = nullptr; return *this; } ~TimeMark() { if (nullptr != base_) cnrtDestroyNotifier(&base_); } void Mark(cnrtQueue_t queue) { CALL_CNRT_FUNC(cnrtPlaceNotifier(base_, queue), "cnrtPlaceNotifier failed"); } void Mark(MluTaskQueue_t queue); cnrtNotifier_t GetNotifier() noexcept { return base_; } // get hardware time in ms static float Count(const TimeMark& start, const TimeMark& end) { float dura; CALL_CNRT_FUNC(cnrtNotifierDuration(start.base_, end.base_, &dura), "Calculate elapsed time failed."); dura /= 1000; return dura; } private: TimeMark(const TimeMark&) = delete; TimeMark& operator=(const TimeMark&) = delete; cnrtNotifier_t base_{nullptr}; }; struct MluTaskQueuePrivate { ~MluTaskQueuePrivate(); cnrtQueue_t queue = nullptr; std::vector marks; std::vector marks_valid; }; inline void MluTaskQueue::_PrivDelete::operator()(MluTaskQueuePrivate* p) { delete p; } class MluTaskQueueProxy { public: static cnrtQueue_t GetCnrtQueue(MluTaskQueue_t q) noexcept { return q->priv_->queue; } static void SetCnrtQueue(MluTaskQueue_t q, cnrtQueue_t cnrt_q) { if (q->priv_->queue) { q->priv_.reset(new MluTaskQueuePrivate); } q->priv_->queue = cnrt_q; } static MluTaskQueue_t Wrap(cnrtQueue_t cnrt_q) { auto q = std::shared_ptr(new MluTaskQueue); q->priv_->queue = cnrt_q; return q; } }; inline void TimeMark::Mark(MluTaskQueue_t queue) { Mark(MluTaskQueueProxy::GetCnrtQueue(std::move(queue))); } } // namespace edk #endif // EASYINFER_MLU_TASK_QUEUE_H_