123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
- """
- import numpy as np
- from collections import defaultdict
- from ..matching import jde_matching as matching
- from ..motion import KalmanFilter
- from .base_jde_tracker import TrackState, STrack
- from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
- from ppdet.core.workspace import register, serializable
- from ppdet.utils.logger import setup_logger
- logger = setup_logger(__name__)
- __all__ = ['JDETracker']
- @register
- @serializable
- class JDETracker(object):
- __shared__ = ['num_classes']
- """
- JDE tracker, support single class and multi classes
- Args:
- use_byte (bool): Whether use ByteTracker, default False
- num_classes (int): the number of classes
- det_thresh (float): threshold of detection score
- track_buffer (int): buffer for tracker
- min_box_area (int): min box area to filter out low quality boxes
- vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
- bad results. If set <= 0 means no need to filter bboxes,usually set
- 1.6 for pedestrian tracking.
- tracked_thresh (float): linear assignment threshold of tracked
- stracks and detections
- r_tracked_thresh (float): linear assignment threshold of
- tracked stracks and unmatched detections
- unconfirmed_thresh (float): linear assignment threshold of
- unconfirmed stracks and unmatched detections
- conf_thres (float): confidence threshold for tracking, also used in
- ByteTracker as higher confidence threshold
- match_thres (float): linear assignment threshold of tracked
- stracks and detections in ByteTracker
- low_conf_thres (float): lower confidence threshold for tracking in
- ByteTracker
- input_size (list): input feature map size to reid model, [h, w] format,
- [64, 192] as default.
- motion (str): motion model, KalmanFilter as default
- metric_type (str): either "euclidean" or "cosine", the distance metric
- used for measurement to track association.
- """
- def __init__(self,
- use_byte=False,
- num_classes=1,
- det_thresh=0.3,
- track_buffer=30,
- min_box_area=0,
- vertical_ratio=0,
- tracked_thresh=0.7,
- r_tracked_thresh=0.5,
- unconfirmed_thresh=0.7,
- conf_thres=0,
- match_thres=0.8,
- low_conf_thres=0.2,
- input_size=[64, 192],
- motion='KalmanFilter',
- metric_type='euclidean'):
- self.use_byte = use_byte
- self.num_classes = num_classes
- self.det_thresh = det_thresh if not use_byte else conf_thres + 0.1
- self.track_buffer = track_buffer
- self.min_box_area = min_box_area
- self.vertical_ratio = vertical_ratio
- self.tracked_thresh = tracked_thresh
- self.r_tracked_thresh = r_tracked_thresh
- self.unconfirmed_thresh = unconfirmed_thresh
- self.conf_thres = conf_thres
- self.match_thres = match_thres
- self.low_conf_thres = low_conf_thres
- self.input_size = input_size
- if motion == 'KalmanFilter':
- self.motion = KalmanFilter()
- self.metric_type = metric_type
- self.frame_id = 0
- self.tracked_tracks_dict = defaultdict(list) # dict(list[STrack])
- self.lost_tracks_dict = defaultdict(list) # dict(list[STrack])
- self.removed_tracks_dict = defaultdict(list) # dict(list[STrack])
- self.max_time_lost = 0
- # max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
- def update(self, pred_dets, pred_embs=None):
- """
- Processes the image frame and finds bounding box(detections).
- Associates the detection with corresponding tracklets and also handles
- lost, removed, refound and active tracklets.
- Args:
- pred_dets (np.array): Detection results of the image, the shape is
- [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
- pred_embs (np.array): Embedding results of the image, the shape is
- [N, 128] or [N, 512].
- Return:
- output_stracks_dict (dict(list)): The list contains information
- regarding the online_tracklets for the received image tensor.
- """
- self.frame_id += 1
- if self.frame_id == 1:
- STrack.init_count(self.num_classes)
- activated_tracks_dict = defaultdict(list)
- refined_tracks_dict = defaultdict(list)
- lost_tracks_dict = defaultdict(list)
- removed_tracks_dict = defaultdict(list)
- output_tracks_dict = defaultdict(list)
- pred_dets_dict = defaultdict(list)
- pred_embs_dict = defaultdict(list)
- # unify single and multi classes detection and embedding results
- for cls_id in range(self.num_classes):
- cls_idx = (pred_dets[:, 0:1] == cls_id).squeeze(-1)
- pred_dets_dict[cls_id] = pred_dets[cls_idx]
- if pred_embs is not None:
- pred_embs_dict[cls_id] = pred_embs[cls_idx]
- else:
- pred_embs_dict[cls_id] = None
- for cls_id in range(self.num_classes):
- """ Step 1: Get detections by class"""
- pred_dets_cls = pred_dets_dict[cls_id]
- pred_embs_cls = pred_embs_dict[cls_id]
- remain_inds = (pred_dets_cls[:, 1:2] > self.conf_thres).squeeze(-1)
- if remain_inds.sum() > 0:
- pred_dets_cls = pred_dets_cls[remain_inds]
- if pred_embs_cls is None:
- # in original ByteTrack
- detections = [
- STrack(
- STrack.tlbr_to_tlwh(tlbrs[2:6]),
- tlbrs[1],
- cls_id,
- 30,
- temp_feat=None) for tlbrs in pred_dets_cls
- ]
- else:
- pred_embs_cls = pred_embs_cls[remain_inds]
- detections = [
- STrack(
- STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id,
- 30, temp_feat) for (tlbrs, temp_feat) in
- zip(pred_dets_cls, pred_embs_cls)
- ]
- else:
- detections = []
- ''' Add newly detected tracklets to tracked_stracks'''
- unconfirmed_dict = defaultdict(list)
- tracked_tracks_dict = defaultdict(list)
- for track in self.tracked_tracks_dict[cls_id]:
- if not track.is_activated:
- # previous tracks which are not active in the current frame are added in unconfirmed list
- unconfirmed_dict[cls_id].append(track)
- else:
- # Active tracks are added to the local list 'tracked_stracks'
- tracked_tracks_dict[cls_id].append(track)
- """ Step 2: First association, with embedding"""
- # building tracking pool for the current frame
- track_pool_dict = defaultdict(list)
- track_pool_dict[cls_id] = joint_stracks(
- tracked_tracks_dict[cls_id], self.lost_tracks_dict[cls_id])
- # Predict the current location with KalmanFilter
- STrack.multi_predict(track_pool_dict[cls_id], self.motion)
- if pred_embs_cls is None:
- # in original ByteTrack
- dists = matching.iou_distance(track_pool_dict[cls_id],
- detections)
- matches, u_track, u_detection = matching.linear_assignment(
- dists, thresh=self.match_thres) # not self.tracked_thresh
- else:
- dists = matching.embedding_distance(
- track_pool_dict[cls_id],
- detections,
- metric=self.metric_type)
- dists = matching.fuse_motion(
- self.motion, dists, track_pool_dict[cls_id], detections)
- matches, u_track, u_detection = matching.linear_assignment(
- dists, thresh=self.tracked_thresh)
- for i_tracked, idet in matches:
- # i_tracked is the id of the track and idet is the detection
- track = track_pool_dict[cls_id][i_tracked]
- det = detections[idet]
- if track.state == TrackState.Tracked:
- # If the track is active, add the detection to the track
- track.update(detections[idet], self.frame_id)
- activated_tracks_dict[cls_id].append(track)
- else:
- # We have obtained a detection from a track which is not active,
- # hence put the track in refind_stracks list
- track.re_activate(det, self.frame_id, new_id=False)
- refined_tracks_dict[cls_id].append(track)
- # None of the steps below happen if there are no undetected tracks.
- """ Step 3: Second association, with IOU"""
- if self.use_byte:
- inds_low = pred_dets_dict[cls_id][:, 1:2] > self.low_conf_thres
- inds_high = pred_dets_dict[cls_id][:, 1:2] < self.conf_thres
- inds_second = np.logical_and(inds_low, inds_high).squeeze(-1)
- pred_dets_cls_second = pred_dets_dict[cls_id][inds_second]
- # association the untrack to the low score detections
- if len(pred_dets_cls_second) > 0:
- if pred_embs_dict[cls_id] is None:
- # in original ByteTrack
- detections_second = [
- STrack(
- STrack.tlbr_to_tlwh(tlbrs[2:6]),
- tlbrs[1],
- cls_id,
- 30,
- temp_feat=None)
- for tlbrs in pred_dets_cls_second
- ]
- else:
- pred_embs_cls_second = pred_embs_dict[cls_id][
- inds_second]
- detections_second = [
- STrack(
- STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1],
- cls_id, 30, temp_feat) for (tlbrs, temp_feat) in
- zip(pred_dets_cls_second, pred_embs_cls_second)
- ]
- else:
- detections_second = []
- r_tracked_stracks = [
- track_pool_dict[cls_id][i] for i in u_track
- if track_pool_dict[cls_id][i].state == TrackState.Tracked
- ]
- dists = matching.iou_distance(r_tracked_stracks,
- detections_second)
- matches, u_track, u_detection_second = matching.linear_assignment(
- dists, thresh=0.4) # not r_tracked_thresh
- else:
- detections = [detections[i] for i in u_detection]
- r_tracked_stracks = []
- for i in u_track:
- if track_pool_dict[cls_id][i].state == TrackState.Tracked:
- r_tracked_stracks.append(track_pool_dict[cls_id][i])
- dists = matching.iou_distance(r_tracked_stracks, detections)
- matches, u_track, u_detection = matching.linear_assignment(
- dists, thresh=self.r_tracked_thresh)
- for i_tracked, idet in matches:
- track = r_tracked_stracks[i_tracked]
- det = detections[
- idet] if not self.use_byte else detections_second[idet]
- if track.state == TrackState.Tracked:
- track.update(det, self.frame_id)
- activated_tracks_dict[cls_id].append(track)
- else:
- track.re_activate(det, self.frame_id, new_id=False)
- refined_tracks_dict[cls_id].append(track)
- for it in u_track:
- track = r_tracked_stracks[it]
- if not track.state == TrackState.Lost:
- track.mark_lost()
- lost_tracks_dict[cls_id].append(track)
- '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
- detections = [detections[i] for i in u_detection]
- dists = matching.iou_distance(unconfirmed_dict[cls_id], detections)
- matches, u_unconfirmed, u_detection = matching.linear_assignment(
- dists, thresh=self.unconfirmed_thresh)
- for i_tracked, idet in matches:
- unconfirmed_dict[cls_id][i_tracked].update(detections[idet],
- self.frame_id)
- activated_tracks_dict[cls_id].append(unconfirmed_dict[cls_id][
- i_tracked])
- for it in u_unconfirmed:
- track = unconfirmed_dict[cls_id][it]
- track.mark_removed()
- removed_tracks_dict[cls_id].append(track)
- """ Step 4: Init new stracks"""
- for inew in u_detection:
- track = detections[inew]
- if track.score < self.det_thresh:
- continue
- track.activate(self.motion, self.frame_id)
- activated_tracks_dict[cls_id].append(track)
- """ Step 5: Update state"""
- for track in self.lost_tracks_dict[cls_id]:
- if self.frame_id - track.end_frame > self.max_time_lost:
- track.mark_removed()
- removed_tracks_dict[cls_id].append(track)
- self.tracked_tracks_dict[cls_id] = [
- t for t in self.tracked_tracks_dict[cls_id]
- if t.state == TrackState.Tracked
- ]
- self.tracked_tracks_dict[cls_id] = joint_stracks(
- self.tracked_tracks_dict[cls_id], activated_tracks_dict[cls_id])
- self.tracked_tracks_dict[cls_id] = joint_stracks(
- self.tracked_tracks_dict[cls_id], refined_tracks_dict[cls_id])
- self.lost_tracks_dict[cls_id] = sub_stracks(
- self.lost_tracks_dict[cls_id], self.tracked_tracks_dict[cls_id])
- self.lost_tracks_dict[cls_id].extend(lost_tracks_dict[cls_id])
- self.lost_tracks_dict[cls_id] = sub_stracks(
- self.lost_tracks_dict[cls_id], self.removed_tracks_dict[cls_id])
- self.removed_tracks_dict[cls_id].extend(removed_tracks_dict[cls_id])
- self.tracked_tracks_dict[cls_id], self.lost_tracks_dict[
- cls_id] = remove_duplicate_stracks(
- self.tracked_tracks_dict[cls_id],
- self.lost_tracks_dict[cls_id])
- # get scores of lost tracks
- output_tracks_dict[cls_id] = [
- track for track in self.tracked_tracks_dict[cls_id]
- if track.is_activated
- ]
- logger.debug('===========Frame {}=========='.format(self.frame_id))
- logger.debug('Activated: {}'.format(
- [track.track_id for track in activated_tracks_dict[cls_id]]))
- logger.debug('Refind: {}'.format(
- [track.track_id for track in refined_tracks_dict[cls_id]]))
- logger.debug('Lost: {}'.format(
- [track.track_id for track in lost_tracks_dict[cls_id]]))
- logger.debug('Removed: {}'.format(
- [track.track_id for track in removed_tracks_dict[cls_id]]))
- return output_tracks_dict
|