Jelajahi Sumber

add multi-thread for detect keypoint and person attribute

MaochengHu 2 tahun lalu
induk
melakukan
8e6e8062a2
1 mengubah file dengan 390 tambahan dan 0 penghapusan
  1. 390 0
      dev/src/run_multi_thread.py

+ 390 - 0
dev/src/run_multi_thread.py

@@ -0,0 +1,390 @@
+# -*- coding: utf-8 -*-
+# @Time : 2022/6/24 14:26
+# @Author : MaochengHu
+# @Email : wojiaohumaocheng@gmail.com
+# @File : run.py
+# @Project : person_monitor
+import threading
+
+from all_packages import *
+
+PRO_DIR = "/data2/humaocheng/person_monitor"
+
+
+def parse_args():
+    """
+    user can input config path by --config_path params
+    :return:
+    """
+    parser = argparse.ArgumentParser(description="video monitor")
+    parser.add_argument("--config_path", default=os.path.join(PRO_DIR, "dev/configs/config.py"),
+                        help="config_files file")
+    args = parser.parse_args()
+    return args
+
+
+class MyThread(threading.Thread):
+    def __init__(self, target, args, name=''):
+        threading.Thread.__init__(self)
+        self.name = name
+        self.target = target
+        self.args = args
+        self.result = self.target(*self.args)
+
+    def get_result(self):
+        try:
+            return self.result
+        except Exception:
+            return None
+
+
+
+class Monitor(object):
+    """
+    This project is for monitoring folk actions. Now it can support these functions:
+    1. Tracking Algorithm
+        [1] person number count
+        [2] person tracking
+        [3] person move direction
+    2 Object Detection
+        [4] helmet detection
+        [5] sleep detection
+        [6] play/call phone detection
+        [7] person gathering
+    3. Action Recognition Algorithm
+        [8] fall down action recognition
+        [9] jump action recognition
+        [10] walk action recognition
+        [11] run action recognition
+        [12] stand action recognition
+    """
+
+    def __init__(self, opt):
+        # load config file
+        self.opt = opt
+        self.config_path = opt.config_path
+        self.cfg = Config(self.config_path)
+        self.cfg.config_from_dict(vars(self.opt))
+        self.cfg.show_info()
+
+        # init predictor
+        self.kp_predictor = None
+        self.od_predictor = None
+        self.action_predictor = None
+
+        # init object detection predictor model
+        self.od_model_cfg = self.cfg.get("object_detection_model_config")
+        self.max_det = self.od_model_cfg["max_det"]
+        self.od_predictor = YoloV5Predictor(
+            data=self.od_model_cfg.get("data"),
+            pt_weigths=self.od_model_cfg["pt_weights"],
+            imgsz=self.od_model_cfg["imgsz"],
+            device=self.od_model_cfg["device"],
+            confthre=self.od_model_cfg["confthre"],
+            nmsthre=self.od_model_cfg["nmsthre"],
+            max_det=self.max_det
+        )
+        self.min_box_area = self.cfg.get("min_box_area")
+
+        # init person attribute detection predictor model
+        self.pa_model_cfg = self.cfg.get("person_attribute_model_config")
+        self.pa_predictor = YoloV5Predictor(
+            data=self.pa_model_cfg.get("data"),
+            pt_weigths=self.pa_model_cfg["pt_weights"],
+            imgsz=self.pa_model_cfg["imgsz"],
+            device=self.pa_model_cfg["device"],
+            confthre=self.pa_model_cfg["confthre"],
+            nmsthre=self.pa_model_cfg["nmsthre"]
+        )
+        self.attr_names = self.pa_predictor.opt.get("names")
+        self.attr_map = dict()
+        for attr_index, attr_names in enumerate(self.attr_names):
+            self.attr_map[attr_index] = attr_names
+        self.person_attr = self.cfg.get("person_attr")
+
+        # init tracker
+        self.timer = Timer()
+        self.tracker_frame_rate = self.cfg.get("tracker_frame_rate")
+        self.track_model_cfg = self.cfg.get("tracker_model_config")
+        self.tracker_args = self.cfg.get("tracker_args")
+        self.tracker = BYTETracker(args=self.tracker_args,
+                                   frame_rate=self.tracker_frame_rate)
+        self.output_side_size = self.cfg.get("output_side_size")
+        self.max_time_lost = self.tracker.max_time_lost
+        self.tracker_line_size = self.cfg.get("tracker_line_size")
+        self.tracker_max_id = self.cfg.get("tracker_max_id")
+
+        # init pose predictor model
+        self.pose_name = self.cfg.get("pose_name")
+        self.pose_model_platform = self.cfg.get("pose_model_platform")
+        self.keypoint_model_config = self.cfg.get("keypoint_model_config")
+        if self.pose_model_platform == "paddle":
+            self.kp_predictor = PaddlePosePredictor(
+                model_dir=self.keypoint_model_config.get("model_dir"),
+                device=self.keypoint_model_config.get("device"),
+                trt_calib_mode=self.keypoint_model_config.get("trt_calib_mode"),
+                run_mode=self.keypoint_model_config.get("run_mode"),
+                enable_mkldnn=self.keypoint_model_config.get("enable_mkldnn"),
+                batch_size=self.keypoint_model_config.get("batch_size"),
+                threshold=self.keypoint_model_config.get("threshold")
+            )
+        elif self.pose_model_platform == "mmpose":
+            self.kp_predictor = MMPosePredictor(
+                pose_config=self.keypoint_model_config.get("model_config_path"),
+                deploy_config=self.keypoint_model_config.get("deploy_config"),
+                checkpoint=self.keypoint_model_config.get("checkpoint"),
+                device=self.keypoint_model_config.get("device")
+            )
+
+        # init action predictor model
+        self.action_model_config = self.cfg.get("action_model_config")
+        self.save_kp_npy = self.action_model_config.get("save_kp_npy")
+        self.npy_output_dir = self.action_model_config.get("npy_output_dir")
+        self.pkl_list = None
+        if self.save_kp_npy:
+            self.pkl_list = []
+        self.action_predictor = SkeletonActionPredictor(
+            model_config_path=self.action_model_config.get("model_config_path"),
+            label_path=self.action_model_config.get("action_label"),
+            checkpoint=self.action_model_config.get("checkpoint"),
+            device=self.action_model_config.get("device"),
+            item_max_size=self.action_model_config.get("item_max_size"),
+            save_kp_npy=self.save_kp_npy,
+            npy_outptut_dir=self.npy_output_dir,
+            dataset_format=self.action_model_config.get("dataset_format")
+        )
+        self.crop = Crop()
+
+        # init cluster predictor model
+        self.eps = self.cfg.get("eps")
+        self.min_samples = self.cfg.get("min_samples")
+        self.cluster_predictor = ClusterPredictor(eps=self.eps, min_samples=self.min_samples)
+
+        # init recorder
+        self.use_keypoint = self.cfg.get("use_keypoint")
+        if not self.use_keypoint:
+            self.crop_frame = True
+        else:
+            self.crop_frame = False
+        self.recoder = Recoder(output_side_size=self.output_side_size, item_max_size=self.cfg.get("item_max_size"),
+                               crop_frame=self.crop_frame, tracker_line_size=self.tracker_line_size,
+                               max_det=self.max_det, tracker_max_id=self.tracker_max_id)
+
+        # init input
+        self.input_source = self.cfg.get("input_source")
+        # init opencv video
+        self.save_result = self.cfg.get("save_result")
+        self.opencv_component = OpencvComponent(input_source=self.input_source, save_video=self.save_result)
+
+        # limited area warning
+        self.limited_area = self.cfg.get("limited_area")
+        self.limited_area_predictor = LimitedAreaPredictor(self.limited_area)
+
+        # init visualize
+        self.save_id_times = dict()
+        self.save_id_action_dict = dict()
+        self.show_result = self.cfg.get("show_result")
+        self.show_config = self.cfg.get("show_config")
+        self.draw_point_num = self.show_config.get("draw_point_num")  # 需要画出跟踪线长短
+        self.visualize = Visualize(id_max=self.tracker_max_id, person_attr=self.person_attr,
+                                   attr_map=self.attr_map, draw_point_num=self.draw_point_num)
+        self.kps_threshold = self.show_config.get("kps_threshold")
+
+
+
+    def run(self):
+
+        """
+        run video monitor
+        :return:
+        """
+        ret = True
+        frame_id = 0
+        cap = self.opencv_component.cap
+        while ret:
+            ret, frame = cap.read()
+            frame_id += 1
+            # if frame_id % 3 == 0 or frame_id % 3 == 1:
+            #     cv2.imshow("demo", frame)
+            #     continue
+            if not ret:
+                if self.save_result:
+                    self.opencv_component.video_writer.release()
+
+                if self.save_kp_npy:
+                    with open("{}/{}".format(self.npy_output_dir, "result.npy"), 'wb') as pkl_file:
+                        pickle.dump(self.pkl_list, pkl_file)
+                return
+
+            online_person_xyxy = []
+            online_ids = []
+            online_scores = []
+            online_kps = []
+            attrs = []
+            tracker_center_dict = None
+            cluster_bbox = None
+            self.timer.tic()
+            limited_area_bool = False
+
+            # The First stage: object detection prediction, the first stage for person detection
+            person_coor, frame_info, img_shape = self.od_predictor.predict(frame)
+
+            # tracker prediction
+            online_targets = self.tracker.update(person_coor, [frame_info['height'], frame_info['width']],
+                                                 img_shape[1:])
+            threads = []
+
+            for t in online_targets:
+                ltwh = t.tlwh
+                tid = t.track_id
+                vertical = ltwh[2] / ltwh[3] > 1.6
+                if ltwh[2] * ltwh[3] > self.min_box_area and not vertical:
+                    online_person_xyxy.append(bbox_xywh2xyxy(ltwh))
+                    online_ids.append(tid)
+                    online_scores.append(t.score)
+
+            if len(online_person_xyxy) > 0:
+                tracker_center_dict = boxes_center_coor(online_person_xyxy, online_ids)  # get bboxes center coordinate
+                # doing dbscan cluster by bounding box center coordinate
+                cluster_bbox = self.cluster_predictor.predict(online_person_xyxy=online_person_xyxy,
+                                                              tracker_center_dict=tracker_center_dict)
+                limited_area_bool = self.limited_area_predictor.predict(tracker_center_dict)
+
+            # The Second stage: attribute prediction(can detect person stage and helmet state)
+            if len(online_person_xyxy) > 0:
+                padding_crop_images, resize_info_list = self.crop.yolo_input_resize(frame, online_person_xyxy,
+                                                                                    online_ids)
+
+                # attrs, _, _ = self.pa_predictor.predict(image=padding_crop_images, scaleFill=True,
+                #                                         resize_info_list=resize_info_list)
+                t1 = MyThread(target=self.pa_predictor.predict, args=(padding_crop_images, True, resize_info_list))
+                threads.append(t1)
+
+
+
+            # pose recognition prediction
+            if self.use_keypoint:
+                if len(online_person_xyxy) != 0:
+                    # online_kps = self.kp_predictor.predict(frame=frame, bboxes=online_person_xyxy)
+                    t2 = MyThread(target=self.kp_predictor.predict, args=(frame, online_person_xyxy))
+                    threads.append(t2)
+
+            for t in threads:
+                t.start()
+            for t in threads:
+                t.join()
+
+            attrs, _, _ = threads[0].get_result()
+            online_kps = threads[1].get_result()
+
+
+
+
+
+
+
+            id_attrs = dict()
+
+            for attr in attrs:
+                id = attr.get("id")
+                id_attrs[id] = attr
+
+            attrs_list = []
+            for online_id in online_ids:
+                attrs_list.append(id_attrs.get(online_id))
+
+
+
+
+            # record
+            state, return_items = self.recoder.update(frame=frame,
+                                                      ids=online_ids,
+                                                      person_bboxes=online_person_xyxy,
+                                                      kps=online_kps,
+                                                      attrs=attrs_list,
+                                                      trackers=tracker_center_dict
+                                                      )
+
+            if state:  # state True -> start action recognition
+                # action recognition prediction
+                action_results, kp_dict_list = self.action_predictor.predict(return_items, frame_id)
+
+                for action_result in action_results:
+                    self.recoder.saver[action_result.id].action = action_result
+                    self.recoder.saver[action_result.id].frame.delete()
+
+                if self.pkl_list is not None and self.save_kp_npy:
+                    for kp_dict in kp_dict_list:
+                        self.pkl_list.append(kp_dict)
+
+            # get action label
+            _actions = self.action_predictor.get_actions(recoder=self.recoder, ids=online_ids)
+
+
+            # delete old action label
+            actions = []
+            for id, action in zip(online_ids, _actions):
+                if action is None:
+                    if id not in self.save_id_times:
+                        self.save_id_times[id] = 1
+                        self.save_id_action_dict[id] = None
+                    else:
+                        # keep some time to display action state, if it exceed lost time, it will be delete
+                        if self.save_id_times[id] > self.max_time_lost * 3:
+                            self.save_id_times.pop(id)
+                            self.save_id_action_dict.pop(id)
+                            actions.append(None)
+                            continue
+                        else:
+                            self.save_id_times[id] += 1
+                            if id in self.save_id_action_dict.keys():
+                                action = self.save_id_action_dict[id]
+                else:
+                    self.save_id_action_dict[id] = action
+                actions.append(action)
+
+            for return_item in return_items:
+                del_id = return_item.id
+                self.recoder.saver.pop(del_id)
+
+            self.timer.toc()
+            fps = 1. / self.timer.average_time
+
+            if self.show_result:
+                result = self.visualize.plot_result(frame=frame,
+                                                    frame_id=frame_id,
+                                                    fps=fps,
+                                                    attrs=attrs,
+                                                    ids=online_ids,
+                                                    person_bboxes=online_person_xyxy,
+                                                    kps=online_kps,
+                                                    actions=actions,
+                                                    kps_threshold=self.kps_threshold,
+                                                    return_image=True,
+                                                    tracker_center_dict_list=self.recoder.tracker_queue.tracker_list,
+                                                    tracker_direction=self.recoder.tracker_queue.direction,
+                                                    cluster_bbox=cluster_bbox,
+                                                    limited_area=self.limited_area,
+                                                    limited_area_bool=limited_area_bool
+
+                                                    )
+
+                if self.save_result:
+                    self.opencv_component.video_writer.write(result)
+
+            k = cv2.waitKey(1)
+            if k & 0xff == ord('q'):
+                break
+
+
+def main():
+    """
+    main function
+    """
+    opt = parse_args()
+    monitor = Monitor(opt)
+    monitor.run()
+
+
+if __name__ == "__main__":
+    main()