byte_tracker.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. from collections import deque
  2. import os
  3. import cv2
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from torchsummary import summary
  8. from core.mot.general import non_max_suppression_and_inds, non_max_suppression_jde, non_max_suppression, scale_coords
  9. from core.mot.torch_utils import intersect_dicts
  10. from models.mot.cstrack import Model
  11. from mot_online import matching
  12. from mot_online.kalman_filter import KalmanFilter
  13. from mot_online.log import logger
  14. from mot_online.utils import *
  15. from mot_online.basetrack import BaseTrack, TrackState
  16. class STrack(BaseTrack):
  17. shared_kalman = KalmanFilter()
  18. def __init__(self, tlwh, score):
  19. # wait activate
  20. self._tlwh = np.asarray(tlwh, dtype=np.float)
  21. self.kalman_filter = None
  22. self.mean, self.covariance = None, None
  23. self.is_activated = False
  24. self.score = score
  25. self.tracklet_len = 0
  26. def predict(self):
  27. mean_state = self.mean.copy()
  28. if self.state != TrackState.Tracked:
  29. mean_state[7] = 0
  30. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  31. @staticmethod
  32. def multi_predict(stracks):
  33. if len(stracks) > 0:
  34. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  35. multi_covariance = np.asarray([st.covariance for st in stracks])
  36. for i, st in enumerate(stracks):
  37. if st.state != TrackState.Tracked:
  38. multi_mean[i][7] = 0
  39. multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
  40. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  41. stracks[i].mean = mean
  42. stracks[i].covariance = cov
  43. def activate(self, kalman_filter, frame_id):
  44. """Start a new tracklet"""
  45. self.kalman_filter = kalman_filter
  46. self.track_id = self.next_id()
  47. self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
  48. self.tracklet_len = 0
  49. self.state = TrackState.Tracked
  50. #self.is_activated = True
  51. self.frame_id = frame_id
  52. self.start_frame = frame_id
  53. def re_activate(self, new_track, frame_id, new_id=False):
  54. self.mean, self.covariance = self.kalman_filter.update(
  55. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
  56. )
  57. self.tracklet_len = 0
  58. self.state = TrackState.Tracked
  59. self.is_activated = True
  60. self.frame_id = frame_id
  61. if new_id:
  62. self.track_id = self.next_id()
  63. def update(self, new_track, frame_id):
  64. """
  65. Update a matched track
  66. :type new_track: STrack
  67. :type frame_id: int
  68. :type update_feature: bool
  69. :return:
  70. """
  71. self.frame_id = frame_id
  72. self.tracklet_len += 1
  73. new_tlwh = new_track.tlwh
  74. self.mean, self.covariance = self.kalman_filter.update(
  75. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  76. self.state = TrackState.Tracked
  77. self.is_activated = True
  78. self.score = new_track.score
  79. @property
  80. # @jit(nopython=True)
  81. def tlwh(self):
  82. """Get current position in bounding box format `(top left x, top left y,
  83. width, height)`.
  84. """
  85. if self.mean is None:
  86. return self._tlwh.copy()
  87. ret = self.mean[:4].copy()
  88. ret[2] *= ret[3]
  89. ret[:2] -= ret[2:] / 2
  90. return ret
  91. @property
  92. # @jit(nopython=True)
  93. def tlbr(self):
  94. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  95. `(top left, bottom right)`.
  96. """
  97. ret = self.tlwh.copy()
  98. ret[2:] += ret[:2]
  99. return ret
  100. @staticmethod
  101. # @jit(nopython=True)
  102. def tlwh_to_xyah(tlwh):
  103. """Convert bounding box to format `(center x, center y, aspect ratio,
  104. height)`, where the aspect ratio is `width / height`.
  105. """
  106. ret = np.asarray(tlwh).copy()
  107. ret[:2] += ret[2:] / 2
  108. ret[2] /= ret[3]
  109. return ret
  110. def to_xyah(self):
  111. return self.tlwh_to_xyah(self.tlwh)
  112. @staticmethod
  113. # @jit(nopython=True)
  114. def tlbr_to_tlwh(tlbr):
  115. ret = np.asarray(tlbr).copy()
  116. ret[2:] -= ret[:2]
  117. return ret
  118. @staticmethod
  119. # @jit(nopython=True)
  120. def tlwh_to_tlbr(tlwh):
  121. ret = np.asarray(tlwh).copy()
  122. ret[2:] += ret[:2]
  123. return ret
  124. def __repr__(self):
  125. return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
  126. class BYTETracker(object):
  127. def __init__(self, opt, frame_rate=30):
  128. self.opt = opt
  129. if int(opt.gpus[0]) >= 0:
  130. opt.device = torch.device('cuda')
  131. else:
  132. opt.device = torch.device('cpu')
  133. print('Creating model...')
  134. ckpt = torch.load(opt.weights, map_location=opt.device) # load checkpoint
  135. self.model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=1).to(opt.device) # create
  136. exclude = ['anchor'] if opt.cfg else [] # exclude keys
  137. if type(ckpt['model']).__name__ == "OrderedDict":
  138. state_dict = ckpt['model']
  139. else:
  140. state_dict = ckpt['model'].float().state_dict() # to FP32
  141. state_dict = intersect_dicts(state_dict, self.model.state_dict(), exclude=exclude) # intersect
  142. self.model.load_state_dict(state_dict, strict=False) # load
  143. self.model.cuda().eval()
  144. total_params = sum(p.numel() for p in self.model.parameters())
  145. print(f'{total_params:,} total parameters.')
  146. self.tracked_stracks = [] # type: list[STrack]
  147. self.lost_stracks = [] # type: list[STrack]
  148. self.removed_stracks = [] # type: list[STrack]
  149. self.frame_id = 0
  150. self.det_thresh = opt.conf_thres
  151. self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
  152. self.max_time_lost = self.buffer_size
  153. self.mean = np.array(opt.mean, dtype=np.float32).reshape(1, 1, 3)
  154. self.std = np.array(opt.std, dtype=np.float32).reshape(1, 1, 3)
  155. self.kalman_filter = KalmanFilter()
  156. self.low_thres = 0.1
  157. self.high_thres = self.opt.conf_thres + 0.1
  158. def update(self, im_blob, img0,seq_num, save_dir):
  159. self.frame_id += 1
  160. activated_starcks = []
  161. refind_stracks = []
  162. lost_stracks = []
  163. removed_stracks = []
  164. dets = []
  165. ''' Step 1: Network forward, get detections & embeddings'''
  166. with torch.no_grad():
  167. output = self.model(im_blob, augment=False)
  168. pred, train_out = output[1]
  169. pred = pred[pred[:, :, 4] > self.low_thres]
  170. detections = []
  171. if len(pred) > 0:
  172. dets,x_inds,y_inds = non_max_suppression_and_inds(pred[:,:6].unsqueeze(0), 0.1, self.opt.nms_thres,method='cluster_diou')
  173. dets = dets.numpy()
  174. if len(dets) != 0:
  175. scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
  176. remain_inds = dets[:, 4] > self.opt.conf_thres
  177. inds_low = dets[:, 4] > self.low_thres
  178. inds_high = dets[:, 4] < self.opt.conf_thres
  179. inds_second = np.logical_and(inds_low, inds_high)
  180. dets_second = dets[inds_second]
  181. dets = dets[remain_inds]
  182. detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4]) for
  183. tlbrs in dets[:, :5]]
  184. else:
  185. detections = []
  186. dets_second = []
  187. else:
  188. detections = []
  189. dets_second = []
  190. ''' Add newly detected tracklets to tracked_stracks'''
  191. unconfirmed = []
  192. tracked_stracks = [] # type: list[STrack]
  193. for track in self.tracked_stracks:
  194. if not track.is_activated:
  195. unconfirmed.append(track)
  196. else:
  197. tracked_stracks.append(track)
  198. ''' Step 2: First association, with embedding'''
  199. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  200. # Predict the current location with KF
  201. STrack.multi_predict(strack_pool)
  202. dists = matching.iou_distance(strack_pool, detections)
  203. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.8)
  204. for itracked, idet in matches:
  205. track = strack_pool[itracked]
  206. det = detections[idet]
  207. if track.state == TrackState.Tracked:
  208. track.update(detections[idet], self.frame_id)
  209. activated_starcks.append(track)
  210. else:
  211. track.re_activate(det, self.frame_id, new_id=False)
  212. refind_stracks.append(track)
  213. # vis
  214. track_features, det_features, cost_matrix, cost_matrix_det, cost_matrix_track = [],[],[],[],[]
  215. if self.opt.vis_state == 1 and self.frame_id % 20 == 0:
  216. if len(dets) != 0:
  217. for i in range(0, dets.shape[0]):
  218. bbox = dets[i][0:4]
  219. cv2.rectangle(img0, (int(bbox[0]), int(bbox[1])),(int(bbox[2]), int(bbox[3])),(0, 255, 0), 2)
  220. track_features, det_features, cost_matrix, cost_matrix_det, cost_matrix_track = matching.vis_id_feature_A_distance(strack_pool, detections)
  221. vis_feature(self.frame_id,seq_num,img0,track_features,
  222. det_features, cost_matrix, cost_matrix_det, cost_matrix_track, max_num=5, out_path=save_dir)
  223. ''' Step 3: Second association, with IOU'''
  224. # association the untrack to the low score detections
  225. if len(dets_second) > 0:
  226. detections_second = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4]) for
  227. tlbrs in dets_second[:, :5]]
  228. else:
  229. detections_second = []
  230. r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
  231. dists = matching.iou_distance(r_tracked_stracks, detections_second)
  232. matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.4)
  233. for itracked, idet in matches:
  234. track = r_tracked_stracks[itracked]
  235. det = detections_second[idet]
  236. if track.state == TrackState.Tracked:
  237. track.update(det, self.frame_id)
  238. activated_starcks.append(track)
  239. else:
  240. track.re_activate(det, self.frame_id, new_id=False)
  241. refind_stracks.append(track)
  242. for it in u_track:
  243. track = r_tracked_stracks[it]
  244. if not track.state == TrackState.Lost:
  245. track.mark_lost()
  246. lost_stracks.append(track)
  247. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  248. detections = [detections[i] for i in u_detection]
  249. dists = matching.iou_distance(unconfirmed, detections)
  250. matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
  251. for itracked, idet in matches:
  252. unconfirmed[itracked].update(detections[idet], self.frame_id)
  253. activated_starcks.append(unconfirmed[itracked])
  254. for it in u_unconfirmed:
  255. track = unconfirmed[it]
  256. track.mark_removed()
  257. removed_stracks.append(track)
  258. """ Step 4: Init new stracks"""
  259. for inew in u_detection:
  260. track = detections[inew]
  261. if track.score < self.high_thres:
  262. continue
  263. track.activate(self.kalman_filter, self.frame_id)
  264. activated_starcks.append(track)
  265. """ Step 5: Update state"""
  266. for track in self.lost_stracks:
  267. if self.frame_id - track.end_frame > self.max_time_lost:
  268. track.mark_removed()
  269. removed_stracks.append(track)
  270. # print('Ramained match {} s'.format(t4-t3))
  271. self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
  272. self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
  273. self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
  274. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  275. self.lost_stracks.extend(lost_stracks)
  276. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  277. self.removed_stracks.extend(removed_stracks)
  278. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
  279. # get scores of lost tracks
  280. output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  281. logger.debug('===========Frame {}=========='.format(self.frame_id))
  282. logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
  283. logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
  284. logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
  285. logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
  286. return output_stracks
  287. def joint_stracks(tlista, tlistb):
  288. exists = {}
  289. res = []
  290. for t in tlista:
  291. exists[t.track_id] = 1
  292. res.append(t)
  293. for t in tlistb:
  294. tid = t.track_id
  295. if not exists.get(tid, 0):
  296. exists[tid] = 1
  297. res.append(t)
  298. return res
  299. def sub_stracks(tlista, tlistb):
  300. stracks = {}
  301. for t in tlista:
  302. stracks[t.track_id] = t
  303. for t in tlistb:
  304. tid = t.track_id
  305. if stracks.get(tid, 0):
  306. del stracks[tid]
  307. return list(stracks.values())
  308. def remove_duplicate_stracks(stracksa, stracksb):
  309. pdist = matching.iou_distance(stracksa, stracksb)
  310. pairs = np.where(pdist < 0.15)
  311. dupa, dupb = list(), list()
  312. for p, q in zip(*pairs):
  313. timep = stracksa[p].frame_id - stracksa[p].start_frame
  314. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  315. if timep > timeq:
  316. dupb.append(q)
  317. else:
  318. dupa.append(p)
  319. resa = [t for i, t in enumerate(stracksa) if not i in dupa]
  320. resb = [t for i, t in enumerate(stracksb) if not i in dupb]
  321. return resa, resb
  322. def vis_feature(frame_id,seq_num,img,track_features, det_features, cost_matrix, cost_matrix_det, cost_matrix_track,max_num=5, out_path='/home/XX/'):
  323. num_zero = ["0000","000","00","0"]
  324. img = cv2.resize(img, (778, 435))
  325. if len(det_features) != 0:
  326. max_f = det_features.max()
  327. min_f = det_features.min()
  328. det_features = np.round((det_features - min_f) / (max_f - min_f) * 255)
  329. det_features = det_features.astype(np.uint8)
  330. d_F_M = []
  331. cutpff_line = [40]*512
  332. for d_f in det_features:
  333. for row in range(45):
  334. d_F_M += [[40]*3+d_f.tolist()+[40]*3]
  335. for row in range(3):
  336. d_F_M += [[40]*3+cutpff_line+[40]*3]
  337. d_F_M = np.array(d_F_M)
  338. d_F_M = d_F_M.astype(np.uint8)
  339. det_features_img = cv2.applyColorMap(d_F_M, cv2.COLORMAP_JET)
  340. feature_img2 = cv2.resize(det_features_img, (435, 435))
  341. #cv2.putText(feature_img2, "det_features", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  342. else:
  343. feature_img2 = np.zeros((435, 435))
  344. feature_img2 = feature_img2.astype(np.uint8)
  345. feature_img2 = cv2.applyColorMap(feature_img2, cv2.COLORMAP_JET)
  346. #cv2.putText(feature_img2, "det_features", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  347. feature_img = np.concatenate((img, feature_img2), axis=1)
  348. if len(cost_matrix_det) != 0 and len(cost_matrix_det[0]) != 0:
  349. max_f = cost_matrix_det.max()
  350. min_f = cost_matrix_det.min()
  351. cost_matrix_det = np.round((cost_matrix_det - min_f) / (max_f - min_f) * 255)
  352. d_F_M = []
  353. cutpff_line = [40]*len(cost_matrix_det)*10
  354. for c_m in cost_matrix_det:
  355. add = []
  356. for row in range(len(c_m)):
  357. add += [255-c_m[row]]*10
  358. for row in range(10):
  359. d_F_M += [[40]+add+[40]]
  360. d_F_M = np.array(d_F_M)
  361. d_F_M = d_F_M.astype(np.uint8)
  362. cost_matrix_det_img = cv2.applyColorMap(d_F_M, cv2.COLORMAP_JET)
  363. feature_img2 = cv2.resize(cost_matrix_det_img, (435, 435))
  364. #cv2.putText(feature_img2, "cost_matrix_det", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  365. else:
  366. feature_img2 = np.zeros((435, 435))
  367. feature_img2 = feature_img2.astype(np.uint8)
  368. feature_img2 = cv2.applyColorMap(feature_img2, cv2.COLORMAP_JET)
  369. #cv2.putText(feature_img2, "cost_matrix_det", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  370. feature_img = np.concatenate((feature_img, feature_img2), axis=1)
  371. if len(track_features) != 0:
  372. max_f = track_features.max()
  373. min_f = track_features.min()
  374. track_features = np.round((track_features - min_f) / (max_f - min_f) * 255)
  375. track_features = track_features.astype(np.uint8)
  376. d_F_M = []
  377. cutpff_line = [40]*512
  378. for d_f in track_features:
  379. for row in range(45):
  380. d_F_M += [[40]*3+d_f.tolist()+[40]*3]
  381. for row in range(3):
  382. d_F_M += [[40]*3+cutpff_line+[40]*3]
  383. d_F_M = np.array(d_F_M)
  384. d_F_M = d_F_M.astype(np.uint8)
  385. track_features_img = cv2.applyColorMap(d_F_M, cv2.COLORMAP_JET)
  386. feature_img2 = cv2.resize(track_features_img, (435, 435))
  387. #cv2.putText(feature_img2, "track_features", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  388. else:
  389. feature_img2 = np.zeros((435, 435))
  390. feature_img2 = feature_img2.astype(np.uint8)
  391. feature_img2 = cv2.applyColorMap(feature_img2, cv2.COLORMAP_JET)
  392. #cv2.putText(feature_img2, "track_features", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  393. feature_img = np.concatenate((feature_img, feature_img2), axis=1)
  394. if len(cost_matrix_track) != 0 and len(cost_matrix_track[0]) != 0:
  395. max_f = cost_matrix_track.max()
  396. min_f = cost_matrix_track.min()
  397. cost_matrix_track = np.round((cost_matrix_track - min_f) / (max_f - min_f) * 255)
  398. d_F_M = []
  399. cutpff_line = [40]*len(cost_matrix_track)*10
  400. for c_m in cost_matrix_track:
  401. add = []
  402. for row in range(len(c_m)):
  403. add += [255-c_m[row]]*10
  404. for row in range(10):
  405. d_F_M += [[40]+add+[40]]
  406. d_F_M = np.array(d_F_M)
  407. d_F_M = d_F_M.astype(np.uint8)
  408. cost_matrix_track_img = cv2.applyColorMap(d_F_M, cv2.COLORMAP_JET)
  409. feature_img2 = cv2.resize(cost_matrix_track_img, (435, 435))
  410. #cv2.putText(feature_img2, "cost_matrix_track", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  411. else:
  412. feature_img2 = np.zeros((435, 435))
  413. feature_img2 = feature_img2.astype(np.uint8)
  414. feature_img2 = cv2.applyColorMap(feature_img2, cv2.COLORMAP_JET)
  415. #cv2.putText(feature_img2, "cost_matrix_track", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  416. feature_img = np.concatenate((feature_img, feature_img2), axis=1)
  417. if len(cost_matrix) != 0 and len(cost_matrix[0]) != 0:
  418. max_f = cost_matrix.max()
  419. min_f = cost_matrix.min()
  420. cost_matrix = np.round((cost_matrix - min_f) / (max_f - min_f) * 255)
  421. d_F_M = []
  422. cutpff_line = [40]*len(cost_matrix[0])*10
  423. for c_m in cost_matrix:
  424. add = []
  425. for row in range(len(c_m)):
  426. add += [255-c_m[row]]*10
  427. for row in range(10):
  428. d_F_M += [[40]+add+[40]]
  429. d_F_M = np.array(d_F_M)
  430. d_F_M = d_F_M.astype(np.uint8)
  431. cost_matrix_img = cv2.applyColorMap(d_F_M, cv2.COLORMAP_JET)
  432. feature_img2 = cv2.resize(cost_matrix_img, (435, 435))
  433. #cv2.putText(feature_img2, "cost_matrix", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  434. else:
  435. feature_img2 = np.zeros((435, 435))
  436. feature_img2 = feature_img2.astype(np.uint8)
  437. feature_img2 = cv2.applyColorMap(feature_img2, cv2.COLORMAP_JET)
  438. #cv2.putText(feature_img2, "cost_matrix", (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  439. feature_img = np.concatenate((feature_img, feature_img2), axis=1)
  440. dst_path = out_path + "/" + seq_num + "_" + num_zero[len(str(frame_id))-1] + str(frame_id) + '.png'
  441. cv2.imwrite(dst_path, feature_img)