keypoint_metrics.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. from collections import defaultdict, OrderedDict
  17. import numpy as np
  18. from pycocotools.coco import COCO
  19. from pycocotools.cocoeval import COCOeval
  20. from ..modeling.keypoint_utils import oks_nms
  21. from scipy.io import loadmat, savemat
  22. from ppdet.utils.logger import setup_logger
  23. logger = setup_logger(__name__)
  24. __all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval']
  25. class KeyPointTopDownCOCOEval(object):
  26. """refer to
  27. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  28. Copyright (c) Microsoft, under the MIT License.
  29. """
  30. def __init__(self,
  31. anno_file,
  32. num_samples,
  33. num_joints,
  34. output_eval,
  35. iou_type='keypoints',
  36. in_vis_thre=0.2,
  37. oks_thre=0.9,
  38. save_prediction_only=False):
  39. super(KeyPointTopDownCOCOEval, self).__init__()
  40. self.coco = COCO(anno_file)
  41. self.num_samples = num_samples
  42. self.num_joints = num_joints
  43. self.iou_type = iou_type
  44. self.in_vis_thre = in_vis_thre
  45. self.oks_thre = oks_thre
  46. self.output_eval = output_eval
  47. self.res_file = os.path.join(output_eval, "keypoints_results.json")
  48. self.save_prediction_only = save_prediction_only
  49. self.reset()
  50. def reset(self):
  51. self.results = {
  52. 'all_preds': np.zeros(
  53. (self.num_samples, self.num_joints, 3), dtype=np.float32),
  54. 'all_boxes': np.zeros((self.num_samples, 6)),
  55. 'image_path': []
  56. }
  57. self.eval_results = {}
  58. self.idx = 0
  59. def update(self, inputs, outputs):
  60. kpts, _ = outputs['keypoint'][0]
  61. num_images = inputs['image'].shape[0]
  62. self.results['all_preds'][self.idx:self.idx + num_images, :, 0:
  63. 3] = kpts[:, :, 0:3]
  64. self.results['all_boxes'][self.idx:self.idx + num_images, 0:2] = inputs[
  65. 'center'].numpy()[:, 0:2]
  66. self.results['all_boxes'][self.idx:self.idx + num_images, 2:4] = inputs[
  67. 'scale'].numpy()[:, 0:2]
  68. self.results['all_boxes'][self.idx:self.idx + num_images, 4] = np.prod(
  69. inputs['scale'].numpy() * 200, 1)
  70. self.results['all_boxes'][self.idx:self.idx + num_images,
  71. 5] = np.squeeze(inputs['score'].numpy())
  72. self.results['image_path'].extend(inputs['im_id'].numpy())
  73. self.idx += num_images
  74. def _write_coco_keypoint_results(self, keypoints):
  75. data_pack = [{
  76. 'cat_id': 1,
  77. 'cls': 'person',
  78. 'ann_type': 'keypoints',
  79. 'keypoints': keypoints
  80. }]
  81. results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
  82. if not os.path.exists(self.output_eval):
  83. os.makedirs(self.output_eval)
  84. with open(self.res_file, 'w') as f:
  85. json.dump(results, f, sort_keys=True, indent=4)
  86. logger.info(f'The keypoint result is saved to {self.res_file}.')
  87. try:
  88. json.load(open(self.res_file))
  89. except Exception:
  90. content = []
  91. with open(self.res_file, 'r') as f:
  92. for line in f:
  93. content.append(line)
  94. content[-1] = ']'
  95. with open(self.res_file, 'w') as f:
  96. for c in content:
  97. f.write(c)
  98. def _coco_keypoint_results_one_category_kernel(self, data_pack):
  99. cat_id = data_pack['cat_id']
  100. keypoints = data_pack['keypoints']
  101. cat_results = []
  102. for img_kpts in keypoints:
  103. if len(img_kpts) == 0:
  104. continue
  105. _key_points = np.array(
  106. [img_kpts[k]['keypoints'] for k in range(len(img_kpts))])
  107. _key_points = _key_points.reshape(_key_points.shape[0], -1)
  108. result = [{
  109. 'image_id': img_kpts[k]['image'],
  110. 'category_id': cat_id,
  111. 'keypoints': _key_points[k].tolist(),
  112. 'score': img_kpts[k]['score'],
  113. 'center': list(img_kpts[k]['center']),
  114. 'scale': list(img_kpts[k]['scale'])
  115. } for k in range(len(img_kpts))]
  116. cat_results.extend(result)
  117. return cat_results
  118. def get_final_results(self, preds, all_boxes, img_path):
  119. _kpts = []
  120. for idx, kpt in enumerate(preds):
  121. _kpts.append({
  122. 'keypoints': kpt,
  123. 'center': all_boxes[idx][0:2],
  124. 'scale': all_boxes[idx][2:4],
  125. 'area': all_boxes[idx][4],
  126. 'score': all_boxes[idx][5],
  127. 'image': int(img_path[idx])
  128. })
  129. # image x person x (keypoints)
  130. kpts = defaultdict(list)
  131. for kpt in _kpts:
  132. kpts[kpt['image']].append(kpt)
  133. # rescoring and oks nms
  134. num_joints = preds.shape[1]
  135. in_vis_thre = self.in_vis_thre
  136. oks_thre = self.oks_thre
  137. oks_nmsed_kpts = []
  138. for img in kpts.keys():
  139. img_kpts = kpts[img]
  140. for n_p in img_kpts:
  141. box_score = n_p['score']
  142. kpt_score = 0
  143. valid_num = 0
  144. for n_jt in range(0, num_joints):
  145. t_s = n_p['keypoints'][n_jt][2]
  146. if t_s > in_vis_thre:
  147. kpt_score = kpt_score + t_s
  148. valid_num = valid_num + 1
  149. if valid_num != 0:
  150. kpt_score = kpt_score / valid_num
  151. # rescoring
  152. n_p['score'] = kpt_score * box_score
  153. keep = oks_nms([img_kpts[i] for i in range(len(img_kpts))],
  154. oks_thre)
  155. if len(keep) == 0:
  156. oks_nmsed_kpts.append(img_kpts)
  157. else:
  158. oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep])
  159. self._write_coco_keypoint_results(oks_nmsed_kpts)
  160. def accumulate(self):
  161. self.get_final_results(self.results['all_preds'],
  162. self.results['all_boxes'],
  163. self.results['image_path'])
  164. if self.save_prediction_only:
  165. logger.info(f'The keypoint result is saved to {self.res_file} '
  166. 'and do not evaluate the mAP.')
  167. return
  168. coco_dt = self.coco.loadRes(self.res_file)
  169. coco_eval = COCOeval(self.coco, coco_dt, 'keypoints')
  170. coco_eval.params.useSegm = None
  171. coco_eval.evaluate()
  172. coco_eval.accumulate()
  173. coco_eval.summarize()
  174. keypoint_stats = []
  175. for ind in range(len(coco_eval.stats)):
  176. keypoint_stats.append((coco_eval.stats[ind]))
  177. self.eval_results['keypoint'] = keypoint_stats
  178. def log(self):
  179. if self.save_prediction_only:
  180. return
  181. stats_names = [
  182. 'AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5',
  183. 'AR .75', 'AR (M)', 'AR (L)'
  184. ]
  185. num_values = len(stats_names)
  186. print(' '.join(['| {}'.format(name) for name in stats_names]) + ' |')
  187. print('|---' * (num_values + 1) + '|')
  188. print(' '.join([
  189. '| {:.3f}'.format(value) for value in self.eval_results['keypoint']
  190. ]) + ' |')
  191. def get_results(self):
  192. return self.eval_results
  193. class KeyPointTopDownMPIIEval(object):
  194. def __init__(self,
  195. anno_file,
  196. num_samples,
  197. num_joints,
  198. output_eval,
  199. oks_thre=0.9,
  200. save_prediction_only=False):
  201. super(KeyPointTopDownMPIIEval, self).__init__()
  202. self.ann_file = anno_file
  203. self.res_file = os.path.join(output_eval, "keypoints_results.json")
  204. self.save_prediction_only = save_prediction_only
  205. self.reset()
  206. def reset(self):
  207. self.results = []
  208. self.eval_results = {}
  209. self.idx = 0
  210. def update(self, inputs, outputs):
  211. kpts, _ = outputs['keypoint'][0]
  212. num_images = inputs['image'].shape[0]
  213. results = {}
  214. results['preds'] = kpts[:, :, 0:3]
  215. results['boxes'] = np.zeros((num_images, 6))
  216. results['boxes'][:, 0:2] = inputs['center'].numpy()[:, 0:2]
  217. results['boxes'][:, 2:4] = inputs['scale'].numpy()[:, 0:2]
  218. results['boxes'][:, 4] = np.prod(inputs['scale'].numpy() * 200, 1)
  219. results['boxes'][:, 5] = np.squeeze(inputs['score'].numpy())
  220. results['image_path'] = inputs['image_file']
  221. self.results.append(results)
  222. def accumulate(self):
  223. self._mpii_keypoint_results_save()
  224. if self.save_prediction_only:
  225. logger.info(f'The keypoint result is saved to {self.res_file} '
  226. 'and do not evaluate the mAP.')
  227. return
  228. self.eval_results = self.evaluate(self.results)
  229. def _mpii_keypoint_results_save(self):
  230. results = []
  231. for res in self.results:
  232. if len(res) == 0:
  233. continue
  234. result = [{
  235. 'preds': res['preds'][k].tolist(),
  236. 'boxes': res['boxes'][k].tolist(),
  237. 'image_path': res['image_path'][k],
  238. } for k in range(len(res))]
  239. results.extend(result)
  240. with open(self.res_file, 'w') as f:
  241. json.dump(results, f, sort_keys=True, indent=4)
  242. logger.info(f'The keypoint result is saved to {self.res_file}.')
  243. def log(self):
  244. if self.save_prediction_only:
  245. return
  246. for item, value in self.eval_results.items():
  247. print("{} : {}".format(item, value))
  248. def get_results(self):
  249. return self.eval_results
  250. def evaluate(self, outputs, savepath=None):
  251. """Evaluate PCKh for MPII dataset. refer to
  252. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  253. Copyright (c) Microsoft, under the MIT License.
  254. Args:
  255. outputs(list(preds, boxes)):
  256. * preds (np.ndarray[N,K,3]): The first two dimensions are
  257. coordinates, score is the third dimension of the array.
  258. * boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
  259. , scale[1],area, score]
  260. Returns:
  261. dict: PCKh for each joint
  262. """
  263. kpts = []
  264. for output in outputs:
  265. preds = output['preds']
  266. batch_size = preds.shape[0]
  267. for i in range(batch_size):
  268. kpts.append({'keypoints': preds[i]})
  269. preds = np.stack([kpt['keypoints'] for kpt in kpts])
  270. # convert 0-based index to 1-based index,
  271. # and get the first two dimensions.
  272. preds = preds[..., :2] + 1.0
  273. if savepath is not None:
  274. pred_file = os.path.join(savepath, 'pred.mat')
  275. savemat(pred_file, mdict={'preds': preds})
  276. SC_BIAS = 0.6
  277. threshold = 0.5
  278. gt_file = os.path.join(
  279. os.path.dirname(self.ann_file), 'mpii_gt_val.mat')
  280. gt_dict = loadmat(gt_file)
  281. dataset_joints = gt_dict['dataset_joints']
  282. jnt_missing = gt_dict['jnt_missing']
  283. pos_gt_src = gt_dict['pos_gt_src']
  284. headboxes_src = gt_dict['headboxes_src']
  285. pos_pred_src = np.transpose(preds, [1, 2, 0])
  286. head = np.where(dataset_joints == 'head')[1][0]
  287. lsho = np.where(dataset_joints == 'lsho')[1][0]
  288. lelb = np.where(dataset_joints == 'lelb')[1][0]
  289. lwri = np.where(dataset_joints == 'lwri')[1][0]
  290. lhip = np.where(dataset_joints == 'lhip')[1][0]
  291. lkne = np.where(dataset_joints == 'lkne')[1][0]
  292. lank = np.where(dataset_joints == 'lank')[1][0]
  293. rsho = np.where(dataset_joints == 'rsho')[1][0]
  294. relb = np.where(dataset_joints == 'relb')[1][0]
  295. rwri = np.where(dataset_joints == 'rwri')[1][0]
  296. rkne = np.where(dataset_joints == 'rkne')[1][0]
  297. rank = np.where(dataset_joints == 'rank')[1][0]
  298. rhip = np.where(dataset_joints == 'rhip')[1][0]
  299. jnt_visible = 1 - jnt_missing
  300. uv_error = pos_pred_src - pos_gt_src
  301. uv_err = np.linalg.norm(uv_error, axis=1)
  302. headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :]
  303. headsizes = np.linalg.norm(headsizes, axis=0)
  304. headsizes *= SC_BIAS
  305. scale = headsizes * np.ones((len(uv_err), 1), dtype=np.float32)
  306. scaled_uv_err = uv_err / scale
  307. scaled_uv_err = scaled_uv_err * jnt_visible
  308. jnt_count = np.sum(jnt_visible, axis=1)
  309. less_than_threshold = (scaled_uv_err <= threshold) * jnt_visible
  310. PCKh = 100. * np.sum(less_than_threshold, axis=1) / jnt_count
  311. # save
  312. rng = np.arange(0, 0.5 + 0.01, 0.01)
  313. pckAll = np.zeros((len(rng), 16), dtype=np.float32)
  314. for r, threshold in enumerate(rng):
  315. less_than_threshold = (scaled_uv_err <= threshold) * jnt_visible
  316. pckAll[r, :] = 100. * np.sum(less_than_threshold,
  317. axis=1) / jnt_count
  318. PCKh = np.ma.array(PCKh, mask=False)
  319. PCKh.mask[6:8] = True
  320. jnt_count = np.ma.array(jnt_count, mask=False)
  321. jnt_count.mask[6:8] = True
  322. jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64)
  323. name_value = [ #noqa
  324. ('Head', PCKh[head]),
  325. ('Shoulder', 0.5 * (PCKh[lsho] + PCKh[rsho])),
  326. ('Elbow', 0.5 * (PCKh[lelb] + PCKh[relb])),
  327. ('Wrist', 0.5 * (PCKh[lwri] + PCKh[rwri])),
  328. ('Hip', 0.5 * (PCKh[lhip] + PCKh[rhip])),
  329. ('Knee', 0.5 * (PCKh[lkne] + PCKh[rkne])),
  330. ('Ankle', 0.5 * (PCKh[lank] + PCKh[rank])),
  331. ('PCKh', np.sum(PCKh * jnt_ratio)),
  332. ('PCKh@0.1', np.sum(pckAll[11, :] * jnt_ratio))
  333. ]
  334. name_value = OrderedDict(name_value)
  335. return name_value
  336. def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
  337. """sort kpts and remove the repeated ones."""
  338. kpts = sorted(kpts, key=lambda x: x[key])
  339. num = len(kpts)
  340. for i in range(num - 1, 0, -1):
  341. if kpts[i][key] == kpts[i - 1][key]:
  342. del kpts[i]
  343. return kpts