postprocess.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/LCFractal/AIC21-MTMC/tree/main/reid/reid-matching/tools
  16. """
  17. import os
  18. import re
  19. import cv2
  20. from tqdm import tqdm
  21. import numpy as np
  22. import motmetrics as mm
  23. from functools import reduce
  24. from .utils import parse_pt_gt, parse_pt, compare_dataframes_mtmc
  25. from .utils import get_labels, getData, gen_new_mot
  26. from .camera_utils import get_labels_with_camera
  27. from .zone import Zone
  28. from ..visualize import plot_tracking
  29. __all__ = [
  30. 'trajectory_fusion',
  31. 'sub_cluster',
  32. 'gen_res',
  33. 'print_mtmct_result',
  34. 'get_mtmct_matching_results',
  35. 'save_mtmct_crops',
  36. 'save_mtmct_vis_results',
  37. ]
  38. def trajectory_fusion(mot_feature, cid, cid_bias, use_zone=False, zone_path=''):
  39. cur_bias = cid_bias[cid]
  40. mot_list_break = {}
  41. if use_zone:
  42. zones = Zone(zone_path=zone_path)
  43. zones.set_cam(cid)
  44. mot_list = parse_pt(mot_feature, zones)
  45. else:
  46. mot_list = parse_pt(mot_feature)
  47. if use_zone:
  48. mot_list = zones.break_mot(mot_list, cid)
  49. mot_list = zones.filter_mot(mot_list, cid) # filter by zone
  50. mot_list = zones.filter_bbox(mot_list, cid) # filter bbox
  51. mot_list_break = gen_new_mot(mot_list) # save break feature for gen result
  52. tid_data = dict()
  53. for tid in mot_list:
  54. tracklet = mot_list[tid]
  55. if len(tracklet) <= 1:
  56. continue
  57. frame_list = list(tracklet.keys())
  58. frame_list.sort()
  59. # filter area too large
  60. zone_list = [tracklet[f]['zone'] for f in frame_list]
  61. feature_list = [
  62. tracklet[f]['feat'] for f in frame_list
  63. if (tracklet[f]['bbox'][3] - tracklet[f]['bbox'][1]) *
  64. (tracklet[f]['bbox'][2] - tracklet[f]['bbox'][0]) > 2000
  65. ]
  66. if len(feature_list) < 2:
  67. feature_list = [tracklet[f]['feat'] for f in frame_list]
  68. io_time = [
  69. cur_bias + frame_list[0] / 10., cur_bias + frame_list[-1] / 10.
  70. ]
  71. all_feat = np.array([feat for feat in feature_list])
  72. mean_feat = np.mean(all_feat, axis=0)
  73. tid_data[tid] = {
  74. 'cam': cid,
  75. 'tid': tid,
  76. 'mean_feat': mean_feat,
  77. 'zone_list': zone_list,
  78. 'frame_list': frame_list,
  79. 'tracklet': tracklet,
  80. 'io_time': io_time
  81. }
  82. return tid_data, mot_list_break
  83. def sub_cluster(cid_tid_dict,
  84. scene_cluster,
  85. use_ff=True,
  86. use_rerank=True,
  87. use_camera=False,
  88. use_st_filter=False):
  89. '''
  90. cid_tid_dict: all camera_id and track_id
  91. scene_cluster: like [41, 42, 43, 44, 45, 46] in AIC21 MTMCT S06 test videos
  92. '''
  93. assert (len(scene_cluster) != 0), "Error: scene_cluster length equals 0"
  94. cid_tids = sorted(
  95. [key for key in cid_tid_dict.keys() if key[0] in scene_cluster])
  96. if use_camera:
  97. clu = get_labels_with_camera(
  98. cid_tid_dict,
  99. cid_tids,
  100. use_ff=use_ff,
  101. use_rerank=use_rerank,
  102. use_st_filter=use_st_filter)
  103. else:
  104. clu = get_labels(
  105. cid_tid_dict,
  106. cid_tids,
  107. use_ff=use_ff,
  108. use_rerank=use_rerank,
  109. use_st_filter=use_st_filter)
  110. new_clu = list()
  111. for c_list in clu:
  112. if len(c_list) <= 1: continue
  113. cam_list = [cid_tids[c][0] for c in c_list]
  114. if len(cam_list) != len(set(cam_list)): continue
  115. new_clu.append([cid_tids[c] for c in c_list])
  116. all_clu = new_clu
  117. cid_tid_label = dict()
  118. for i, c_list in enumerate(all_clu):
  119. for c in c_list:
  120. cid_tid_label[c] = i + 1
  121. return cid_tid_label
  122. def gen_res(output_dir_filename,
  123. scene_cluster,
  124. map_tid,
  125. mot_list_breaks,
  126. use_roi=False,
  127. roi_dir=''):
  128. f_w = open(output_dir_filename, 'w')
  129. for idx, mot_feature in enumerate(mot_list_breaks):
  130. cid = scene_cluster[idx]
  131. img_rects = parse_pt_gt(mot_feature)
  132. if use_roi:
  133. assert (roi_dir != ''), "Error: roi_dir is not empty!"
  134. roi = cv2.imread(os.path.join(roi_dir, f'c{cid:03d}/roi.jpg'), 0)
  135. height, width = roi.shape
  136. for fid in img_rects:
  137. tid_rects = img_rects[fid]
  138. fid = int(fid) + 1
  139. for tid_rect in tid_rects:
  140. tid = tid_rect[0]
  141. rect = tid_rect[1:]
  142. cx = 0.5 * rect[0] + 0.5 * rect[2]
  143. cy = 0.5 * rect[1] + 0.5 * rect[3]
  144. w = rect[2] - rect[0]
  145. w = min(w * 1.2, w + 40)
  146. h = rect[3] - rect[1]
  147. h = min(h * 1.2, h + 40)
  148. rect[2] -= rect[0]
  149. rect[3] -= rect[1]
  150. rect[0] = max(0, rect[0])
  151. rect[1] = max(0, rect[1])
  152. x1, y1 = max(0, cx - 0.5 * w), max(0, cy - 0.5 * h)
  153. if use_roi:
  154. x2, y2 = min(width, cx + 0.5 * w), min(height, cy + 0.5 * h)
  155. else:
  156. x2, y2 = cx + 0.5 * w, cy + 0.5 * h
  157. w, h = x2 - x1, y2 - y1
  158. new_rect = list(map(int, [x1, y1, w, h]))
  159. rect = list(map(int, rect))
  160. if (cid, tid) in map_tid:
  161. new_tid = map_tid[(cid, tid)]
  162. f_w.write(
  163. str(cid) + ' ' + str(new_tid) + ' ' + str(fid) + ' ' +
  164. ' '.join(map(str, new_rect)) + ' -1 -1'
  165. '\n')
  166. print('gen_res: write file in {}'.format(output_dir_filename))
  167. f_w.close()
  168. def print_mtmct_result(gt_file, pred_file):
  169. names = [
  170. 'CameraId', 'Id', 'FrameId', 'X', 'Y', 'Width', 'Height', 'Xworld',
  171. 'Yworld'
  172. ]
  173. gt = getData(gt_file, names=names)
  174. pred = getData(pred_file, names=names)
  175. summary = compare_dataframes_mtmc(gt, pred)
  176. print('MTMCT summary: ', summary.columns.tolist())
  177. formatters = {
  178. 'idf1': '{:2.2f}'.format,
  179. 'idp': '{:2.2f}'.format,
  180. 'idr': '{:2.2f}'.format,
  181. 'mota': '{:2.2f}'.format
  182. }
  183. summary = summary[['idf1', 'idp', 'idr', 'mota']]
  184. summary.loc[:, 'idp'] *= 100
  185. summary.loc[:, 'idr'] *= 100
  186. summary.loc[:, 'idf1'] *= 100
  187. summary.loc[:, 'mota'] *= 100
  188. print(
  189. mm.io.render_summary(
  190. summary,
  191. formatters=formatters,
  192. namemap=mm.io.motchallenge_metric_names))
  193. def get_mtmct_matching_results(pred_mtmct_file, secs_interval=0.5,
  194. video_fps=20):
  195. res = np.loadtxt(pred_mtmct_file) # 'cid, tid, fid, x1, y1, w, h, -1, -1'
  196. camera_ids = list(map(int, np.unique(res[:, 0])))
  197. res = res[:, :7]
  198. # each line in res: 'cid, tid, fid, x1, y1, w, h'
  199. camera_tids = []
  200. camera_results = dict()
  201. for c_id in camera_ids:
  202. camera_results[c_id] = res[res[:, 0] == c_id]
  203. tids = np.unique(camera_results[c_id][:, 1])
  204. tids = list(map(int, tids))
  205. camera_tids.append(tids)
  206. # select common tids throughout each video
  207. common_tids = reduce(np.intersect1d, camera_tids)
  208. if len(common_tids) == 0:
  209. print(
  210. 'No common tracked ids in these videos, please check your MOT result or select new videos.'
  211. )
  212. return None, None
  213. # get mtmct matching results by cid_tid_fid_results[c_id][t_id][f_id]
  214. cid_tid_fid_results = dict()
  215. cid_tid_to_fids = dict()
  216. interval = int(secs_interval * video_fps) # preferably less than 10
  217. for c_id in camera_ids:
  218. cid_tid_fid_results[c_id] = dict()
  219. cid_tid_to_fids[c_id] = dict()
  220. for t_id in common_tids:
  221. tid_mask = camera_results[c_id][:, 1] == t_id
  222. cid_tid_fid_results[c_id][t_id] = dict()
  223. camera_trackid_results = camera_results[c_id][tid_mask]
  224. fids = np.unique(camera_trackid_results[:, 2])
  225. fids = fids[fids % interval == 0]
  226. fids = list(map(int, fids))
  227. cid_tid_to_fids[c_id][t_id] = fids
  228. for f_id in fids:
  229. st_frame = f_id
  230. ed_frame = f_id + interval
  231. st_mask = camera_trackid_results[:, 2] >= st_frame
  232. ed_mask = camera_trackid_results[:, 2] < ed_frame
  233. frame_mask = np.logical_and(st_mask, ed_mask)
  234. cid_tid_fid_results[c_id][t_id][f_id] = camera_trackid_results[
  235. frame_mask]
  236. return camera_results, cid_tid_fid_results
  237. def save_mtmct_crops(cid_tid_fid_res,
  238. images_dir,
  239. crops_dir,
  240. width=300,
  241. height=200):
  242. camera_ids = cid_tid_fid_res.keys()
  243. seqs_folder = os.listdir(images_dir)
  244. seqs = []
  245. for x in seqs_folder:
  246. if os.path.isdir(os.path.join(images_dir, x)):
  247. seqs.append(x)
  248. assert len(seqs) == len(camera_ids)
  249. seqs.sort()
  250. if not os.path.exists(crops_dir):
  251. os.makedirs(crops_dir)
  252. common_tids = list(cid_tid_fid_res[list(camera_ids)[0]].keys())
  253. # get crops by name 'tid_cid_fid.jpg
  254. for t_id in common_tids:
  255. for i, c_id in enumerate(camera_ids):
  256. infer_dir = os.path.join(images_dir, seqs[i])
  257. if os.path.exists(os.path.join(infer_dir, 'img1')):
  258. infer_dir = os.path.join(infer_dir, 'img1')
  259. all_images = os.listdir(infer_dir)
  260. all_images.sort()
  261. for f_id in cid_tid_fid_res[c_id][t_id].keys():
  262. frame_idx = f_id - 1 if f_id > 0 else 0
  263. im_path = os.path.join(infer_dir, all_images[frame_idx])
  264. im = cv2.imread(im_path) # (H, W, 3)
  265. # only select one track
  266. track = cid_tid_fid_res[c_id][t_id][f_id][0]
  267. cid, tid, fid, x1, y1, w, h = [int(v) for v in track]
  268. clip = im[y1:(y1 + h), x1:(x1 + w)]
  269. clip = cv2.resize(clip, (width, height))
  270. cv2.imwrite(
  271. os.path.join(crops_dir,
  272. 'tid{:06d}_cid{:06d}_fid{:06d}.jpg'.format(
  273. tid, cid, fid)), clip)
  274. print("Finish cropping image of tracked_id {} in camera: {}".format(
  275. t_id, c_id))
  276. def save_mtmct_vis_results(camera_results,
  277. images_dir,
  278. save_dir,
  279. save_videos=False):
  280. # camera_results: 'cid, tid, fid, x1, y1, w, h'
  281. camera_ids = camera_results.keys()
  282. seqs_folder = os.listdir(images_dir)
  283. seqs = []
  284. for x in seqs_folder:
  285. if os.path.isdir(os.path.join(images_dir, x)):
  286. seqs.append(x)
  287. assert len(seqs) == len(camera_ids)
  288. seqs.sort()
  289. if not os.path.exists(save_dir):
  290. os.makedirs(save_dir)
  291. for i, c_id in enumerate(camera_ids):
  292. print("Start visualization for camera {} of sequence {}.".format(
  293. c_id, seqs[i]))
  294. cid_save_dir = os.path.join(save_dir, '{}'.format(seqs[i]))
  295. if not os.path.exists(cid_save_dir):
  296. os.makedirs(cid_save_dir)
  297. infer_dir = os.path.join(images_dir, seqs[i])
  298. if os.path.exists(os.path.join(infer_dir, 'img1')):
  299. infer_dir = os.path.join(infer_dir, 'img1')
  300. all_images = os.listdir(infer_dir)
  301. all_images.sort()
  302. for f_id, im_path in enumerate(all_images):
  303. img = cv2.imread(os.path.join(infer_dir, im_path))
  304. tracks = camera_results[c_id][camera_results[c_id][:, 2] == f_id]
  305. if tracks.shape[0] > 0:
  306. tracked_ids = tracks[:, 1]
  307. xywhs = tracks[:, 3:]
  308. online_im = plot_tracking(
  309. img, xywhs, tracked_ids, scores=None, frame_id=f_id)
  310. else:
  311. online_im = img
  312. print('Frame {} of seq {} has no tracking results'.format(
  313. f_id, seqs[i]))
  314. cv2.imwrite(
  315. os.path.join(cid_save_dir, '{:05d}.jpg'.format(f_id)),
  316. online_im)
  317. if f_id % 40 == 0:
  318. print('Processing frame {}'.format(f_id))
  319. if save_videos:
  320. output_video_path = os.path.join(cid_save_dir, '..',
  321. '{}_mtmct_vis.mp4'.format(seqs[i]))
  322. cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
  323. cid_save_dir, output_video_path)
  324. os.system(cmd_str)
  325. print('Save camera {} video in {}.'.format(seqs[i],
  326. output_video_path))