utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import time
  17. import numpy as np
  18. import collections
  19. __all__ = [
  20. 'MOTTimer', 'Detection', 'write_mot_results', 'load_det_results',
  21. 'preprocess_reid', 'get_crops', 'clip_box', 'scale_coords', 'flow_statistic'
  22. ]
  23. class MOTTimer(object):
  24. """
  25. This class used to compute and print the current FPS while evaling.
  26. """
  27. def __init__(self, window_size=20):
  28. self.start_time = 0.
  29. self.diff = 0.
  30. self.duration = 0.
  31. self.deque = collections.deque(maxlen=window_size)
  32. def tic(self):
  33. # using time.time instead of time.clock because time time.clock
  34. # does not normalize for multithreading
  35. self.start_time = time.time()
  36. def toc(self, average=True):
  37. self.diff = time.time() - self.start_time
  38. self.deque.append(self.diff)
  39. if average:
  40. self.duration = np.mean(self.deque)
  41. else:
  42. self.duration = np.sum(self.deque)
  43. return self.duration
  44. def clear(self):
  45. self.start_time = 0.
  46. self.diff = 0.
  47. self.duration = 0.
  48. class Detection(object):
  49. """
  50. This class represents a bounding box detection in a single image.
  51. Args:
  52. tlwh (Tensor): Bounding box in format `(top left x, top left y,
  53. width, height)`.
  54. score (Tensor): Bounding box confidence score.
  55. feature (Tensor): A feature vector that describes the object
  56. contained in this image.
  57. cls_id (Tensor): Bounding box category id.
  58. """
  59. def __init__(self, tlwh, score, feature, cls_id):
  60. self.tlwh = np.asarray(tlwh, dtype=np.float32)
  61. self.score = float(score)
  62. self.feature = np.asarray(feature, dtype=np.float32)
  63. self.cls_id = int(cls_id)
  64. def to_tlbr(self):
  65. """
  66. Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  67. `(top left, bottom right)`.
  68. """
  69. ret = self.tlwh.copy()
  70. ret[2:] += ret[:2]
  71. return ret
  72. def to_xyah(self):
  73. """
  74. Convert bounding box to format `(center x, center y, aspect ratio,
  75. height)`, where the aspect ratio is `width / height`.
  76. """
  77. ret = self.tlwh.copy()
  78. ret[:2] += ret[2:] / 2
  79. ret[2] /= ret[3]
  80. return ret
  81. def write_mot_results(filename, results, data_type='mot', num_classes=1):
  82. # support single and multi classes
  83. if data_type in ['mot', 'mcmot']:
  84. save_format = '{frame},{id},{x1},{y1},{w},{h},{score},{cls_id},-1,-1\n'
  85. elif data_type == 'kitti':
  86. save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
  87. else:
  88. raise ValueError(data_type)
  89. f = open(filename, 'w')
  90. for cls_id in range(num_classes):
  91. for frame_id, tlwhs, tscores, track_ids in results[cls_id]:
  92. if data_type == 'kitti':
  93. frame_id -= 1
  94. for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
  95. if track_id < 0: continue
  96. if data_type == 'mot':
  97. cls_id = -1
  98. x1, y1, w, h = tlwh
  99. x2, y2 = x1 + w, y1 + h
  100. line = save_format.format(
  101. frame=frame_id,
  102. id=track_id,
  103. x1=x1,
  104. y1=y1,
  105. x2=x2,
  106. y2=y2,
  107. w=w,
  108. h=h,
  109. score=score,
  110. cls_id=cls_id)
  111. f.write(line)
  112. print('MOT results save in {}'.format(filename))
  113. def load_det_results(det_file, num_frames):
  114. assert os.path.exists(det_file) and os.path.isfile(det_file), \
  115. '{} is not exist or not a file.'.format(det_file)
  116. labels = np.loadtxt(det_file, dtype='float32', delimiter=',')
  117. assert labels.shape[1] == 7, \
  118. "Each line of {} should have 7 items: '[frame_id],[x0],[y0],[w],[h],[score],[class_id]'.".format(det_file)
  119. results_list = []
  120. for frame_i in range(num_frames):
  121. results = {'bbox': [], 'score': [], 'cls_id': []}
  122. lables_with_frame = labels[labels[:, 0] == frame_i + 1]
  123. # each line of lables_with_frame:
  124. # [frame_id],[x0],[y0],[w],[h],[score],[class_id]
  125. for l in lables_with_frame:
  126. results['bbox'].append(l[1:5])
  127. results['score'].append(l[5:6])
  128. results['cls_id'].append(l[6:7])
  129. results_list.append(results)
  130. return results_list
  131. def scale_coords(coords, input_shape, im_shape, scale_factor):
  132. # Note: ratio has only one value, scale_factor[0] == scale_factor[1]
  133. #
  134. # This function only used for JDE YOLOv3 or other detectors with
  135. # LetterBoxResize and JDEBBoxPostProcess, coords output from detector had
  136. # not scaled back to the origin image.
  137. ratio = scale_factor[0]
  138. pad_w = (input_shape[1] - int(im_shape[1])) / 2
  139. pad_h = (input_shape[0] - int(im_shape[0])) / 2
  140. coords[:, 0::2] -= pad_w
  141. coords[:, 1::2] -= pad_h
  142. coords[:, 0:4] /= ratio
  143. coords[:, :4] = np.clip(coords[:, :4], a_min=0, a_max=coords[:, :4].max())
  144. return coords.round()
  145. def clip_box(xyxy, ori_image_shape):
  146. H, W = ori_image_shape
  147. xyxy[:, 0::2] = np.clip(xyxy[:, 0::2], a_min=0, a_max=W)
  148. xyxy[:, 1::2] = np.clip(xyxy[:, 1::2], a_min=0, a_max=H)
  149. w = xyxy[:, 2:3] - xyxy[:, 0:1]
  150. h = xyxy[:, 3:4] - xyxy[:, 1:2]
  151. mask = np.logical_and(h > 0, w > 0)
  152. keep_idx = np.nonzero(mask)
  153. return xyxy[keep_idx[0]], keep_idx
  154. def get_crops(xyxy, ori_img, w, h):
  155. crops = []
  156. xyxy = xyxy.astype(np.int64)
  157. ori_img = ori_img.transpose(1, 0, 2) # [h,w,3]->[w,h,3]
  158. for i, bbox in enumerate(xyxy):
  159. crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :]
  160. crops.append(crop)
  161. crops = preprocess_reid(crops, w, h)
  162. return crops
  163. def preprocess_reid(imgs,
  164. w=64,
  165. h=192,
  166. mean=[0.485, 0.456, 0.406],
  167. std=[0.229, 0.224, 0.225]):
  168. im_batch = []
  169. for img in imgs:
  170. img = cv2.resize(img, (w, h))
  171. img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
  172. img_mean = np.array(mean).reshape((3, 1, 1))
  173. img_std = np.array(std).reshape((3, 1, 1))
  174. img -= img_mean
  175. img /= img_std
  176. img = np.expand_dims(img, axis=0)
  177. im_batch.append(img)
  178. im_batch = np.concatenate(im_batch, 0)
  179. return im_batch
  180. def flow_statistic(result,
  181. secs_interval,
  182. do_entrance_counting,
  183. video_fps,
  184. entrance,
  185. id_set,
  186. interval_id_set,
  187. in_id_list,
  188. out_id_list,
  189. prev_center,
  190. records,
  191. data_type='mot',
  192. num_classes=1):
  193. # Count in and out number:
  194. # Use horizontal center line as the entrance just for simplification.
  195. # If a person located in the above the horizontal center line
  196. # at the previous frame and is in the below the line at the current frame,
  197. # the in number is increased by one.
  198. # If a person was in the below the horizontal center line
  199. # at the previous frame and locates in the below the line at the current frame,
  200. # the out number is increased by one.
  201. # TODO: if the entrance is not the horizontal center line,
  202. # the counting method should be optimized.
  203. if do_entrance_counting:
  204. entrance_y = entrance[1] # xmin, ymin, xmax, ymax
  205. frame_id, tlwhs, tscores, track_ids = result
  206. for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
  207. if track_id < 0: continue
  208. if data_type == 'kitti':
  209. frame_id -= 1
  210. x1, y1, w, h = tlwh
  211. center_x = x1 + w / 2.
  212. center_y = y1 + h / 2.
  213. if track_id in prev_center:
  214. if prev_center[track_id][1] <= entrance_y and \
  215. center_y > entrance_y:
  216. in_id_list.append(track_id)
  217. if prev_center[track_id][1] >= entrance_y and \
  218. center_y < entrance_y:
  219. out_id_list.append(track_id)
  220. prev_center[track_id][0] = center_x
  221. prev_center[track_id][1] = center_y
  222. else:
  223. prev_center[track_id] = [center_x, center_y]
  224. # Count totol number, number at a manual-setting interval
  225. frame_id, tlwhs, tscores, track_ids = result
  226. for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
  227. if track_id < 0: continue
  228. id_set.add(track_id)
  229. interval_id_set.add(track_id)
  230. # Reset counting at the interval beginning
  231. if frame_id % video_fps == 0 and frame_id / video_fps % secs_interval == 0:
  232. curr_interval_count = len(interval_id_set)
  233. interval_id_set.clear()
  234. info = "Frame id: {}, Total count: {}".format(frame_id, len(id_set))
  235. if do_entrance_counting:
  236. info += ", In count: {}, Out count: {}".format(
  237. len(in_id_list), len(out_id_list))
  238. if frame_id % video_fps == 0 and frame_id / video_fps % secs_interval == 0:
  239. info += ", Count during {} secs: {}".format(secs_interval,
  240. curr_interval_count)
  241. interval_id_set.clear()
  242. print(info)
  243. info += "\n"
  244. records.append(info)
  245. return {
  246. "id_set": id_set,
  247. "interval_id_set": interval_id_set,
  248. "in_id_list": in_id_list,
  249. "out_id_list": out_id_list,
  250. "prev_center": prev_center,
  251. "records": records
  252. }