inference.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. #include "inference.h"
  2. #define MAX_DISPLAY_LEN 64
  3. #define PGIE_CLASS_ID_VEHICLE 0
  4. #define PGIE_CLASS_ID_PERSON 2
  5. /* By default, OSD process-mode is set to CPU_MODE. To change mode, set as:
  6. * 1: GPU mode (for Tesla only)
  7. * 2: HW mode (For Jetson only)
  8. */
  9. #define OSD_PROCESS_MODE 0
  10. /* By default, OSD will not display text. To display text, change this to 1 */
  11. #define OSD_DISPLAY_TEXT 1
  12. /* The muxer output resolution must be set if the input streams will be of
  13. * different resolution. The muxer will scale all the input frames to this
  14. * resolution. */
  15. #define MUXER_OUTPUT_WIDTH 1920
  16. #define MUXER_OUTPUT_HEIGHT 1080
  17. /* Muxer batch formation timeout, for e.g. 40 millisec. Should ideally be set
  18. * based on the fastest source's framerate. */
  19. #define MUXER_BATCH_TIMEOUT_USEC 5000
  20. #define TILED_OUTPUT_WIDTH 1280
  21. #define TILED_OUTPUT_HEIGHT 720
  22. /* NVIDIA Decoder source pad memory feature. This feature signifies that source
  23. * pads having this capability will push GstBuffers containing cuda buffers. */
  24. #define GST_CAPS_FEATURES_NVMM "memory:NVMM"
  25. #define MAX_NUM_SOURCES 30
  26. gint frame_number = 0;
  27. gint g_num_sources = 0;
  28. gint g_source_id_list[MAX_NUM_SOURCES];
  29. gboolean g_eos_list[MAX_NUM_SOURCES];
  30. gboolean g_source_enabled[MAX_NUM_SOURCES];
  31. GMutex eos_lock;
  32. GstElement *g_streammux = NULL;
  33. namespace MIVA{
  34. std::shared_ptr<Inference> infer = NULL;
  35. ThreadPool pool(3,ThreadPool::PRIORITY_HIGHEST, false);
  36. std::shared_ptr<Inference> Inference::CreateNew()
  37. {
  38. if(infer == NULL) infer = std::make_shared<Inference>();
  39. return infer;
  40. }
  41. Inference::Inference()
  42. {
  43. }
  44. Inference::~Inference()
  45. {
  46. Destory();
  47. }
  48. // Init 初始化
  49. int32_t Inference::Init(vector<DataSource> DataList)
  50. {
  51. // init
  52. this->loop = g_main_loop_new (NULL, FALSE);
  53. // 创建管道
  54. this->pipeline = gst_pipeline_new("dstest3-pipeline");
  55. // 创建批处理器
  56. this->streammux = gst_element_factory_make ("nvstreammux", "stream-muxer");
  57. g_streammux = this->streammux;
  58. if(this->pipeline == NULL || this->streammux == NULL){
  59. ErrorL << "One element could not be created. Exiting.";
  60. return ERR;
  61. }
  62. gst_bin_add (GST_BIN (this->pipeline), streammux);
  63. this->m_DataList = DataList;
  64. // 创建数据源
  65. std::vector<DataSource>::iterator iter;
  66. g_num_sources = 0;
  67. for(iter = m_DataList.begin(); iter != m_DataList.end(); iter++){
  68. GstElement *source_bin = create_uridecode_bin (g_num_sources, (gchar*)((*iter).uri).c_str());
  69. if (!source_bin) {
  70. ErrorL << "Failed to create source bin. Exiting.";
  71. return ERR;
  72. }
  73. gst_bin_add(GST_BIN (this->pipeline), source_bin);
  74. iter->source_bin = source_bin;
  75. iter->Play = true;
  76. g_num_sources++;
  77. }
  78. /* Use nvinfer to infer on batched frame. */
  79. this->pgie = gst_element_factory_make("nvinfer", "primary-nvinference-engine");
  80. /* Add queue elements between every two elements */
  81. this->queue1 = gst_element_factory_make ("queue", "queue1");
  82. this->queue2 = gst_element_factory_make ("queue", "queue2");
  83. this->queue3 = gst_element_factory_make ("queue", "queue3");
  84. this->queue4 = gst_element_factory_make ("queue", "queue4");
  85. this->queue5 = gst_element_factory_make ("queue", "queue5");
  86. /* Use nvtiler to composite the batched frames into a 2D tiled array based
  87. * on the source of the frames. */
  88. this->tiler = gst_element_factory_make ("nvmultistreamtiler", "nvtiler");
  89. /* Use convertor to convert from NV12 to RGBA as required by nvosd */
  90. this->nvvidconv = gst_element_factory_make ("nvvideoconvert", "nvvideo-converter");
  91. this->nvosd = gst_element_factory_make ("nvdsosd", "nv-onscreendisplay");
  92. #ifdef PLATFORM_TEGRA
  93. this->transform = gst_element_factory_make ("nvegltransform", "nvegl-transform");
  94. #endif
  95. this->sink = gst_element_factory_make ("nveglglessink", "nvvideo-renderer");
  96. if (!this->pgie || !this->tiler || !this->nvvidconv || !this->nvosd || !this->sink) {
  97. ErrorL << "One element could not be created. Exiting.";
  98. return -1;
  99. }
  100. #ifdef PLATFORM_TEGRA
  101. if(!this->transform) {
  102. ErrorL << "One tegra element could not be created. Exiting.";
  103. return -1;
  104. }
  105. #endif
  106. g_object_set(G_OBJECT(streammux), "batch-size", g_num_sources, NULL);
  107. g_object_set (G_OBJECT (streammux), "width", MUXER_OUTPUT_WIDTH, "height",MUXER_OUTPUT_HEIGHT,
  108. "batched-push-timeout", MUXER_BATCH_TIMEOUT_USEC, NULL);
  109. /* Configure the nvinfer element using the nvinfer config file. */
  110. g_object_set (G_OBJECT (this->pgie),
  111. "config-file-path", "config_infer_primary_yoloV5.txt", NULL);
  112. /* Override the batch-size set in the config file with the number of sources. */
  113. g_object_get (G_OBJECT (this->pgie), "batch-size", &(this->pgie_batch_size), NULL);
  114. if (this->pgie_batch_size != g_num_sources) {
  115. WarnL << "WARNING: Overriding infer-config batch-size:" << this->pgie_batch_size << "with number of sources ("<< g_num_sources << ")";
  116. g_object_set (G_OBJECT (this->pgie), "batch-size", g_num_sources, NULL);
  117. }
  118. this->tiler_rows = (guint) sqrt (g_num_sources);
  119. this->tiler_columns = (guint) ceil (1.0 * g_num_sources / this->tiler_rows);
  120. /* we set the tiler properties here */
  121. g_object_set (G_OBJECT (this->tiler), "rows", this->tiler_rows, "columns", this->tiler_columns,
  122. "width", TILED_OUTPUT_WIDTH, "height", TILED_OUTPUT_HEIGHT, NULL);
  123. g_object_set (G_OBJECT (this->nvosd), "process-mode", OSD_PROCESS_MODE,
  124. "display-text", OSD_DISPLAY_TEXT, NULL);
  125. g_object_set (G_OBJECT (this->sink), "qos", 0, NULL);
  126. this->bus = gst_pipeline_get_bus (GST_PIPELINE (this->pipeline));
  127. this->bus_watch_id = gst_bus_add_watch (this->bus, bus_call, this->loop);
  128. gst_object_unref (this->bus);
  129. gst_bin_add_many (GST_BIN (this->pipeline), this->queue1, this->pgie, this->queue2, this->tiler, this->queue3,
  130. this->nvvidconv, this->queue4, this->nvosd, this->queue5, this->transform, this->sink, NULL);
  131. if (!gst_element_link_many (streammux, this->queue1, this->pgie, this->queue2, this->tiler, this->queue3,
  132. this->nvvidconv, this->queue4, this->nvosd, this->queue5, this->transform, this->sink, NULL)) {
  133. ErrorL << "Elements could not be linked. Exiting.";
  134. return -1;
  135. }
  136. this->tiler_src_pad = gst_element_get_static_pad(this->pgie, "src");
  137. if (!this->tiler_src_pad)
  138. InfoL << "Unable to get src pad";
  139. else
  140. gst_pad_add_probe (this->tiler_src_pad, GST_PAD_PROBE_TYPE_BUFFER,
  141. tiler_src_pad_buffer_probe, NULL, NULL);
  142. gst_object_unref (this->tiler_src_pad);
  143. return OK;
  144. }
  145. void Inference::ReadyTask()
  146. {
  147. InfoL << "Now ReadyTask";
  148. gst_element_set_state(this->pipeline, GST_STATE_READY);
  149. g_main_loop_run(this->loop);
  150. }
  151. // 启动任务
  152. void Inference::StartTask()
  153. {
  154. static int ret = 0;
  155. InfoL << "Now palying";
  156. if(ret != 0){
  157. this->RestartTask();
  158. }else{
  159. ret++;
  160. gst_element_set_state(this->pipeline, GST_STATE_PLAYING);
  161. }
  162. }
  163. // 暂停任务
  164. void Inference::PauseTask()
  165. {
  166. InfoL << "Now Pause";
  167. std::vector<DataSource>::iterator iter;
  168. for(iter = this->m_DataList.begin(); iter != this->m_DataList.end(); iter++){
  169. if(iter->Play){
  170. gst_element_set_state(iter->source_bin, GST_STATE_PAUSED);
  171. }
  172. }
  173. }
  174. void Inference::StopTask()
  175. {
  176. pool.async([&](){
  177. int sourceId = 0;
  178. g_mutex_lock (&eos_lock);
  179. std::vector<DataSource>::iterator iter;
  180. for (iter = this->m_DataList.begin(); iter != this->m_DataList.end(); iter++){
  181. if(iter->Play){
  182. this->stop_release_source(sourceId);
  183. iter->Play = false;
  184. }
  185. sourceId++;
  186. }
  187. g_mutex_unlock (&eos_lock);
  188. NoticeCenter::Instance().emitEvent(NOTICE_RELEASE);
  189. });
  190. pool.start();
  191. }
  192. void Inference::RestartTask()
  193. {
  194. pool.async([&](){
  195. int sourceId = 0;
  196. std::vector<DataSource>::iterator iter;
  197. for (iter = this->m_DataList.begin(); iter != this->m_DataList.end(); iter++){
  198. if(iter->Play == false){
  199. this->add_sources(sourceId, iter->uri);
  200. iter->Play = true;
  201. }
  202. sourceId++;
  203. }
  204. gst_element_set_state(this->pipeline, GST_STATE_PLAYING);
  205. });
  206. pool.start();
  207. }
  208. // 销毁对象
  209. void Inference::Destory()
  210. {
  211. InfoL << "Returned, stopping playback";
  212. gst_element_set_state(this->pipeline, GST_STATE_NULL);
  213. InfoL << "Deleting pipeline";
  214. gst_object_unref(GST_OBJECT(this->pipeline));
  215. g_source_remove(this->bus_watch_id);
  216. g_main_loop_unref(this->loop);
  217. infer = NULL;
  218. }
  219. GstPadProbeReturn
  220. Inference::tiler_src_pad_buffer_probe(GstPad * pad, GstPadProbeInfo * info, gpointer u_data)
  221. {
  222. //获取从管道中获取推理结果
  223. GstBuffer *buf = (GstBuffer *) info->data;
  224. NvDsBatchMeta *batch_meta = gst_buffer_get_nvds_batch_meta (buf);
  225. //初始化要使用的数据结构
  226. NvDsObjectMeta *obj_meta = NULL; //目标检测元数据类型变量
  227. NvDsMetaList * l_frame = NULL;
  228. NvDsMetaList * l_obj = NULL;
  229. NvDsDisplayMeta *display_meta = NULL;
  230. for (l_frame = batch_meta->frame_meta_list; l_frame != NULL;l_frame = l_frame->next) //从批量中获取某一帧图
  231. {
  232. NvDsFrameMeta *frame_meta = (NvDsFrameMeta *) (l_frame->data);
  233. int num = 0;
  234. for (l_obj = frame_meta->obj_meta_list; l_obj != NULL;l_obj = l_obj->next)
  235. {
  236. obj_meta = (NvDsObjectMeta *) (l_obj->data);
  237. if (obj_meta->class_id == 0) // Person
  238. {
  239. num++;
  240. }
  241. }
  242. //画左上角的统计信息
  243. display_meta = nvds_acquire_display_meta_from_pool(batch_meta);
  244. NvOSD_TextParams *txt_params = &display_meta->text_params[0];
  245. display_meta->num_labels = 1;
  246. txt_params->display_text = (char *)g_malloc0 (MAX_DISPLAY_LEN);
  247. snprintf(txt_params->display_text, MAX_DISPLAY_LEN, "Number of people: %d \n", num);
  248. // 推理广播
  249. NoticeCenter::Instance().emitEvent(NOTICE_INFER,frame_meta->source_id, num);
  250. txt_params->x_offset = 30;
  251. txt_params->y_offset = 30;
  252. /* Font , font-color and font-size */
  253. txt_params->font_params.font_name = (char *)"Serif";
  254. txt_params->font_params.font_size = 10;
  255. txt_params->font_params.font_color.red = 1.0;
  256. txt_params->font_params.font_color.green = 1.0;
  257. txt_params->font_params.font_color.blue = 1.0;
  258. txt_params->font_params.font_color.alpha = 1.0;
  259. /* Text background color */
  260. txt_params->set_bg_clr = 1;
  261. txt_params->text_bg_clr.red = 0.0;
  262. txt_params->text_bg_clr.green = 0.0;
  263. txt_params->text_bg_clr.blue = 0.0;
  264. txt_params->text_bg_clr.alpha = 1.0;
  265. //nvds_add_display_meta_to_frame(frame_meta, display_meta);
  266. }
  267. return GST_PAD_PROBE_OK;
  268. }
  269. gboolean Inference::bus_call (GstBus * bus, GstMessage * msg, gpointer data)
  270. {
  271. GMainLoop *loop = (GMainLoop *) data;
  272. switch (GST_MESSAGE_TYPE (msg)) {
  273. case GST_MESSAGE_EOS:
  274. InfoL << "End of stream";
  275. g_main_loop_quit (loop);
  276. break;
  277. case GST_MESSAGE_WARNING:
  278. {
  279. gchar *debug;
  280. GError *error;
  281. gst_message_parse_warning (msg, &error, &debug);
  282. WarnL << "WARNING from element " << GST_OBJECT_NAME (msg->src) << ": " << error->message;
  283. g_free (debug);
  284. ErrorL << "Warning: " << error->message;
  285. g_error_free (error);
  286. break;
  287. }
  288. case GST_MESSAGE_ERROR:
  289. {
  290. gchar *debug;
  291. GError *error;
  292. gst_message_parse_error (msg, &error, &debug);
  293. ErrorL << "ERROR from element" << GST_OBJECT_NAME (msg->src) << ":" << error->message;
  294. if (debug)
  295. ErrorL << "Error details:" << debug;
  296. g_free (debug);
  297. g_error_free (error);
  298. g_main_loop_quit (loop);
  299. break;
  300. }
  301. #ifndef PLATFORM_TEGRA
  302. case GST_MESSAGE_ELEMENT:
  303. {
  304. if (gst_nvmessage_is_stream_eos (msg)) {
  305. guint stream_id;
  306. if (gst_nvmessage_parse_stream_eos (msg, &stream_id)) {
  307. InfoL << "Got EOS from stream " << stream_id;
  308. }
  309. }
  310. break;
  311. }
  312. #endif
  313. default:
  314. break;
  315. }
  316. return TRUE;
  317. }
  318. void Inference::decodebin_child_added (GstChildProxy * child_proxy, GObject * object,
  319. gchar * name, gpointer user_data)
  320. {
  321. InfoL << "Decodebin child added: " << name;
  322. if (g_strrstr (name, "decodebin") == name) {
  323. g_signal_connect (G_OBJECT (object), "child-added",
  324. G_CALLBACK (decodebin_child_added), user_data);
  325. }
  326. }
  327. void Inference::stop_release_source (gint source_id)
  328. {
  329. GstStateChangeReturn state_return;
  330. gchar pad_name[16];
  331. GstPad *sinkpad = NULL;
  332. state_return = gst_element_set_state (m_DataList[source_id].source_bin, GST_STATE_NULL);
  333. switch (state_return) {
  334. case GST_STATE_CHANGE_SUCCESS:
  335. InfoL << "STATE CHANGE SUCCESS";
  336. g_snprintf (pad_name, 15, "sink_%u", source_id);
  337. sinkpad = gst_element_get_static_pad (streammux, pad_name);
  338. gst_pad_send_event (sinkpad, gst_event_new_flush_stop (FALSE));
  339. gst_element_release_request_pad (streammux, sinkpad);
  340. InfoL << "STATE CHANGE SUCCESS:" << source_id;
  341. gst_object_unref (sinkpad);
  342. gst_bin_remove (GST_BIN (this->pipeline), m_DataList[source_id].source_bin);
  343. break;
  344. case GST_STATE_CHANGE_FAILURE:
  345. ErrorL << "STATE CHANGE FAILURE";
  346. break;
  347. case GST_STATE_CHANGE_ASYNC:
  348. InfoL << "STATE CHANGE ASYNC";
  349. state_return =
  350. gst_element_get_state (m_DataList[source_id].source_bin, NULL, NULL,
  351. GST_CLOCK_TIME_NONE);
  352. g_snprintf (pad_name, 15, "sink_%u", source_id);
  353. sinkpad = gst_element_get_static_pad (streammux, pad_name);
  354. gst_pad_send_event (sinkpad, gst_event_new_flush_stop (FALSE));
  355. gst_element_release_request_pad (streammux, sinkpad);
  356. g_print ("STATE CHANGE ASYNC %p\n\n", sinkpad);
  357. gst_object_unref (sinkpad);
  358. gst_bin_remove (GST_BIN (this->pipeline), m_DataList[source_id].source_bin);
  359. break;
  360. case GST_STATE_CHANGE_NO_PREROLL:
  361. InfoL << "STATE CHANGE NO PREROLL";
  362. break;
  363. default:
  364. break;
  365. }
  366. }
  367. void Inference::cb_newpad (GstElement * decodebin, GstPad * pad, gpointer data)
  368. {
  369. GstCaps *caps = gst_pad_query_caps (pad, NULL);
  370. const GstStructure *str = gst_caps_get_structure (caps, 0);
  371. const gchar *name = gst_structure_get_name (str);
  372. g_print ("decodebin new pad %s\n", name);
  373. if (!strncmp (name, "video", 5)) {
  374. gint source_id = (*(gint *) data);
  375. gchar pad_name[16] = { 0 };
  376. GstPad *sinkpad = NULL;
  377. g_snprintf (pad_name, 15, "sink_%u", source_id);
  378. sinkpad = gst_element_get_request_pad (g_streammux, pad_name);
  379. if (gst_pad_link (pad, sinkpad) != GST_PAD_LINK_OK) {
  380. g_print ("Failed to link decodebin to pipeline\n");
  381. } else {
  382. g_print ("Decodebin linked to pipeline\n");
  383. }
  384. gst_object_unref (sinkpad);
  385. }
  386. }
  387. GstElement* Inference::create_uridecode_bin (guint index, gchar * filename)
  388. {
  389. GstElement *bin = NULL;
  390. gchar bin_name[16] = { };
  391. g_print ("creating uridecodebin for [%s]\n", filename);
  392. g_source_id_list[index] = index;
  393. g_snprintf (bin_name, 15, "source-bin-%02d", index);
  394. bin = gst_element_factory_make ("uridecodebin", bin_name);
  395. g_object_set (G_OBJECT (bin), "uri", filename, NULL);
  396. g_signal_connect (G_OBJECT (bin), "pad-added",
  397. G_CALLBACK (cb_newpad), &g_source_id_list[index]);
  398. g_signal_connect (G_OBJECT (bin), "child-added",
  399. G_CALLBACK (decodebin_child_added), &g_source_id_list[index]);
  400. g_source_enabled[index] = TRUE;
  401. return bin;
  402. }
  403. // 增加数据源
  404. gboolean Inference::add_sources (int source_Id, std::string uri)
  405. {
  406. g_mutex_lock (&eos_lock);
  407. GstElement *source_bin;
  408. GstStateChangeReturn state_return;
  409. InfoL << "Calling Start " << source_Id;
  410. source_bin = create_uridecode_bin (source_Id, (gchar *)uri.c_str());
  411. if (!source_bin) {
  412. ErrorL << "Failed to create source bin. Exiting.";
  413. return -1;
  414. }
  415. m_DataList[source_Id].source_bin = source_bin;
  416. gst_bin_add (GST_BIN (this->pipeline), source_bin);
  417. state_return = gst_element_set_state (source_bin, GST_STATE_PLAYING);
  418. switch (state_return) {
  419. case GST_STATE_CHANGE_SUCCESS:
  420. InfoL << "STATE CHANGE SUCCESS.";
  421. break;
  422. case GST_STATE_CHANGE_FAILURE:
  423. InfoL << "STATE CHANGE FAILURE";
  424. break;
  425. case GST_STATE_CHANGE_ASYNC:
  426. InfoL << "STATE CHANGE ASYNC";
  427. state_return = gst_element_get_state (source_bin, NULL, NULL,GST_CLOCK_TIME_NONE);
  428. break;
  429. case GST_STATE_CHANGE_NO_PREROLL:
  430. InfoL << "STATE CHANGE NO PREROLL";
  431. break;
  432. default:
  433. break;
  434. }
  435. g_mutex_unlock (&eos_lock);
  436. return TRUE;
  437. }
  438. }