yololayer.h 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #ifndef _YOLO_LAYER_H
  2. #define _YOLO_LAYER_H
  3. #include <vector>
  4. #include <string>
  5. #include <NvInfer.h>
  6. #include "macros.h"
  7. namespace Yolo
  8. {
  9. static constexpr int CHECK_COUNT = 3;
  10. static constexpr float IGNORE_THRESH = 0.1f;
  11. struct YoloKernel
  12. {
  13. int width;
  14. int height;
  15. float anchors[CHECK_COUNT * 2];
  16. };
  17. static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
  18. static constexpr int CLASS_NUM = 3;
  19. static constexpr int INPUT_H = 640; // yolov5's input height and width must be divisible by 32.
  20. static constexpr int INPUT_W = 640;
  21. static constexpr int LOCATIONS = 4;
  22. struct alignas(float) Detection {
  23. //center_x center_y w h
  24. float bbox[LOCATIONS];
  25. float conf; // bbox_conf * cls_conf
  26. float class_id;
  27. };
  28. }
  29. namespace nvinfer1
  30. {
  31. class API YoloLayerPlugin : public IPluginV2IOExt
  32. {
  33. public:
  34. YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector<Yolo::YoloKernel>& vYoloKernel);
  35. YoloLayerPlugin(const void* data, size_t length);
  36. ~YoloLayerPlugin();
  37. int getNbOutputs() const TRT_NOEXCEPT override
  38. {
  39. return 1;
  40. }
  41. Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override;
  42. int initialize() TRT_NOEXCEPT override;
  43. virtual void terminate() TRT_NOEXCEPT override {};
  44. virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0; }
  45. virtual int enqueue(int batchSize, const void* const* inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
  46. virtual size_t getSerializationSize() const TRT_NOEXCEPT override;
  47. virtual void serialize(void* buffer) const TRT_NOEXCEPT override;
  48. bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const TRT_NOEXCEPT override {
  49. return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
  50. }
  51. const char* getPluginType() const TRT_NOEXCEPT override;
  52. const char* getPluginVersion() const TRT_NOEXCEPT override;
  53. void destroy() TRT_NOEXCEPT override;
  54. IPluginV2IOExt* clone() const TRT_NOEXCEPT override;
  55. void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override;
  56. const char* getPluginNamespace() const TRT_NOEXCEPT override;
  57. DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override;
  58. bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT override;
  59. bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override;
  60. void attachToContext(
  61. cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override;
  62. void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT override;
  63. void detachFromContext() TRT_NOEXCEPT override;
  64. private:
  65. void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1);
  66. int mThreadCount = 256;
  67. const char* mPluginNamespace;
  68. int mKernelCount;
  69. int mClassCount;
  70. int mYoloV5NetWidth;
  71. int mYoloV5NetHeight;
  72. int mMaxOutObject;
  73. std::vector<Yolo::YoloKernel> mYoloKernel;
  74. void** mAnchor;
  75. };
  76. class API YoloPluginCreator : public IPluginCreator
  77. {
  78. public:
  79. YoloPluginCreator();
  80. ~YoloPluginCreator() override = default;
  81. const char* getPluginName() const TRT_NOEXCEPT override;
  82. const char* getPluginVersion() const TRT_NOEXCEPT override;
  83. const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
  84. IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override;
  85. IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override;
  86. void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override
  87. {
  88. mNamespace = libNamespace;
  89. }
  90. const char* getPluginNamespace() const TRT_NOEXCEPT override
  91. {
  92. return mNamespace.c_str();
  93. }
  94. private:
  95. std::string mNamespace;
  96. static PluginFieldCollection mFC;
  97. static std::vector<PluginField> mPluginAttributes;
  98. };
  99. REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
  100. };
  101. #endif // _YOLO_LAYER_H