keypoint_hrhrnet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from scipy.optimize import linear_sum_assignment
  18. from collections import abc, defaultdict
  19. import numpy as np
  20. import paddle
  21. from ppdet.core.workspace import register, create, serializable
  22. from .meta_arch import BaseArch
  23. from .. import layers as L
  24. from ..keypoint_utils import transpred
  25. __all__ = ['HigherHRNet']
  26. @register
  27. class HigherHRNet(BaseArch):
  28. __category__ = 'architecture'
  29. def __init__(self,
  30. backbone='HRNet',
  31. hrhrnet_head='HrHRNetHead',
  32. post_process='HrHRNetPostProcess',
  33. eval_flip=True,
  34. flip_perm=None,
  35. max_num_people=30):
  36. """
  37. HigherHRNet network, see https://arxiv.org/abs/1908.10357;
  38. HigherHRNet+swahr, see https://arxiv.org/abs/2012.15175
  39. Args:
  40. backbone (nn.Layer): backbone instance
  41. hrhrnet_head (nn.Layer): keypoint_head instance
  42. bbox_post_process (object): `BBoxPostProcess` instance
  43. """
  44. super(HigherHRNet, self).__init__()
  45. self.backbone = backbone
  46. self.hrhrnet_head = hrhrnet_head
  47. self.post_process = post_process
  48. self.flip = eval_flip
  49. self.flip_perm = paddle.to_tensor(flip_perm)
  50. self.deploy = False
  51. self.interpolate = L.Upsample(2, mode='bilinear')
  52. self.pool = L.MaxPool(5, 1, 2)
  53. self.max_num_people = max_num_people
  54. @classmethod
  55. def from_config(cls, cfg, *args, **kwargs):
  56. # backbone
  57. backbone = create(cfg['backbone'])
  58. # head
  59. kwargs = {'input_shape': backbone.out_shape}
  60. hrhrnet_head = create(cfg['hrhrnet_head'], **kwargs)
  61. post_process = create(cfg['post_process'])
  62. return {
  63. 'backbone': backbone,
  64. "hrhrnet_head": hrhrnet_head,
  65. "post_process": post_process,
  66. }
  67. def _forward(self):
  68. if self.flip and not self.training and not self.deploy:
  69. self.inputs['image'] = paddle.concat(
  70. (self.inputs['image'], paddle.flip(self.inputs['image'], [3])))
  71. body_feats = self.backbone(self.inputs)
  72. if self.training:
  73. return self.hrhrnet_head(body_feats, self.inputs)
  74. else:
  75. outputs = self.hrhrnet_head(body_feats)
  76. if self.flip and not self.deploy:
  77. outputs = [paddle.split(o, 2) for o in outputs]
  78. output_rflip = [
  79. paddle.flip(paddle.gather(o[1], self.flip_perm, 1), [3])
  80. for o in outputs
  81. ]
  82. output1 = [o[0] for o in outputs]
  83. heatmap = (output1[0] + output_rflip[0]) / 2.
  84. tagmaps = [output1[1], output_rflip[1]]
  85. outputs = [heatmap] + tagmaps
  86. outputs = self.get_topk(outputs)
  87. if self.deploy:
  88. return outputs
  89. res_lst = []
  90. h = self.inputs['im_shape'][0, 0].numpy().item()
  91. w = self.inputs['im_shape'][0, 1].numpy().item()
  92. kpts, scores = self.post_process(*outputs, h, w)
  93. res_lst.append([kpts, scores])
  94. return res_lst
  95. def get_loss(self):
  96. return self._forward()
  97. def get_pred(self):
  98. outputs = {}
  99. res_lst = self._forward()
  100. outputs['keypoint'] = res_lst
  101. return outputs
  102. def get_topk(self, outputs):
  103. # resize to image size
  104. outputs = [self.interpolate(x) for x in outputs]
  105. if len(outputs) == 3:
  106. tagmap = paddle.concat(
  107. (outputs[1].unsqueeze(4), outputs[2].unsqueeze(4)), axis=4)
  108. else:
  109. tagmap = outputs[1].unsqueeze(4)
  110. heatmap = outputs[0]
  111. N, J = 1, self.hrhrnet_head.num_joints
  112. heatmap_maxpool = self.pool(heatmap)
  113. # topk
  114. maxmap = heatmap * (heatmap == heatmap_maxpool)
  115. maxmap = maxmap.reshape([N, J, -1])
  116. heat_k, inds_k = maxmap.topk(self.max_num_people, axis=2)
  117. outputs = [heatmap, tagmap, heat_k, inds_k]
  118. return outputs
  119. @register
  120. @serializable
  121. class HrHRNetPostProcess(object):
  122. '''
  123. HrHRNet postprocess contain:
  124. 1) get topk keypoints in the output heatmap
  125. 2) sample the tagmap's value corresponding to each of the topk coordinate
  126. 3) match different joints to combine to some people with Hungary algorithm
  127. 4) adjust the coordinate by +-0.25 to decrease error std
  128. 5) salvage missing joints by check positivity of heatmap - tagdiff_norm
  129. Args:
  130. max_num_people (int): max number of people support in postprocess
  131. heat_thresh (float): value of topk below this threshhold will be ignored
  132. tag_thresh (float): coord's value sampled in tagmap below this threshold belong to same people for init
  133. inputs(list[heatmap]): the output list of model, [heatmap, heatmap_maxpool, tagmap], heatmap_maxpool used to get topk
  134. original_height, original_width (float): the original image size
  135. '''
  136. def __init__(self, max_num_people=30, heat_thresh=0.1, tag_thresh=1.):
  137. self.max_num_people = max_num_people
  138. self.heat_thresh = heat_thresh
  139. self.tag_thresh = tag_thresh
  140. def lerp(self, j, y, x, heatmap):
  141. H, W = heatmap.shape[-2:]
  142. left = np.clip(x - 1, 0, W - 1)
  143. right = np.clip(x + 1, 0, W - 1)
  144. up = np.clip(y - 1, 0, H - 1)
  145. down = np.clip(y + 1, 0, H - 1)
  146. offset_y = np.where(heatmap[j, down, x] > heatmap[j, up, x], 0.25,
  147. -0.25)
  148. offset_x = np.where(heatmap[j, y, right] > heatmap[j, y, left], 0.25,
  149. -0.25)
  150. return offset_y + 0.5, offset_x + 0.5
  151. def __call__(self, heatmap, tagmap, heat_k, inds_k, original_height,
  152. original_width):
  153. N, J, H, W = heatmap.shape
  154. assert N == 1, "only support batch size 1"
  155. heatmap = heatmap[0].cpu().detach().numpy()
  156. tagmap = tagmap[0].cpu().detach().numpy()
  157. heats = heat_k[0].cpu().detach().numpy()
  158. inds_np = inds_k[0].cpu().detach().numpy()
  159. y = inds_np // W
  160. x = inds_np % W
  161. tags = tagmap[np.arange(J)[None, :].repeat(self.max_num_people),
  162. y.flatten(), x.flatten()].reshape(J, -1, tagmap.shape[-1])
  163. coords = np.stack((y, x), axis=2)
  164. # threshold
  165. mask = heats > self.heat_thresh
  166. # cluster
  167. cluster = defaultdict(lambda: {
  168. 'coords': np.zeros((J, 2), dtype=np.float32),
  169. 'scores': np.zeros(J, dtype=np.float32),
  170. 'tags': []
  171. })
  172. for jid, m in enumerate(mask):
  173. num_valid = m.sum()
  174. if num_valid == 0:
  175. continue
  176. valid_inds = np.where(m)[0]
  177. valid_tags = tags[jid, m, :]
  178. if len(cluster) == 0: # initialize
  179. for i in valid_inds:
  180. tag = tags[jid, i]
  181. key = tag[0]
  182. cluster[key]['tags'].append(tag)
  183. cluster[key]['scores'][jid] = heats[jid, i]
  184. cluster[key]['coords'][jid] = coords[jid, i]
  185. continue
  186. candidates = list(cluster.keys())[:self.max_num_people]
  187. centroids = [
  188. np.mean(
  189. cluster[k]['tags'], axis=0) for k in candidates
  190. ]
  191. num_clusters = len(centroids)
  192. # shape is (num_valid, num_clusters, tag_dim)
  193. dist = valid_tags[:, None, :] - np.array(centroids)[None, ...]
  194. l2_dist = np.linalg.norm(dist, ord=2, axis=2)
  195. # modulate dist with heat value, see `use_detection_val`
  196. cost = np.round(l2_dist) * 100 - heats[jid, m, None]
  197. # pad the cost matrix, otherwise new pose are ignored
  198. if num_valid > num_clusters:
  199. cost = np.pad(cost, ((0, 0), (0, num_valid - num_clusters)),
  200. 'constant',
  201. constant_values=((0, 0), (0, 1e-10)))
  202. rows, cols = linear_sum_assignment(cost)
  203. for y, x in zip(rows, cols):
  204. tag = tags[jid, y]
  205. if y < num_valid and x < num_clusters and \
  206. l2_dist[y, x] < self.tag_thresh:
  207. key = candidates[x] # merge to cluster
  208. else:
  209. key = tag[0] # initialize new cluster
  210. cluster[key]['tags'].append(tag)
  211. cluster[key]['scores'][jid] = heats[jid, y]
  212. cluster[key]['coords'][jid] = coords[jid, y]
  213. # shape is [k, J, 2] and [k, J]
  214. pose_tags = np.array([cluster[k]['tags'] for k in cluster])
  215. pose_coords = np.array([cluster[k]['coords'] for k in cluster])
  216. pose_scores = np.array([cluster[k]['scores'] for k in cluster])
  217. valid = pose_scores > 0
  218. pose_kpts = np.zeros((pose_scores.shape[0], J, 3), dtype=np.float32)
  219. if valid.sum() == 0:
  220. return pose_kpts, pose_kpts
  221. # refine coords
  222. valid_coords = pose_coords[valid].astype(np.int32)
  223. y = valid_coords[..., 0].flatten()
  224. x = valid_coords[..., 1].flatten()
  225. _, j = np.nonzero(valid)
  226. offsets = self.lerp(j, y, x, heatmap)
  227. pose_coords[valid, 0] += offsets[0]
  228. pose_coords[valid, 1] += offsets[1]
  229. # mean score before salvage
  230. mean_score = pose_scores.mean(axis=1)
  231. pose_kpts[valid, 2] = pose_scores[valid]
  232. # salvage missing joints
  233. if True:
  234. for pid, coords in enumerate(pose_coords):
  235. tag_mean = np.array(pose_tags[pid]).mean(axis=0)
  236. norm = np.sum((tagmap - tag_mean)**2, axis=3)**0.5
  237. score = heatmap - np.round(norm) # (J, H, W)
  238. flat_score = score.reshape(J, -1)
  239. max_inds = np.argmax(flat_score, axis=1)
  240. max_scores = np.max(flat_score, axis=1)
  241. salvage_joints = (pose_scores[pid] == 0) & (max_scores > 0)
  242. if salvage_joints.sum() == 0:
  243. continue
  244. y = max_inds[salvage_joints] // W
  245. x = max_inds[salvage_joints] % W
  246. offsets = self.lerp(salvage_joints.nonzero()[0], y, x, heatmap)
  247. y = y.astype(np.float32) + offsets[0]
  248. x = x.astype(np.float32) + offsets[1]
  249. pose_coords[pid][salvage_joints, 0] = y
  250. pose_coords[pid][salvage_joints, 1] = x
  251. pose_kpts[pid][salvage_joints, 2] = max_scores[salvage_joints]
  252. pose_kpts[..., :2] = transpred(pose_coords[..., :2][..., ::-1],
  253. original_height, original_width,
  254. min(H, W))
  255. return pose_kpts, mean_score