config.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2022/6/24 15:18
  3. # @Author : MaochengHu
  4. # @Email : wojiaohumaocheng@gmail.com
  5. # @File : config.py
  6. # @Project : person_monitor
  7. import os
  8. import argparse
  9. # -------------------------- 常规配置参数 ------------------------------- #
  10. # 项目根路径
  11. pro_root = "/data/fengyang/sunwin/code/person_monitor"
  12. # 依赖根路径
  13. dependence_root = os.path.join(pro_root, "dependence")
  14. # 模型根路径
  15. model_root = os.path.join(pro_root, "dev/src/algorithms/save_models")
  16. # 测试文件(可以是本地视频文件也可以是在线视频流地址)
  17. # input_source = "rtsp://admin:sunwin2019@192.168.20.240:554/h265/ch1/main/av_stream"
  18. input_source = "/data/fengyang/sunwin/code/person_monitor/dev/test/demo.mp4"
  19. # 使用GPU索引
  20. cuda_index = 0
  21. item_max_size = 60 # 保留多少帧进行动作识别
  22. # 保存结果
  23. save_result = True # 是否将识别出来的结果进行保存
  24. use_keypoint = True # 是否通过关键点进行行为识别,如果不用关键点则直接基于视频裁剪进行行为识别
  25. # ---------------------- 人物目标检测配置参数 --------------------------------#
  26. # yolov7模型加载参数
  27. # 模型加载参数
  28. object_detection_model_config = dict(
  29. pt_weights=os.path.join(model_root, "object_detection_model", "helmet_fall_phone.engine"), # 目标检测的权重地址
  30. data=os.path.join(model_root, "object_detection_model", "second_stage_bbox_helmet_fall_phone.yaml"), # 目标检测对应的类别文件
  31. imgsz=(640, 640), # 表示图片大小,不通的yolo模型需要对应不同的输入大小(不宜修改)
  32. device=cuda_index, # GPU索引
  33. confthre=0.001, # 做跟踪所以置信度需要很小, 保证丢失目标能补上
  34. nmsthre=0.7, # nms阈值大小
  35. max_det=20 # 设置最大检测人数
  36. )
  37. # class_list = ["play_phone", "call_phone", "sleep", "work", "no_helmet", "helmet"]
  38. person_end_index = 3
  39. person_attr = ["play_phone", "call_phone", "sleep", "work"]
  40. match_score_threshold = 0.8
  41. # ----------------------- 人物属性检测模型配置参数 ---------------------------- #
  42. # yolov7模型加载参数
  43. person_attribute_model_config = dict(
  44. pt_weights=os.path.join(model_root, "object_detection_model", "helmet_fall_phone.engine"), # 目标检测的权重地址
  45. data=os.path.join(model_root, "object_detection_model", "second_stage_bbox_helmet_fall_phone.yaml"), # 目标检测对应的类别文件
  46. imgsz=(640, 640), # 表示图片大小,不通的yolo模型需要对应不同的输入大小(不宜修改)
  47. device=cuda_index, # GPU索引
  48. confthre=0.3, # 目标置信度
  49. nmsthre=0.2, # nms阈值大小
  50. )
  51. # 人物状态类别
  52. person_class_list = ["play_phone", "call_phone", "sleep", "work"]
  53. # 安全帽佩戴状态类别
  54. helmet_class_list = ["no_helmet", "helmet"]
  55. # ----------------------- 人体跟踪模型配置参数 ---------------------------- #
  56. tracker_max_id = 100 # 建议大于等于max_det的5倍及以上
  57. tracker_model_config = dict(
  58. track_thresh=0.5, # 跟踪人体置信度
  59. track_buffer=30, # 如果人体框丢失多少帧则不进行追回
  60. match_thresh=0.8, # 相似度匹配阈值多少算匹配上
  61. mot20=False, # 是否使用mot20 计算
  62. tracker_max_id=tracker_max_id, # 最多跟踪多少人, 如果超过该人数, 则重新计数
  63. )
  64. # 生成跟踪参数解析器
  65. tracker_parser = argparse.ArgumentParser()
  66. for k, v in tracker_model_config.items():
  67. tracker_parser.add_argument("--{}".format(k), default=v)
  68. tracker_args = tracker_parser.parse_args()
  69. #
  70. tracker_frame_rate = 30 # 跟踪视频的fps的值
  71. min_box_area = 10 # 小于多少的边不进行识别
  72. output_side_size = 640 # 如果不进行关键点识别,直接基于视频识别,则需要进行padding裁剪对应的输出图片大小
  73. tracker_line_size = 90 # 设置人体行为轨迹跟踪线长度
  74. # ----------------------- 关键点模型配置参数 ---------------------------- #
  75. pose_name = "tiny_pose"
  76. pose_model_platform = "paddle" # 目前仅仅支持paddle(飞浆) 以及 mmpose 平台
  77. pose_trt = True # 是否使用tensorrt加速
  78. if pose_model_platform == "paddle":
  79. if pose_trt:
  80. run_mode = "trt_fp32"
  81. else:
  82. run_mode = "paddle"
  83. keypoint_model_config = dict(model_dir=os.path.join(model_root, "pose_model/tinypose_256x192"),
  84. device="gpu:{}".format(cuda_index),
  85. trt_calib_mode=True,
  86. run_mode=run_mode,
  87. enable_mkldnn=True,
  88. batch_size=8,
  89. threshold=0.5
  90. )
  91. elif pose_model_platform == "mmpose":
  92. if pose_trt:
  93. keypoint_model_config = dict(model_config_path=os.path.join(model_root,
  94. "mspn50_coco_256x192_topdown_heatmap/mspn50_coco_256x192.py "),
  95. deploy_config_path=os.path.join(model_root,
  96. "mspn50_coco_256x192_topdown_heatmap/pose-detection_tensorrt_static-256x192.py"),
  97. device="cuda:{}".format(cuda_index),
  98. checkpoint=[os.path.join(model_root,
  99. "mspn50_coco_256x192_topdown_heatmap/end2end.engine")]
  100. )
  101. else:
  102. keypoint_model_config = dict(model_config_path=os.path.join(model_root,
  103. "mspn50_coco_256x192_topdown_heatmap/mspn50_coco_256x192.py"),
  104. device="cuda:{}".format(cuda_index),
  105. checkpoint=os.path.join(model_root,
  106. "mspn50_coco_256x192_topdown_heatmap/mspn50_coco_256x192-8fbfb5d0_20201123.pth")
  107. )
  108. # ------------------------- 行为识别模型配置参数 ------------------------ #
  109. if use_keypoint:
  110. action_config_root = os.path.join(model_root, "action_model/stgcn_80e_ntu60_xsub_keypoint")
  111. save_kp_npy = False # 是否需要保留关键点姿态采集, 采集骨骼关键点, 并画出对应的骨骼关键点视频
  112. dataset_format = 'TopDownCocoDataset'
  113. class_name = "fall" # run, jump .etc # 需要采集的骨骼关键点对应的动作类别
  114. npy_output_dir = os.path.join(pro_root, "test_npy/{}".format(class_name))
  115. if save_kp_npy:
  116. if not os.path.exists(npy_output_dir):
  117. os.makedirs(npy_output_dir)
  118. action_model_config = dict(
  119. model_config_path=os.path.join(action_config_root, "stgcn_80e_ntu60_xsub_keypoint_customer.py"),
  120. checkpoint=os.path.join(action_config_root, "best_top1_acc_epoch_26.pth"),
  121. action_label=os.path.join(pro_root, "dev/configs/customer_action.txt"),
  122. device="cuda:{}".format(cuda_index),
  123. item_max_size=item_max_size, # 保留多少帧进行动作识别
  124. save_kp_npy=save_kp_npy,
  125. dataset_format=dataset_format,
  126. npy_output_dir=npy_output_dir
  127. )
  128. # ----------------------- 人群聚集检测配置参数 ---------------------------- #
  129. eps = 100 # 人员聚类距离
  130. min_samples = 2 # 簇最少人数
  131. # ----------------------- 人员入侵配置参数 ------------------------------- #
  132. limited_area = (800, 200, 1000, 600) # 对应限制区域的画面坐标
  133. # -------------------------- 显示结果配置参数 ---------------------------- #
  134. show_result = True # 是否需要展示效果
  135. show_config = dict(kps_threshold=0.3, draw_point_num=30) # 关键点的展示阈值以及需要画跟踪点的长度