keypoint_postprocess.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from scipy.optimize import linear_sum_assignment
  15. from collections import abc, defaultdict
  16. import cv2
  17. import numpy as np
  18. import math
  19. import paddle
  20. import paddle.nn as nn
  21. from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform
  22. class HrHRNetPostProcess(object):
  23. """
  24. HrHRNet postprocess contain:
  25. 1) get topk keypoints in the output heatmap
  26. 2) sample the tagmap's value corresponding to each of the topk coordinate
  27. 3) match different joints to combine to some people with Hungary algorithm
  28. 4) adjust the coordinate by +-0.25 to decrease error std
  29. 5) salvage missing joints by check positivity of heatmap - tagdiff_norm
  30. Args:
  31. max_num_people (int): max number of people support in postprocess
  32. heat_thresh (float): value of topk below this threshhold will be ignored
  33. tag_thresh (float): coord's value sampled in tagmap below this threshold belong to same people for init
  34. inputs(list[heatmap]): the output list of model, [heatmap, heatmap_maxpool, tagmap], heatmap_maxpool used to get topk
  35. original_height, original_width (float): the original image size
  36. """
  37. def __init__(self, max_num_people=30, heat_thresh=0.2, tag_thresh=1.):
  38. self.max_num_people = max_num_people
  39. self.heat_thresh = heat_thresh
  40. self.tag_thresh = tag_thresh
  41. def lerp(self, j, y, x, heatmap):
  42. H, W = heatmap.shape[-2:]
  43. left = np.clip(x - 1, 0, W - 1)
  44. right = np.clip(x + 1, 0, W - 1)
  45. up = np.clip(y - 1, 0, H - 1)
  46. down = np.clip(y + 1, 0, H - 1)
  47. offset_y = np.where(heatmap[j, down, x] > heatmap[j, up, x], 0.25,
  48. -0.25)
  49. offset_x = np.where(heatmap[j, y, right] > heatmap[j, y, left], 0.25,
  50. -0.25)
  51. return offset_y + 0.5, offset_x + 0.5
  52. def __call__(self, heatmap, tagmap, heat_k, inds_k, original_height,
  53. original_width):
  54. N, J, H, W = heatmap.shape
  55. assert N == 1, "only support batch size 1"
  56. heatmap = heatmap[0]
  57. tagmap = tagmap[0]
  58. heats = heat_k[0]
  59. inds_np = inds_k[0]
  60. y = inds_np // W
  61. x = inds_np % W
  62. tags = tagmap[np.arange(J)[None, :].repeat(self.max_num_people),
  63. y.flatten(), x.flatten()].reshape(J, -1, tagmap.shape[-1])
  64. coords = np.stack((y, x), axis=2)
  65. # threshold
  66. mask = heats > self.heat_thresh
  67. # cluster
  68. cluster = defaultdict(lambda: {
  69. 'coords': np.zeros((J, 2), dtype=np.float32),
  70. 'scores': np.zeros(J, dtype=np.float32),
  71. 'tags': []
  72. })
  73. for jid, m in enumerate(mask):
  74. num_valid = m.sum()
  75. if num_valid == 0:
  76. continue
  77. valid_inds = np.where(m)[0]
  78. valid_tags = tags[jid, m, :]
  79. if len(cluster) == 0: # initialize
  80. for i in valid_inds:
  81. tag = tags[jid, i]
  82. key = tag[0]
  83. cluster[key]['tags'].append(tag)
  84. cluster[key]['scores'][jid] = heats[jid, i]
  85. cluster[key]['coords'][jid] = coords[jid, i]
  86. continue
  87. candidates = list(cluster.keys())[:self.max_num_people]
  88. centroids = [
  89. np.mean(
  90. cluster[k]['tags'], axis=0) for k in candidates
  91. ]
  92. num_clusters = len(centroids)
  93. # shape is (num_valid, num_clusters, tag_dim)
  94. dist = valid_tags[:, None, :] - np.array(centroids)[None, ...]
  95. l2_dist = np.linalg.norm(dist, ord=2, axis=2)
  96. # modulate dist with heat value, see `use_detection_val`
  97. cost = np.round(l2_dist) * 100 - heats[jid, m, None]
  98. # pad the cost matrix, otherwise new pose are ignored
  99. if num_valid > num_clusters:
  100. cost = np.pad(cost, ((0, 0), (0, num_valid - num_clusters)),
  101. 'constant',
  102. constant_values=((0, 0), (0, 1e-10)))
  103. rows, cols = linear_sum_assignment(cost)
  104. for y, x in zip(rows, cols):
  105. tag = tags[jid, y]
  106. if y < num_valid and x < num_clusters and \
  107. l2_dist[y, x] < self.tag_thresh:
  108. key = candidates[x] # merge to cluster
  109. else:
  110. key = tag[0] # initialize new cluster
  111. cluster[key]['tags'].append(tag)
  112. cluster[key]['scores'][jid] = heats[jid, y]
  113. cluster[key]['coords'][jid] = coords[jid, y]
  114. # shape is [k, J, 2] and [k, J]
  115. pose_tags = np.array([cluster[k]['tags'] for k in cluster])
  116. pose_coords = np.array([cluster[k]['coords'] for k in cluster])
  117. pose_scores = np.array([cluster[k]['scores'] for k in cluster])
  118. valid = pose_scores > 0
  119. pose_kpts = np.zeros((pose_scores.shape[0], J, 3), dtype=np.float32)
  120. if valid.sum() == 0:
  121. return pose_kpts, pose_kpts
  122. # refine coords
  123. valid_coords = pose_coords[valid].astype(np.int32)
  124. y = valid_coords[..., 0].flatten()
  125. x = valid_coords[..., 1].flatten()
  126. _, j = np.nonzero(valid)
  127. offsets = self.lerp(j, y, x, heatmap)
  128. pose_coords[valid, 0] += offsets[0]
  129. pose_coords[valid, 1] += offsets[1]
  130. # mean score before salvage
  131. mean_score = pose_scores.mean(axis=1)
  132. pose_kpts[valid, 2] = pose_scores[valid]
  133. # salvage missing joints
  134. if True:
  135. for pid, coords in enumerate(pose_coords):
  136. tag_mean = np.array(pose_tags[pid]).mean(axis=0)
  137. norm = np.sum((tagmap - tag_mean)**2, axis=3)**0.5
  138. score = heatmap - np.round(norm) # (J, H, W)
  139. flat_score = score.reshape(J, -1)
  140. max_inds = np.argmax(flat_score, axis=1)
  141. max_scores = np.max(flat_score, axis=1)
  142. salvage_joints = (pose_scores[pid] == 0) & (max_scores > 0)
  143. if salvage_joints.sum() == 0:
  144. continue
  145. y = max_inds[salvage_joints] // W
  146. x = max_inds[salvage_joints] % W
  147. offsets = self.lerp(salvage_joints.nonzero()[0], y, x, heatmap)
  148. y = y.astype(np.float32) + offsets[0]
  149. x = x.astype(np.float32) + offsets[1]
  150. pose_coords[pid][salvage_joints, 0] = y
  151. pose_coords[pid][salvage_joints, 1] = x
  152. pose_kpts[pid][salvage_joints, 2] = max_scores[salvage_joints]
  153. pose_kpts[..., :2] = transpred(pose_coords[..., :2][..., ::-1],
  154. original_height, original_width,
  155. min(H, W))
  156. return pose_kpts, mean_score
  157. def transpred(kpts, h, w, s):
  158. trans, _ = get_affine_mat_kernel(h, w, s, inv=True)
  159. return warp_affine_joints(kpts[..., :2].copy(), trans)
  160. def warp_affine_joints(joints, mat):
  161. """Apply affine transformation defined by the transform matrix on the
  162. joints.
  163. Args:
  164. joints (np.ndarray[..., 2]): Origin coordinate of joints.
  165. mat (np.ndarray[3, 2]): The affine matrix.
  166. Returns:
  167. matrix (np.ndarray[..., 2]): Result coordinate of joints.
  168. """
  169. joints = np.array(joints)
  170. shape = joints.shape
  171. joints = joints.reshape(-1, 2)
  172. return np.dot(np.concatenate(
  173. (joints, joints[:, 0:1] * 0 + 1), axis=1),
  174. mat.T).reshape(shape)
  175. class HRNetPostProcess(object):
  176. def __init__(self, use_dark=True):
  177. self.use_dark = use_dark
  178. def flip_back(self, output_flipped, matched_parts):
  179. assert output_flipped.ndim == 4,\
  180. 'output_flipped should be [batch_size, num_joints, height, width]'
  181. output_flipped = output_flipped[:, :, :, ::-1]
  182. for pair in matched_parts:
  183. tmp = output_flipped[:, pair[0], :, :].copy()
  184. output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
  185. output_flipped[:, pair[1], :, :] = tmp
  186. return output_flipped
  187. def get_max_preds(self, heatmaps):
  188. """get predictions from score maps
  189. Args:
  190. heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
  191. Returns:
  192. preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
  193. maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
  194. """
  195. assert isinstance(heatmaps,
  196. np.ndarray), 'heatmaps should be numpy.ndarray'
  197. assert heatmaps.ndim == 4, 'batch_images should be 4-ndim'
  198. batch_size = heatmaps.shape[0]
  199. num_joints = heatmaps.shape[1]
  200. width = heatmaps.shape[3]
  201. heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
  202. idx = np.argmax(heatmaps_reshaped, 2)
  203. maxvals = np.amax(heatmaps_reshaped, 2)
  204. maxvals = maxvals.reshape((batch_size, num_joints, 1))
  205. idx = idx.reshape((batch_size, num_joints, 1))
  206. preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
  207. preds[:, :, 0] = (preds[:, :, 0]) % width
  208. preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
  209. pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
  210. pred_mask = pred_mask.astype(np.float32)
  211. preds *= pred_mask
  212. return preds, maxvals
  213. def gaussian_blur(self, heatmap, kernel):
  214. border = (kernel - 1) // 2
  215. batch_size = heatmap.shape[0]
  216. num_joints = heatmap.shape[1]
  217. height = heatmap.shape[2]
  218. width = heatmap.shape[3]
  219. for i in range(batch_size):
  220. for j in range(num_joints):
  221. origin_max = np.max(heatmap[i, j])
  222. dr = np.zeros((height + 2 * border, width + 2 * border))
  223. dr[border:-border, border:-border] = heatmap[i, j].copy()
  224. dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
  225. heatmap[i, j] = dr[border:-border, border:-border].copy()
  226. heatmap[i, j] *= origin_max / np.max(heatmap[i, j])
  227. return heatmap
  228. def dark_parse(self, hm, coord):
  229. heatmap_height = hm.shape[0]
  230. heatmap_width = hm.shape[1]
  231. px = int(coord[0])
  232. py = int(coord[1])
  233. if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2:
  234. dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1])
  235. dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px])
  236. dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2])
  237. dxy = 0.25 * (hm[py+1][px+1] - hm[py-1][px+1] - hm[py+1][px-1] \
  238. + hm[py-1][px-1])
  239. dyy = 0.25 * (
  240. hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px])
  241. derivative = np.matrix([[dx], [dy]])
  242. hessian = np.matrix([[dxx, dxy], [dxy, dyy]])
  243. if dxx * dyy - dxy**2 != 0:
  244. hessianinv = hessian.I
  245. offset = -hessianinv * derivative
  246. offset = np.squeeze(np.array(offset.T), axis=0)
  247. coord += offset
  248. return coord
  249. def dark_postprocess(self, hm, coords, kernelsize):
  250. """
  251. refer to https://github.com/ilovepose/DarkPose/lib/core/inference.py
  252. """
  253. hm = self.gaussian_blur(hm, kernelsize)
  254. hm = np.maximum(hm, 1e-10)
  255. hm = np.log(hm)
  256. for n in range(coords.shape[0]):
  257. for p in range(coords.shape[1]):
  258. coords[n, p] = self.dark_parse(hm[n][p], coords[n][p])
  259. return coords
  260. def get_final_preds(self, heatmaps, center, scale, kernelsize=3):
  261. """the highest heatvalue location with a quarter offset in the
  262. direction from the highest response to the second highest response.
  263. Args:
  264. heatmaps (numpy.ndarray): The predicted heatmaps
  265. center (numpy.ndarray): The boxes center
  266. scale (numpy.ndarray): The scale factor
  267. Returns:
  268. preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
  269. maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
  270. """
  271. coords, maxvals = self.get_max_preds(heatmaps)
  272. heatmap_height = heatmaps.shape[2]
  273. heatmap_width = heatmaps.shape[3]
  274. if self.use_dark:
  275. coords = self.dark_postprocess(heatmaps, coords, kernelsize)
  276. else:
  277. for n in range(coords.shape[0]):
  278. for p in range(coords.shape[1]):
  279. hm = heatmaps[n][p]
  280. px = int(math.floor(coords[n][p][0] + 0.5))
  281. py = int(math.floor(coords[n][p][1] + 0.5))
  282. if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
  283. diff = np.array([
  284. hm[py][px + 1] - hm[py][px - 1],
  285. hm[py + 1][px] - hm[py - 1][px]
  286. ])
  287. coords[n][p] += np.sign(diff) * .25
  288. preds = coords.copy()
  289. # Transform back
  290. for i in range(coords.shape[0]):
  291. preds[i] = transform_preds(coords[i], center[i], scale[i],
  292. [heatmap_width, heatmap_height])
  293. return preds, maxvals
  294. def __call__(self, output, center, scale):
  295. preds, maxvals = self.get_final_preds(output, center, scale)
  296. return np.concatenate(
  297. (preds, maxvals), axis=-1), np.mean(
  298. maxvals, axis=1)
  299. def transform_preds(coords, center, scale, output_size):
  300. target_coords = np.zeros(coords.shape)
  301. trans = get_affine_transform(center, scale * 200, 0, output_size, inv=1)
  302. for p in range(coords.shape[0]):
  303. target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
  304. return target_coords
  305. def affine_transform(pt, t):
  306. new_pt = np.array([pt[0], pt[1], 1.]).T
  307. new_pt = np.dot(t, new_pt)
  308. return new_pt[:2]
  309. def translate_to_ori_images(keypoint_result, batch_records):
  310. kpts = keypoint_result['keypoint']
  311. scores = keypoint_result['score']
  312. kpts[..., 0] += batch_records[:, 0:1]
  313. kpts[..., 1] += batch_records[:, 1:2]
  314. return kpts, scores