zone.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/LCFractal/AIC21-MTMC/tree/main/reid/reid-matching/tools
  16. Note: The following codes are strongly related to zone of the AIC21 test-set S06,
  17. so they can only be used in S06, and can not be used for other MTMCT datasets.
  18. """
  19. import os
  20. import cv2
  21. import numpy as np
  22. from sklearn.cluster import AgglomerativeClustering
  23. BBOX_B = 10 / 15
  24. class Zone(object):
  25. def __init__(self, zone_path='datasets/zone'):
  26. # 0: b 1: g 3: r 123:w
  27. # w r not high speed
  28. # b g high speed
  29. assert zone_path != '', "Error: zone_path is not empty!"
  30. zones = {}
  31. for img_name in os.listdir(zone_path):
  32. camnum = int(img_name.split('.')[0][-3:])
  33. zone_img = cv2.imread(os.path.join(zone_path, img_name))
  34. zones[camnum] = zone_img
  35. self.zones = zones
  36. self.current_cam = 0
  37. def set_cam(self, cam):
  38. self.current_cam = cam
  39. def get_zone(self, bbox):
  40. cx = int((bbox[0] + bbox[2]) / 2)
  41. cy = int((bbox[1] + bbox[3]) / 2)
  42. pix = self.zones[self.current_cam][max(cy - 1, 0), max(cx - 1, 0), :]
  43. zone_num = 0
  44. if pix[0] > 50 and pix[1] > 50 and pix[2] > 50: # w
  45. zone_num = 1
  46. if pix[0] < 50 and pix[1] < 50 and pix[2] > 50: # r
  47. zone_num = 2
  48. if pix[0] < 50 and pix[1] > 50 and pix[2] < 50: # g
  49. zone_num = 3
  50. if pix[0] > 50 and pix[1] < 50 and pix[2] < 50: # b
  51. zone_num = 4
  52. return zone_num
  53. def is_ignore(self, zone_list, frame_list, cid):
  54. # 0 not in any corssroad, 1 white 2 red 3 green 4 bule
  55. zs, ze = zone_list[0], zone_list[-1]
  56. fs, fe = frame_list[0], frame_list[-1]
  57. if zs == ze:
  58. # if always on one section, excluding
  59. if ze in [1, 2]:
  60. return 2
  61. if zs != 0 and 0 in zone_list:
  62. return 0
  63. if fe - fs > 1500:
  64. return 2
  65. if fs < 2:
  66. if cid in [45]:
  67. if ze in [3, 4]:
  68. return 1
  69. else:
  70. return 2
  71. if fe > 1999:
  72. if cid in [41]:
  73. if ze not in [3]:
  74. return 2
  75. else:
  76. return 0
  77. if fs < 2 or fe > 1999:
  78. if ze in [3, 4]:
  79. return 0
  80. if ze in [3, 4]:
  81. return 1
  82. return 2
  83. else:
  84. # if camera section change
  85. if cid in [41, 42, 43, 44, 45, 46]:
  86. # come from road extension, exclusing
  87. if zs == 1 and ze == 2:
  88. return 2
  89. if zs == 2 and ze == 1:
  90. return 2
  91. if cid in [41]:
  92. # On 41 camera, no vehicle come into 42 camera
  93. if (zs in [1, 2]) and ze == 4:
  94. return 2
  95. if zs == 4 and (ze in [1, 2]):
  96. return 2
  97. if cid in [46]:
  98. # On 46 camera,no vehicle come into 45
  99. if (zs in [1, 2]) and ze == 3:
  100. return 2
  101. if zs == 3 and (ze in [1, 2]):
  102. return 2
  103. return 0
  104. def filter_mot(self, mot_list, cid):
  105. new_mot_list = dict()
  106. sub_mot_list = dict()
  107. for tracklet in mot_list:
  108. tracklet_dict = mot_list[tracklet]
  109. frame_list = list(tracklet_dict.keys())
  110. frame_list.sort()
  111. zone_list = []
  112. for f in frame_list:
  113. zone_list.append(tracklet_dict[f]['zone'])
  114. if self.is_ignore(zone_list, frame_list, cid) == 0:
  115. new_mot_list[tracklet] = tracklet_dict
  116. if self.is_ignore(zone_list, frame_list, cid) == 1:
  117. sub_mot_list[tracklet] = tracklet_dict
  118. return new_mot_list
  119. def filter_bbox(self, mot_list, cid):
  120. new_mot_list = dict()
  121. yh = self.zones[cid].shape[0]
  122. for tracklet in mot_list:
  123. tracklet_dict = mot_list[tracklet]
  124. frame_list = list(tracklet_dict.keys())
  125. frame_list.sort()
  126. bbox_list = []
  127. for f in frame_list:
  128. bbox_list.append(tracklet_dict[f]['bbox'])
  129. bbox_x = [b[0] for b in bbox_list]
  130. bbox_y = [b[1] for b in bbox_list]
  131. bbox_w = [b[2] - b[0] for b in bbox_list]
  132. bbox_h = [b[3] - b[1] for b in bbox_list]
  133. new_frame_list = list()
  134. if 0 in bbox_x or 0 in bbox_y:
  135. b0 = [
  136. i for i, f in enumerate(frame_list)
  137. if bbox_x[i] < 5 or bbox_y[i] + bbox_h[i] > yh - 5
  138. ]
  139. if len(b0) == len(frame_list):
  140. if cid in [41, 42, 44, 45, 46]:
  141. continue
  142. max_w = max(bbox_w)
  143. max_h = max(bbox_h)
  144. for i, f in enumerate(frame_list):
  145. if bbox_w[i] > max_w * BBOX_B and bbox_h[
  146. i] > max_h * BBOX_B:
  147. new_frame_list.append(f)
  148. else:
  149. l_i, r_i = 0, len(frame_list) - 1
  150. if len(b0) == 0:
  151. continue
  152. if b0[0] == 0:
  153. for i in range(len(b0) - 1):
  154. if b0[i] + 1 == b0[i + 1]:
  155. l_i = b0[i + 1]
  156. else:
  157. break
  158. if b0[-1] == len(frame_list) - 1:
  159. for i in range(len(b0) - 1):
  160. i = len(b0) - 1 - i
  161. if b0[i] - 1 == b0[i - 1]:
  162. r_i = b0[i - 1]
  163. else:
  164. break
  165. max_lw, max_lh = bbox_w[l_i], bbox_h[l_i]
  166. max_rw, max_rh = bbox_w[r_i], bbox_h[r_i]
  167. for i, f in enumerate(frame_list):
  168. if i < l_i:
  169. if bbox_w[i] > max_lw * BBOX_B and bbox_h[
  170. i] > max_lh * BBOX_B:
  171. new_frame_list.append(f)
  172. elif i > r_i:
  173. if bbox_w[i] > max_rw * BBOX_B and bbox_h[
  174. i] > max_rh * BBOX_B:
  175. new_frame_list.append(f)
  176. else:
  177. new_frame_list.append(f)
  178. new_tracklet_dict = dict()
  179. for f in new_frame_list:
  180. new_tracklet_dict[f] = tracklet_dict[f]
  181. new_mot_list[tracklet] = new_tracklet_dict
  182. else:
  183. new_mot_list[tracklet] = tracklet_dict
  184. return new_mot_list
  185. def break_mot(self, mot_list, cid):
  186. new_mot_list = dict()
  187. new_num_tracklets = max(mot_list) + 1
  188. for tracklet in mot_list:
  189. tracklet_dict = mot_list[tracklet]
  190. frame_list = list(tracklet_dict.keys())
  191. frame_list.sort()
  192. zone_list = []
  193. back_tracklet = False
  194. new_zone_f = 0
  195. pre_frame = frame_list[0]
  196. time_break = False
  197. for f in frame_list:
  198. if f - pre_frame > 100:
  199. if cid in [44, 45]:
  200. time_break = True
  201. break
  202. if not cid in [41, 44, 45, 46]:
  203. break
  204. pre_frame = f
  205. new_zone = tracklet_dict[f]['zone']
  206. if len(zone_list) > 0 and zone_list[-1] == new_zone:
  207. continue
  208. if new_zone_f > 1:
  209. if len(zone_list) > 1 and new_zone in zone_list:
  210. back_tracklet = True
  211. zone_list.append(new_zone)
  212. new_zone_f = 0
  213. else:
  214. new_zone_f += 1
  215. if back_tracklet:
  216. new_tracklet_dict = dict()
  217. pre_bbox = -1
  218. pre_arrow = 0
  219. have_break = False
  220. for f in frame_list:
  221. now_bbox = tracklet_dict[f]['bbox']
  222. if type(pre_bbox) == int:
  223. if pre_bbox == -1:
  224. pre_bbox = now_bbox
  225. now_arrow = now_bbox[0] - pre_bbox[0]
  226. if pre_arrow * now_arrow < 0 and len(
  227. new_tracklet_dict) > 15 and not have_break:
  228. new_mot_list[tracklet] = new_tracklet_dict
  229. new_tracklet_dict = dict()
  230. have_break = True
  231. if have_break:
  232. tracklet_dict[f]['id'] = new_num_tracklets
  233. new_tracklet_dict[f] = tracklet_dict[f]
  234. pre_bbox, pre_arrow = now_bbox, now_arrow
  235. if have_break:
  236. new_mot_list[new_num_tracklets] = new_tracklet_dict
  237. new_num_tracklets += 1
  238. else:
  239. new_mot_list[tracklet] = new_tracklet_dict
  240. elif time_break:
  241. new_tracklet_dict = dict()
  242. have_break = False
  243. pre_frame = frame_list[0]
  244. for f in frame_list:
  245. if f - pre_frame > 100:
  246. new_mot_list[tracklet] = new_tracklet_dict
  247. new_tracklet_dict = dict()
  248. have_break = True
  249. new_tracklet_dict[f] = tracklet_dict[f]
  250. pre_frame = f
  251. if have_break:
  252. new_mot_list[new_num_tracklets] = new_tracklet_dict
  253. new_num_tracklets += 1
  254. else:
  255. new_mot_list[tracklet] = new_tracklet_dict
  256. else:
  257. new_mot_list[tracklet] = tracklet_dict
  258. return new_mot_list
  259. def intra_matching(self, mot_list, sub_mot_list):
  260. sub_zone_dict = dict()
  261. new_mot_list = dict()
  262. new_mot_list, new_sub_mot_list = self.do_intra_matching2(mot_list,
  263. sub_mot_list)
  264. return new_mot_list
  265. def do_intra_matching2(self, mot_list, sub_list):
  266. new_zone_dict = dict()
  267. def get_trac_info(tracklet1):
  268. t1_f = list(tracklet1)
  269. t1_f.sort()
  270. t1_fs = t1_f[0]
  271. t1_fe = t1_f[-1]
  272. t1_zs = tracklet1[t1_fs]['zone']
  273. t1_ze = tracklet1[t1_fe]['zone']
  274. t1_boxs = tracklet1[t1_fs]['bbox']
  275. t1_boxe = tracklet1[t1_fe]['bbox']
  276. t1_boxs = [(t1_boxs[2] + t1_boxs[0]) / 2,
  277. (t1_boxs[3] + t1_boxs[1]) / 2]
  278. t1_boxe = [(t1_boxe[2] + t1_boxe[0]) / 2,
  279. (t1_boxe[3] + t1_boxe[1]) / 2]
  280. return t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe
  281. for t1id in sub_list:
  282. tracklet1 = sub_list[t1id]
  283. if tracklet1 == -1:
  284. continue
  285. t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe = get_trac_info(
  286. tracklet1)
  287. sim_dict = dict()
  288. for t2id in mot_list:
  289. tracklet2 = mot_list[t2id]
  290. t2_fs, t2_fe, t2_zs, t2_ze, t2_boxs, t2_boxe = get_trac_info(
  291. tracklet2)
  292. if t1_ze == t2_zs:
  293. if abs(t2_fs - t1_fe) < 5 and abs(t2_boxe[0] - t1_boxs[
  294. 0]) < 50 and abs(t2_boxe[1] - t1_boxs[1]) < 50:
  295. t1_feat = tracklet1[t1_fe]['feat']
  296. t2_feat = tracklet2[t2_fs]['feat']
  297. sim_dict[t2id] = np.matmul(t1_feat, t2_feat)
  298. if t1_zs == t2_ze:
  299. if abs(t2_fe - t1_fs) < 5 and abs(t2_boxs[0] - t1_boxe[
  300. 0]) < 50 and abs(t2_boxs[1] - t1_boxe[1]) < 50:
  301. t1_feat = tracklet1[t1_fs]['feat']
  302. t2_feat = tracklet2[t2_fe]['feat']
  303. sim_dict[t2id] = np.matmul(t1_feat, t2_feat)
  304. if len(sim_dict) > 0:
  305. max_sim = 0
  306. max_id = 0
  307. for t2id in sim_dict:
  308. if sim_dict[t2id] > max_sim:
  309. sim_dict[t2id] = max_sim
  310. max_id = t2id
  311. if max_sim > 0.5:
  312. t2 = mot_list[max_id]
  313. for t1f in tracklet1:
  314. if t1f not in t2:
  315. tracklet1[t1f]['id'] = max_id
  316. t2[t1f] = tracklet1[t1f]
  317. mot_list[max_id] = t2
  318. sub_list[t1id] = -1
  319. return mot_list, sub_list
  320. def do_intra_matching(self, sub_zone_dict, sub_zone):
  321. new_zone_dict = dict()
  322. id_list = list(sub_zone_dict)
  323. id2index = dict()
  324. for index, id in enumerate(id_list):
  325. id2index[id] = index
  326. def get_trac_info(tracklet1):
  327. t1_f = list(tracklet1)
  328. t1_f.sort()
  329. t1_fs = t1_f[0]
  330. t1_fe = t1_f[-1]
  331. t1_zs = tracklet1[t1_fs]['zone']
  332. t1_ze = tracklet1[t1_fe]['zone']
  333. t1_boxs = tracklet1[t1_fs]['bbox']
  334. t1_boxe = tracklet1[t1_fe]['bbox']
  335. t1_boxs = [(t1_boxs[2] + t1_boxs[0]) / 2,
  336. (t1_boxs[3] + t1_boxs[1]) / 2]
  337. t1_boxe = [(t1_boxe[2] + t1_boxe[0]) / 2,
  338. (t1_boxe[3] + t1_boxe[1]) / 2]
  339. return t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe
  340. sim_matrix = np.zeros([len(id_list), len(id_list)])
  341. for t1id in sub_zone_dict:
  342. tracklet1 = sub_zone_dict[t1id]
  343. t1_fs, t1_fe, t1_zs, t1_ze, t1_boxs, t1_boxe = get_trac_info(
  344. tracklet1)
  345. t1_feat = tracklet1[t1_fe]['feat']
  346. for t2id in sub_zone_dict:
  347. if t1id == t2id:
  348. continue
  349. tracklet2 = sub_zone_dict[t2id]
  350. t2_fs, t2_fe, t2_zs, t2_ze, t2_boxs, t2_boxe = get_trac_info(
  351. tracklet2)
  352. if t1_zs != t1_ze and t2_ze != t2_zs or t1_fe > t2_fs:
  353. continue
  354. if abs(t1_boxe[0] - t2_boxs[0]) > 50 or abs(t1_boxe[1] -
  355. t2_boxs[1]) > 50:
  356. continue
  357. if t2_fs - t1_fe > 5:
  358. continue
  359. t2_feat = tracklet2[t2_fs]['feat']
  360. sim_matrix[id2index[t1id], id2index[t2id]] = np.matmul(t1_feat,
  361. t2_feat)
  362. sim_matrix[id2index[t2id], id2index[t1id]] = np.matmul(t1_feat,
  363. t2_feat)
  364. sim_matrix = 1 - sim_matrix
  365. cluster_labels = AgglomerativeClustering(
  366. n_clusters=None,
  367. distance_threshold=0.7,
  368. affinity='precomputed',
  369. linkage='complete').fit_predict(sim_matrix)
  370. new_zone_dict = dict()
  371. label2id = dict()
  372. for index, label in enumerate(cluster_labels):
  373. tracklet = sub_zone_dict[id_list[index]]
  374. if label not in label2id:
  375. new_id = tracklet[list(tracklet)[0]]
  376. new_tracklet = dict()
  377. else:
  378. new_id = label2id[label]
  379. new_tracklet = new_zone_dict[label2id[label]]
  380. for tf in tracklet:
  381. tracklet[tf]['id'] = new_id
  382. new_tracklet[tf] = tracklet[tf]
  383. new_zone_dict[label] = new_tracklet
  384. return new_zone_dict