mtmct.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import motmetrics as mm
  15. from pptracking.python.mot.visualize import plot_tracking
  16. import os
  17. import re
  18. import cv2
  19. import gc
  20. import numpy as np
  21. from sklearn import preprocessing
  22. from sklearn.cluster import AgglomerativeClustering
  23. import pandas as pd
  24. from tqdm import tqdm
  25. from functools import reduce
  26. import warnings
  27. warnings.filterwarnings("ignore")
  28. def gen_restxt(output_dir_filename, map_tid, cid_tid_dict):
  29. pattern = re.compile(r'c(\d)_t(\d)')
  30. f_w = open(output_dir_filename, 'w')
  31. for key, res in cid_tid_dict.items():
  32. cid, tid = pattern.search(key).groups()
  33. cid = int(cid) + 1
  34. rects = res["rects"]
  35. frames = res["frames"]
  36. for idx, bbox in enumerate(rects):
  37. bbox[0][3:] -= bbox[0][1:3]
  38. fid = frames[idx] + 1
  39. rect = [max(int(x), 0) for x in bbox[0][1:]]
  40. if key in map_tid:
  41. new_tid = map_tid[key]
  42. f_w.write(
  43. str(cid) + ' ' + str(new_tid) + ' ' + str(fid) + ' ' +
  44. ' '.join(map(str, rect)) + '\n')
  45. print('gen_res: write file in {}'.format(output_dir_filename))
  46. f_w.close()
  47. def get_mtmct_matching_results(pred_mtmct_file, secs_interval=0.5,
  48. video_fps=20):
  49. res = np.loadtxt(pred_mtmct_file) # 'cid, tid, fid, x1, y1, w, h, -1, -1'
  50. camera_ids = list(map(int, np.unique(res[:, 0])))
  51. res = res[:, :7]
  52. # each line in res: 'cid, tid, fid, x1, y1, w, h'
  53. camera_tids = []
  54. camera_results = dict()
  55. for c_id in camera_ids:
  56. camera_results[c_id] = res[res[:, 0] == c_id]
  57. tids = np.unique(camera_results[c_id][:, 1])
  58. tids = list(map(int, tids))
  59. camera_tids.append(tids)
  60. # select common tids throughout each video
  61. common_tids = reduce(np.intersect1d, camera_tids)
  62. # get mtmct matching results by cid_tid_fid_results[c_id][t_id][f_id]
  63. cid_tid_fid_results = dict()
  64. cid_tid_to_fids = dict()
  65. interval = int(secs_interval * video_fps) # preferably less than 10
  66. for c_id in camera_ids:
  67. cid_tid_fid_results[c_id] = dict()
  68. cid_tid_to_fids[c_id] = dict()
  69. for t_id in common_tids:
  70. tid_mask = camera_results[c_id][:, 1] == t_id
  71. cid_tid_fid_results[c_id][t_id] = dict()
  72. camera_trackid_results = camera_results[c_id][tid_mask]
  73. fids = np.unique(camera_trackid_results[:, 2])
  74. fids = fids[fids % interval == 0]
  75. fids = list(map(int, fids))
  76. cid_tid_to_fids[c_id][t_id] = fids
  77. for f_id in fids:
  78. st_frame = f_id
  79. ed_frame = f_id + interval
  80. st_mask = camera_trackid_results[:, 2] >= st_frame
  81. ed_mask = camera_trackid_results[:, 2] < ed_frame
  82. frame_mask = np.logical_and(st_mask, ed_mask)
  83. cid_tid_fid_results[c_id][t_id][f_id] = camera_trackid_results[
  84. frame_mask]
  85. return camera_results, cid_tid_fid_results
  86. def save_mtmct_vis_results(camera_results, captures, output_dir):
  87. # camera_results: 'cid, tid, fid, x1, y1, w, h'
  88. camera_ids = list(camera_results.keys())
  89. import shutil
  90. save_dir = os.path.join(output_dir, 'mtmct_vis')
  91. if os.path.exists(save_dir):
  92. shutil.rmtree(save_dir)
  93. os.makedirs(save_dir)
  94. for idx, video_file in enumerate(captures):
  95. capture = cv2.VideoCapture(video_file)
  96. cid = camera_ids[idx]
  97. basename = os.path.basename(video_file)
  98. video_out_name = "vis_" + basename
  99. out_path = os.path.join(save_dir, video_out_name)
  100. print("Start visualizing output video: {}".format(out_path))
  101. # Get Video info : resolution, fps, frame count
  102. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  103. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  104. fps = int(capture.get(cv2.CAP_PROP_FPS))
  105. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  106. fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
  107. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  108. frame_id = 0
  109. while (1):
  110. if frame_id % 50 == 0:
  111. print('frame id: ', frame_id)
  112. ret, frame = capture.read()
  113. frame_id += 1
  114. if not ret:
  115. if frame_id == 1:
  116. print("video read failed!")
  117. break
  118. frame_results = camera_results[cid][camera_results[cid][:, 2] ==
  119. frame_id]
  120. boxes = frame_results[:, -4:]
  121. ids = frame_results[:, 1]
  122. image = plot_tracking(frame, boxes, ids, frame_id=frame_id, fps=fps)
  123. writer.write(image)
  124. writer.release()
  125. def get_euclidean(x, y, **kwargs):
  126. m = x.shape[0]
  127. n = y.shape[0]
  128. distmat = (np.power(x, 2).sum(axis=1, keepdims=True).repeat(
  129. n, axis=1) + np.power(y, 2).sum(axis=1, keepdims=True).repeat(
  130. m, axis=1).T)
  131. distmat -= np.dot(2 * x, y.T)
  132. return distmat
  133. def cosine_similarity(x, y, eps=1e-12):
  134. """
  135. Computes cosine similarity between two tensors.
  136. Value == 1 means the same vector
  137. Value == 0 means perpendicular vectors
  138. """
  139. x_n, y_n = np.linalg.norm(
  140. x, axis=1, keepdims=True), np.linalg.norm(
  141. y, axis=1, keepdims=True)
  142. x_norm = x / np.maximum(x_n, eps * np.ones_like(x_n))
  143. y_norm = y / np.maximum(y_n, eps * np.ones_like(y_n))
  144. sim_mt = np.dot(x_norm, y_norm.T)
  145. return sim_mt
  146. def get_cosine(x, y, eps=1e-12):
  147. """
  148. Computes cosine distance between two tensors.
  149. The cosine distance is the inverse cosine similarity
  150. -> cosine_distance = abs(-cosine_distance) to make it
  151. similar in behaviour to euclidean distance
  152. """
  153. sim_mt = cosine_similarity(x, y, eps)
  154. return sim_mt
  155. def get_dist_mat(x, y, func_name="euclidean"):
  156. if func_name == "cosine":
  157. dist_mat = get_cosine(x, y)
  158. elif func_name == "euclidean":
  159. dist_mat = get_euclidean(x, y)
  160. print("Using {} as distance function during evaluation".format(func_name))
  161. return dist_mat
  162. def intracam_ignore(st_mask, cid_tids):
  163. count = len(cid_tids)
  164. for i in range(count):
  165. for j in range(count):
  166. if cid_tids[i][1] == cid_tids[j][1]:
  167. st_mask[i, j] = 0.
  168. return st_mask
  169. def get_sim_matrix_new(cid_tid_dict, cid_tids):
  170. # Note: camera independent get_sim_matrix function,
  171. # which is different from the one in camera_utils.py.
  172. count = len(cid_tids)
  173. q_arr = np.array(
  174. [cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)])
  175. g_arr = np.array(
  176. [cid_tid_dict[cid_tids[i]]['mean_feat'] for i in range(count)])
  177. #compute distmat
  178. distmat = get_dist_mat(q_arr, g_arr, func_name="cosine")
  179. #mask the element which belongs to same video
  180. st_mask = np.ones((count, count), dtype=np.float32)
  181. st_mask = intracam_ignore(st_mask, cid_tids)
  182. sim_matrix = distmat * st_mask
  183. np.fill_diagonal(sim_matrix, 0.)
  184. return 1. - sim_matrix
  185. def get_match(cluster_labels):
  186. cluster_dict = dict()
  187. cluster = list()
  188. for i, l in enumerate(cluster_labels):
  189. if l in list(cluster_dict.keys()):
  190. cluster_dict[l].append(i)
  191. else:
  192. cluster_dict[l] = [i]
  193. for idx in cluster_dict:
  194. cluster.append(cluster_dict[idx])
  195. return cluster
  196. def get_cid_tid(cluster_labels, cid_tids):
  197. cluster = list()
  198. for labels in cluster_labels:
  199. cid_tid_list = list()
  200. for label in labels:
  201. cid_tid_list.append(cid_tids[label])
  202. cluster.append(cid_tid_list)
  203. return cluster
  204. def get_labels(cid_tid_dict, cid_tids):
  205. #compute cost matrix between features
  206. cost_matrix = get_sim_matrix_new(cid_tid_dict, cid_tids)
  207. #cluster all the features
  208. cluster1 = AgglomerativeClustering(
  209. n_clusters=None,
  210. distance_threshold=0.5,
  211. affinity='precomputed',
  212. linkage='complete')
  213. cluster_labels1 = cluster1.fit_predict(cost_matrix)
  214. labels = get_match(cluster_labels1)
  215. sub_cluster = get_cid_tid(labels, cid_tids)
  216. return labels
  217. def sub_cluster(cid_tid_dict):
  218. '''
  219. cid_tid_dict: all camera_id and track_id
  220. '''
  221. #get all keys
  222. cid_tids = sorted([key for key in cid_tid_dict.keys()])
  223. #cluster all trackid
  224. clu = get_labels(cid_tid_dict, cid_tids)
  225. #relabel every cluster groups
  226. new_clu = list()
  227. for c_list in clu:
  228. new_clu.append([cid_tids[c] for c in c_list])
  229. cid_tid_label = dict()
  230. for i, c_list in enumerate(new_clu):
  231. for c in c_list:
  232. cid_tid_label[c] = i + 1
  233. return cid_tid_label
  234. def distill_idfeat(mot_res):
  235. qualities_list = mot_res["qualities"]
  236. feature_list = mot_res["features"]
  237. rects = mot_res["rects"]
  238. qualities_new = []
  239. feature_new = []
  240. #filter rect less than 100*20
  241. for idx, rect in enumerate(rects):
  242. conf, xmin, ymin, xmax, ymax = rect[0]
  243. if (xmax - xmin) * (ymax - ymin) and (xmax > xmin) > 2000:
  244. qualities_new.append(qualities_list[idx])
  245. feature_new.append(feature_list[idx])
  246. #take all features if available rect is less than 2
  247. if len(qualities_new) < 2:
  248. qualities_new = qualities_list
  249. feature_new = feature_list
  250. #if available frames number is more than 200, take one frame data per 20 frames
  251. skipf = 1
  252. if len(qualities_new) > 20:
  253. skipf = 2
  254. quality_skip = np.array(qualities_new[::skipf])
  255. feature_skip = np.array(feature_new[::skipf])
  256. #sort features with image qualities, take the most trustworth features
  257. topk_argq = np.argsort(quality_skip)[::-1]
  258. if (quality_skip > 0.6).sum() > 1:
  259. topk_feat = feature_skip[topk_argq[quality_skip > 0.6]]
  260. else:
  261. topk_feat = feature_skip[topk_argq]
  262. #get final features by mean or cluster, at most take five
  263. mean_feat = np.mean(topk_feat[:5], axis=0)
  264. return mean_feat
  265. def res2dict(multi_res):
  266. cid_tid_dict = {}
  267. for cid, c_res in enumerate(multi_res):
  268. for tid, res in c_res.items():
  269. key = "c" + str(cid) + "_t" + str(tid)
  270. if key not in cid_tid_dict:
  271. if len(res["rects"]) < 10:
  272. continue
  273. cid_tid_dict[key] = res
  274. cid_tid_dict[key]['mean_feat'] = distill_idfeat(res)
  275. return cid_tid_dict
  276. def mtmct_process(multi_res, captures, mtmct_vis=True, output_dir="output"):
  277. cid_tid_dict = res2dict(multi_res)
  278. map_tid = sub_cluster(cid_tid_dict)
  279. if not os.path.exists(output_dir):
  280. os.mkdir(output_dir)
  281. pred_mtmct_file = os.path.join(output_dir, 'mtmct_result.txt')
  282. gen_restxt(pred_mtmct_file, map_tid, cid_tid_dict)
  283. if mtmct_vis:
  284. camera_results, cid_tid_fid_res = get_mtmct_matching_results(
  285. pred_mtmct_file)
  286. save_mtmct_vis_results(camera_results, captures, output_dir=output_dir)