byte_tracker.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import torch
  2. import numpy as np
  3. from utils.kalman_filter import KalmanFilter
  4. from dev.utils.log import logger
  5. from models import *
  6. from tracker import matching
  7. from .basetrack import BaseTrack, TrackState
  8. class STrack(BaseTrack):
  9. def __init__(self, tlwh, score):
  10. # wait activate
  11. self._tlwh = np.asarray(tlwh, dtype=np.float)
  12. self.kalman_filter = None
  13. self.mean, self.covariance = None, None
  14. self.is_activated = False
  15. self.score = score
  16. self.tracklet_len = 0
  17. def predict(self):
  18. mean_state = self.mean.copy()
  19. if self.state != TrackState.Tracked:
  20. mean_state[7] = 0
  21. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  22. @staticmethod
  23. def multi_predict(stracks, kalman_filter):
  24. if len(stracks) > 0:
  25. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  26. multi_covariance = np.asarray([st.covariance for st in stracks])
  27. for i, st in enumerate(stracks):
  28. if st.state != TrackState.Tracked:
  29. multi_mean[i][7] = 0
  30. # multi_mean, multi_covariance = STrack.kalman_filter.multi_predict(multi_mean, multi_covariance)
  31. multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance)
  32. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  33. stracks[i].mean = mean
  34. stracks[i].covariance = cov
  35. def activate(self, kalman_filter, frame_id):
  36. """Start a new tracklet"""
  37. self.kalman_filter = kalman_filter
  38. self.track_id = self.next_id()
  39. self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
  40. self.tracklet_len = 0
  41. self.state = TrackState.Tracked
  42. #self.is_activated = True
  43. self.frame_id = frame_id
  44. self.start_frame = frame_id
  45. def re_activate(self, new_track, frame_id, new_id=False):
  46. self.mean, self.covariance = self.kalman_filter.update(
  47. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
  48. )
  49. self.tracklet_len = 0
  50. self.state = TrackState.Tracked
  51. self.is_activated = True
  52. self.frame_id = frame_id
  53. if new_id:
  54. self.track_id = self.next_id()
  55. def update(self, new_track, frame_id, update_feature=True):
  56. """
  57. Update a matched track
  58. :type new_track: STrack
  59. :type frame_id: int
  60. :type update_feature: bool
  61. :return:
  62. """
  63. self.frame_id = frame_id
  64. self.tracklet_len += 1
  65. new_tlwh = new_track.tlwh
  66. self.mean, self.covariance = self.kalman_filter.update(
  67. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  68. self.state = TrackState.Tracked
  69. self.is_activated = True
  70. self.score = new_track.score
  71. @property
  72. def tlwh(self):
  73. """Get current position in bounding box format `(top left x, top left y,
  74. width, height)`.
  75. """
  76. if self.mean is None:
  77. return self._tlwh.copy()
  78. ret = self.mean[:4].copy()
  79. ret[2] *= ret[3]
  80. ret[:2] -= ret[2:] / 2
  81. return ret
  82. @property
  83. def tlbr(self):
  84. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  85. `(top left, bottom right)`.
  86. """
  87. ret = self.tlwh.copy()
  88. ret[2:] += ret[:2]
  89. return ret
  90. @staticmethod
  91. def tlwh_to_xyah(tlwh):
  92. """Convert bounding box to format `(center x, center y, aspect ratio,
  93. height)`, where the aspect ratio is `width / height`.
  94. """
  95. ret = np.asarray(tlwh).copy()
  96. ret[:2] += ret[2:] / 2
  97. ret[2] /= ret[3]
  98. return ret
  99. def to_xyah(self):
  100. return self.tlwh_to_xyah(self.tlwh)
  101. @staticmethod
  102. def tlbr_to_tlwh(tlbr):
  103. ret = np.asarray(tlbr).copy()
  104. ret[2:] -= ret[:2]
  105. return ret
  106. @staticmethod
  107. def tlwh_to_tlbr(tlwh):
  108. ret = np.asarray(tlwh).copy()
  109. ret[2:] += ret[:2]
  110. return ret
  111. def __repr__(self):
  112. return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
  113. class BYTETracker(object):
  114. def __init__(self, opt, frame_rate=30):
  115. self.opt = opt
  116. self.model = Darknet(opt.cfg, nID=14455)
  117. # load_darknet_weights(self.model, opt.weights)
  118. self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)
  119. self.model.cuda().eval()
  120. self.tracked_stracks = [] # type: list[STrack]
  121. self.lost_stracks = [] # type: list[STrack]
  122. self.removed_stracks = [] # type: list[STrack]
  123. self.frame_id = 0
  124. self.det_thresh = opt.conf_thres
  125. self.init_thresh = self.det_thresh + 0.2
  126. self.low_thresh = 0.3
  127. self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
  128. self.max_time_lost = self.buffer_size
  129. self.kalman_filter = KalmanFilter()
  130. def update(self, im_blob, img0):
  131. """
  132. Processes the image frame and finds bounding box(detections).
  133. Associates the detection with corresponding tracklets and also handles lost, removed, refound and active tracklets
  134. Parameters
  135. ----------
  136. im_blob : torch.float32
  137. Tensor of shape depending upon the size of image. By default, shape of this tensor is [1, 3, 608, 1088]
  138. img0 : ndarray
  139. ndarray of shape depending on the input image sequence. By default, shape is [608, 1080, 3]
  140. Returns
  141. -------
  142. output_stracks : list of Strack(instances)
  143. The list contains information regarding the online_tracklets for the recieved image tensor.
  144. """
  145. self.frame_id += 1
  146. activated_starcks = [] # for storing active tracks, for the current frame
  147. refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
  148. lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
  149. removed_stracks = []
  150. t1 = time.time()
  151. ''' Step 1: Network forward, get detections & embeddings'''
  152. with torch.no_grad():
  153. pred = self.model(im_blob)
  154. # pred is tensor of all the proposals (default number of proposals: 54264). Proposals have information associated with the bounding box and embeddings
  155. pred = pred[pred[:, :, 4] > self.low_thresh]
  156. # pred now has lesser number of proposals. Proposals rejected on basis of object confidence score
  157. if len(pred) > 0:
  158. dets = non_max_suppression(pred.unsqueeze(0), self.low_thresh, self.opt.nms_thres)[0].cpu()
  159. # Final proposals are obtained in dets. Information of bounding box and embeddings also included
  160. # Next step changes the detection scales
  161. scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
  162. '''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)'''
  163. # class_pred is the embeddings.
  164. dets = dets.numpy()
  165. remain_inds = dets[:, 4] > self.det_thresh
  166. inds_low = dets[:, 4] > self.low_thresh
  167. inds_high = dets[:, 4] < self.det_thresh
  168. inds_second = np.logical_and(inds_low, inds_high)
  169. dets_second = dets[inds_second]
  170. dets = dets[remain_inds]
  171. detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4]) for
  172. tlbrs in dets[:, :5]]
  173. else:
  174. detections = []
  175. dets_second = []
  176. t2 = time.time()
  177. # print('Forward: {} s'.format(t2-t1))
  178. ''' Add newly detected tracklets to tracked_stracks'''
  179. unconfirmed = []
  180. tracked_stracks = [] # type: list[STrack]
  181. for track in self.tracked_stracks:
  182. if not track.is_activated:
  183. # previous tracks which are not active in the current frame are added in unconfirmed list
  184. unconfirmed.append(track)
  185. # print("Should not be here, in unconfirmed")
  186. else:
  187. # Active tracks are added to the local list 'tracked_stracks'
  188. tracked_stracks.append(track)
  189. ''' Step 2: First association, with embedding'''
  190. # Combining currently tracked_stracks and lost_stracks
  191. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  192. # Predict the current location with KF
  193. STrack.multi_predict(strack_pool, self.kalman_filter)
  194. dists = matching.iou_distance(strack_pool, detections)
  195. # The dists is the list of distances of the detection with the tracks in strack_pool
  196. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.8)
  197. # The matches is the array for corresponding matches of the detection with the corresponding strack_pool
  198. for itracked, idet in matches:
  199. # itracked is the id of the track and idet is the detection
  200. track = strack_pool[itracked]
  201. det = detections[idet]
  202. if track.state == TrackState.Tracked:
  203. # If the track is active, add the detection to the track
  204. track.update(detections[idet], self.frame_id)
  205. activated_starcks.append(track)
  206. else:
  207. # We have obtained a detection from a track which is not active, hence put the track in refind_stracks list
  208. track.re_activate(det, self.frame_id, new_id=False)
  209. refind_stracks.append(track)
  210. # association the untrack to the low score detections
  211. if len(dets_second) > 0:
  212. detections_second = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4]) for
  213. tlbrs in dets_second[:, :5]]
  214. else:
  215. detections_second = []
  216. r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
  217. dists = matching.iou_distance(r_tracked_stracks, detections_second)
  218. matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.4)
  219. for itracked, idet in matches:
  220. track = r_tracked_stracks[itracked]
  221. det = detections_second[idet]
  222. if track.state == TrackState.Tracked:
  223. track.update(det, self.frame_id)
  224. activated_starcks.append(track)
  225. else:
  226. track.re_activate(det, self.frame_id, new_id=False)
  227. refind_stracks.append(track)
  228. for it in u_track:
  229. track = r_tracked_stracks[it]
  230. if not track.state == TrackState.Lost:
  231. track.mark_lost()
  232. lost_stracks.append(track)
  233. # If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
  234. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  235. detections = [detections[i] for i in u_detection]
  236. dists = matching.iou_distance(unconfirmed, detections)
  237. matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
  238. for itracked, idet in matches:
  239. unconfirmed[itracked].update(detections[idet], self.frame_id)
  240. activated_starcks.append(unconfirmed[itracked])
  241. # The tracks which are yet not matched
  242. for it in u_unconfirmed:
  243. track = unconfirmed[it]
  244. track.mark_removed()
  245. removed_stracks.append(track)
  246. # after all these confirmation steps, if a new detection is found, it is initialized for a new track
  247. """ Step 4: Init new stracks"""
  248. for inew in u_detection:
  249. track = detections[inew]
  250. if track.score < self.init_thresh:
  251. continue
  252. track.activate(self.kalman_filter, self.frame_id)
  253. activated_starcks.append(track)
  254. """ Step 5: Update state"""
  255. # If the tracks are lost for more frames than the threshold number, the tracks are removed.
  256. for track in self.lost_stracks:
  257. if self.frame_id - track.end_frame > self.max_time_lost:
  258. track.mark_removed()
  259. removed_stracks.append(track)
  260. # print('Remained match {} s'.format(t4-t3))
  261. # Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
  262. self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
  263. self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
  264. self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
  265. # self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
  266. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  267. self.lost_stracks.extend(lost_stracks)
  268. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  269. self.removed_stracks.extend(removed_stracks)
  270. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
  271. # get scores of lost tracks
  272. output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  273. logger.debug('===========Frame {}=========='.format(self.frame_id))
  274. logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
  275. logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
  276. logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
  277. logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
  278. # print('Final {} s'.format(t5-t4))
  279. return output_stracks
  280. def joint_stracks(tlista, tlistb):
  281. exists = {}
  282. res = []
  283. for t in tlista:
  284. exists[t.track_id] = 1
  285. res.append(t)
  286. for t in tlistb:
  287. tid = t.track_id
  288. if not exists.get(tid, 0):
  289. exists[tid] = 1
  290. res.append(t)
  291. return res
  292. def sub_stracks(tlista, tlistb):
  293. stracks = {}
  294. for t in tlista:
  295. stracks[t.track_id] = t
  296. for t in tlistb:
  297. tid = t.track_id
  298. if stracks.get(tid, 0):
  299. del stracks[tid]
  300. return list(stracks.values())
  301. def remove_duplicate_stracks(stracksa, stracksb):
  302. pdist = matching.iou_distance(stracksa, stracksb)
  303. pairs = np.where(pdist<0.15)
  304. dupa, dupb = list(), list()
  305. for p,q in zip(*pairs):
  306. timep = stracksa[p].frame_id - stracksa[p].start_frame
  307. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  308. if timep > timeq:
  309. dupb.append(q)
  310. else:
  311. dupa.append(p)
  312. resa = [t for i,t in enumerate(stracksa) if not i in dupa]
  313. resb = [t for i,t in enumerate(stracksb) if not i in dupb]
  314. return resa, resb