123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- #ifndef _YOLO_LAYER_H
- #define _YOLO_LAYER_H
- #include <vector>
- #include <string>
- #include <NvInfer.h>
- #include "macros.h"
- namespace Yolo
- {
- static constexpr int CHECK_COUNT = 3;
- static constexpr float IGNORE_THRESH = 0.1f;
- struct YoloKernel
- {
- int width;
- int height;
- float anchors[CHECK_COUNT * 2];
- };
- static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
- static constexpr int CLASS_NUM = 3;
- static constexpr int INPUT_H = 640; // yolov5's input height and width must be divisible by 32.
- static constexpr int INPUT_W = 640;
- static constexpr int LOCATIONS = 4;
- struct alignas(float) Detection {
- //center_x center_y w h
- float bbox[LOCATIONS];
- float conf; // bbox_conf * cls_conf
- float class_id;
- };
- }
- namespace nvinfer1
- {
- class API YoloLayerPlugin : public IPluginV2IOExt
- {
- public:
- YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector<Yolo::YoloKernel>& vYoloKernel);
- YoloLayerPlugin(const void* data, size_t length);
- ~YoloLayerPlugin();
- int getNbOutputs() const TRT_NOEXCEPT override
- {
- return 1;
- }
- Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override;
- int initialize() TRT_NOEXCEPT override;
- virtual void terminate() TRT_NOEXCEPT override {};
- virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0; }
- virtual int enqueue(int batchSize, const void* const* inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
- virtual size_t getSerializationSize() const TRT_NOEXCEPT override;
- virtual void serialize(void* buffer) const TRT_NOEXCEPT override;
- bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const TRT_NOEXCEPT override {
- return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
- }
- const char* getPluginType() const TRT_NOEXCEPT override;
- const char* getPluginVersion() const TRT_NOEXCEPT override;
- void destroy() TRT_NOEXCEPT override;
- IPluginV2IOExt* clone() const TRT_NOEXCEPT override;
- void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override;
- const char* getPluginNamespace() const TRT_NOEXCEPT override;
- DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override;
- bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT override;
- bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override;
- void attachToContext(
- cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override;
- void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT override;
- void detachFromContext() TRT_NOEXCEPT override;
- private:
- void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1);
- int mThreadCount = 256;
- const char* mPluginNamespace;
- int mKernelCount;
- int mClassCount;
- int mYoloV5NetWidth;
- int mYoloV5NetHeight;
- int mMaxOutObject;
- std::vector<Yolo::YoloKernel> mYoloKernel;
- void** mAnchor;
- };
- class API YoloPluginCreator : public IPluginCreator
- {
- public:
- YoloPluginCreator();
- ~YoloPluginCreator() override = default;
- const char* getPluginName() const TRT_NOEXCEPT override;
- const char* getPluginVersion() const TRT_NOEXCEPT override;
- const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
- IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override;
- IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override;
- void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override
- {
- mNamespace = libNamespace;
- }
- const char* getPluginNamespace() const TRT_NOEXCEPT override
- {
- return mNamespace.c_str();
- }
- private:
- std::string mNamespace;
- static PluginFieldCollection mFC;
- static std::vector<PluginField> mPluginAttributes;
- };
- REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
- };
- #endif // _YOLO_LAYER_H
|