123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- """
- Copyright (c) https://github.com/xingyizhou/CenterTrack
- Modified by Peize Sun, Rufeng Zhang
- """
- # coding: utf-8
- import torch
- from scipy.optimize import linear_sum_assignment
- from util import box_ops
- import copy
- class Tracker(object):
- def __init__(self, score_thresh, max_age=32):
- self.score_thresh = score_thresh
- self.low_thresh = 0.2
- self.high_thresh = score_thresh + 0.1
- self.max_age = max_age
- self.id_count = 0
- self.tracks_dict = dict()
- self.tracks = list()
- self.unmatched_tracks = list()
- self.reset_all()
-
- def reset_all(self):
- self.id_count = 0
- self.tracks_dict = dict()
- self.tracks = list()
- self.unmatched_tracks = list()
-
- def init_track(self, results):
- scores = results["scores"]
- classes = results["labels"]
- bboxes = results["boxes"] # x1y1x2y2
-
- ret = list()
- ret_dict = dict()
- for idx in range(scores.shape[0]):
- if scores[idx] >= self.score_thresh:
- self.id_count += 1
- obj = dict()
- obj["score"] = float(scores[idx])
- obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
- obj["tracking_id"] = self.id_count
- obj['active'] = 1
- obj['age'] = 1
- ret.append(obj)
- ret_dict[idx] = obj
-
- self.tracks = ret
- self.tracks_dict = ret_dict
- return copy.deepcopy(ret)
-
- def step(self, output_results):
- scores = output_results["scores"]
- bboxes = output_results["boxes"] # x1y1x2y2
- track_bboxes = output_results["track_boxes"] if "track_boxes" in output_results else None # x1y1x2y2
-
- results = list()
- results_dict = dict()
- results_second = list()
- tracks = list()
-
- for idx in range(scores.shape[0]):
- if idx in self.tracks_dict and track_bboxes is not None:
- self.tracks_dict[idx]["bbox"] = track_bboxes[idx, :].cpu().numpy().tolist()
- if scores[idx] >= self.score_thresh:
- obj = dict()
- obj["score"] = float(scores[idx])
- obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
- results.append(obj)
- results_dict[idx] = obj
- elif scores[idx] >= self.low_thresh:
- second_obj = dict()
- second_obj["score"] = float(scores[idx])
- second_obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
- results_second.append(second_obj)
- results_dict[idx] = second_obj
-
- tracks = [v for v in self.tracks_dict.values()] + self.unmatched_tracks
- # for trackss in tracks:
- # print(trackss.keys())
- N = len(results)
- M = len(tracks)
-
- ret = list()
- unmatched_tracks = [t for t in range(M)]
- unmatched_dets = [d for d in range(N)]
- if N > 0 and M > 0:
- det_box = torch.stack([torch.tensor(obj['bbox']) for obj in results], dim=0) # N x 4
- track_box = torch.stack([torch.tensor(obj['bbox']) for obj in tracks], dim=0) # M x 4
- cost_bbox = 1.0 - box_ops.generalized_box_iou(det_box, track_box) # N x M
- matched_indices = linear_sum_assignment(cost_bbox)
- unmatched_dets = [d for d in range(N) if not (d in matched_indices[0])]
- unmatched_tracks = [d for d in range(M) if not (d in matched_indices[1])]
- matches = [[],[]]
- for (m0, m1) in zip(matched_indices[0], matched_indices[1]):
- if cost_bbox[m0, m1] > 1.2:
- unmatched_dets.append(m0)
- unmatched_tracks.append(m1)
- else:
- matches[0].append(m0)
- matches[1].append(m1)
- for (m0, m1) in zip(matches[0], matches[1]):
- track = results[m0]
- track['tracking_id'] = tracks[m1]['tracking_id']
- track['age'] = 1
- track['active'] = 1
- ret.append(track)
-
- # second association
- N_second = len(results_second)
- unmatched_tracks_obj = list()
- for i in unmatched_tracks:
- #print(tracks[i].keys())
- track = tracks[i]
- if track['active'] == 1:
- unmatched_tracks_obj.append(track)
- M_second = len(unmatched_tracks_obj)
- unmatched_tracks_second = [t for t in range(M_second)]
- if N_second > 0 and M_second > 0:
- det_box_second = torch.stack([torch.tensor(obj['bbox']) for obj in results_second], dim=0) # N_second x 4
- track_box_second = torch.stack([torch.tensor(obj['bbox']) for obj in unmatched_tracks_obj], dim=0) # M_second x 4
- cost_bbox_second = 1.0 - box_ops.generalized_box_iou(det_box_second, track_box_second) # N_second x M_second
- matched_indices_second = linear_sum_assignment(cost_bbox_second)
- unmatched_tracks_second = [d for d in range(M_second) if not (d in matched_indices_second[1])]
- matches_second = [[],[]]
- for (m0, m1) in zip(matched_indices_second[0], matched_indices_second[1]):
- if cost_bbox_second[m0, m1] > 0.8:
- unmatched_tracks_second.append(m1)
- else:
- matches_second[0].append(m0)
- matches_second[1].append(m1)
- for (m0, m1) in zip(matches_second[0], matches_second[1]):
- track = results_second[m0]
- track['tracking_id'] = unmatched_tracks_obj[m1]['tracking_id']
- track['age'] = 1
- track['active'] = 1
- ret.append(track)
- for i in unmatched_dets:
- trackd = results[i]
- if trackd["score"] >= self.high_thresh:
- self.id_count += 1
- trackd['tracking_id'] = self.id_count
- trackd['age'] = 1
- trackd['active'] = 1
- ret.append(trackd)
-
- # ------------------------------------------------------ #
- ret_unmatched_tracks = []
-
- for j in unmatched_tracks:
- track = tracks[j]
- if track['active'] == 0 and track['age'] < self.max_age:
- track['age'] += 1
- track['active'] = 0
- ret.append(track)
- ret_unmatched_tracks.append(track)
- for i in unmatched_tracks_second:
- track = unmatched_tracks_obj[i]
- if track['age'] < self.max_age:
- track['age'] += 1
- track['active'] = 0
- ret.append(track)
- ret_unmatched_tracks.append(track)
-
- # for i in unmatched_tracks:
- # track = tracks[i]
- # if track['age'] < self.max_age:
- # track['age'] += 1
- # track['active'] = 0
- # ret.append(track)
- # ret_unmatched_tracks.append(track)
- #print(len(ret_unmatched_tracks))
-
- self.tracks = ret
- self.tracks_dict = {red_ind:red for red_ind, red in results_dict.items() if 'tracking_id' in red}
- self.unmatched_tracks = ret_unmatched_tracks
- return copy.deepcopy(ret)
|