123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import json
- from collections import defaultdict, OrderedDict
- import numpy as np
- from pycocotools.coco import COCO
- from pycocotools.cocoeval import COCOeval
- from ..modeling.keypoint_utils import oks_nms
- from scipy.io import loadmat, savemat
- from ppdet.utils.logger import setup_logger
- logger = setup_logger(__name__)
- __all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval']
- class KeyPointTopDownCOCOEval(object):
- """refer to
- https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
- Copyright (c) Microsoft, under the MIT License.
- """
- def __init__(self,
- anno_file,
- num_samples,
- num_joints,
- output_eval,
- iou_type='keypoints',
- in_vis_thre=0.2,
- oks_thre=0.9,
- save_prediction_only=False):
- super(KeyPointTopDownCOCOEval, self).__init__()
- self.coco = COCO(anno_file)
- self.num_samples = num_samples
- self.num_joints = num_joints
- self.iou_type = iou_type
- self.in_vis_thre = in_vis_thre
- self.oks_thre = oks_thre
- self.output_eval = output_eval
- self.res_file = os.path.join(output_eval, "keypoints_results.json")
- self.save_prediction_only = save_prediction_only
- self.reset()
- def reset(self):
- self.results = {
- 'all_preds': np.zeros(
- (self.num_samples, self.num_joints, 3), dtype=np.float32),
- 'all_boxes': np.zeros((self.num_samples, 6)),
- 'image_path': []
- }
- self.eval_results = {}
- self.idx = 0
- def update(self, inputs, outputs):
- kpts, _ = outputs['keypoint'][0]
- num_images = inputs['image'].shape[0]
- self.results['all_preds'][self.idx:self.idx + num_images, :, 0:
- 3] = kpts[:, :, 0:3]
- self.results['all_boxes'][self.idx:self.idx + num_images, 0:2] = inputs[
- 'center'].numpy()[:, 0:2]
- self.results['all_boxes'][self.idx:self.idx + num_images, 2:4] = inputs[
- 'scale'].numpy()[:, 0:2]
- self.results['all_boxes'][self.idx:self.idx + num_images, 4] = np.prod(
- inputs['scale'].numpy() * 200, 1)
- self.results['all_boxes'][self.idx:self.idx + num_images,
- 5] = np.squeeze(inputs['score'].numpy())
- self.results['image_path'].extend(inputs['im_id'].numpy())
- self.idx += num_images
- def _write_coco_keypoint_results(self, keypoints):
- data_pack = [{
- 'cat_id': 1,
- 'cls': 'person',
- 'ann_type': 'keypoints',
- 'keypoints': keypoints
- }]
- results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
- if not os.path.exists(self.output_eval):
- os.makedirs(self.output_eval)
- with open(self.res_file, 'w') as f:
- json.dump(results, f, sort_keys=True, indent=4)
- logger.info(f'The keypoint result is saved to {self.res_file}.')
- try:
- json.load(open(self.res_file))
- except Exception:
- content = []
- with open(self.res_file, 'r') as f:
- for line in f:
- content.append(line)
- content[-1] = ']'
- with open(self.res_file, 'w') as f:
- for c in content:
- f.write(c)
- def _coco_keypoint_results_one_category_kernel(self, data_pack):
- cat_id = data_pack['cat_id']
- keypoints = data_pack['keypoints']
- cat_results = []
- for img_kpts in keypoints:
- if len(img_kpts) == 0:
- continue
- _key_points = np.array(
- [img_kpts[k]['keypoints'] for k in range(len(img_kpts))])
- _key_points = _key_points.reshape(_key_points.shape[0], -1)
- result = [{
- 'image_id': img_kpts[k]['image'],
- 'category_id': cat_id,
- 'keypoints': _key_points[k].tolist(),
- 'score': img_kpts[k]['score'],
- 'center': list(img_kpts[k]['center']),
- 'scale': list(img_kpts[k]['scale'])
- } for k in range(len(img_kpts))]
- cat_results.extend(result)
- return cat_results
- def get_final_results(self, preds, all_boxes, img_path):
- _kpts = []
- for idx, kpt in enumerate(preds):
- _kpts.append({
- 'keypoints': kpt,
- 'center': all_boxes[idx][0:2],
- 'scale': all_boxes[idx][2:4],
- 'area': all_boxes[idx][4],
- 'score': all_boxes[idx][5],
- 'image': int(img_path[idx])
- })
- # image x person x (keypoints)
- kpts = defaultdict(list)
- for kpt in _kpts:
- kpts[kpt['image']].append(kpt)
- # rescoring and oks nms
- num_joints = preds.shape[1]
- in_vis_thre = self.in_vis_thre
- oks_thre = self.oks_thre
- oks_nmsed_kpts = []
- for img in kpts.keys():
- img_kpts = kpts[img]
- for n_p in img_kpts:
- box_score = n_p['score']
- kpt_score = 0
- valid_num = 0
- for n_jt in range(0, num_joints):
- t_s = n_p['keypoints'][n_jt][2]
- if t_s > in_vis_thre:
- kpt_score = kpt_score + t_s
- valid_num = valid_num + 1
- if valid_num != 0:
- kpt_score = kpt_score / valid_num
- # rescoring
- n_p['score'] = kpt_score * box_score
- keep = oks_nms([img_kpts[i] for i in range(len(img_kpts))],
- oks_thre)
- if len(keep) == 0:
- oks_nmsed_kpts.append(img_kpts)
- else:
- oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep])
- self._write_coco_keypoint_results(oks_nmsed_kpts)
- def accumulate(self):
- self.get_final_results(self.results['all_preds'],
- self.results['all_boxes'],
- self.results['image_path'])
- if self.save_prediction_only:
- logger.info(f'The keypoint result is saved to {self.res_file} '
- 'and do not evaluate the mAP.')
- return
- coco_dt = self.coco.loadRes(self.res_file)
- coco_eval = COCOeval(self.coco, coco_dt, 'keypoints')
- coco_eval.params.useSegm = None
- coco_eval.evaluate()
- coco_eval.accumulate()
- coco_eval.summarize()
- keypoint_stats = []
- for ind in range(len(coco_eval.stats)):
- keypoint_stats.append((coco_eval.stats[ind]))
- self.eval_results['keypoint'] = keypoint_stats
- def log(self):
- if self.save_prediction_only:
- return
- stats_names = [
- 'AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5',
- 'AR .75', 'AR (M)', 'AR (L)'
- ]
- num_values = len(stats_names)
- print(' '.join(['| {}'.format(name) for name in stats_names]) + ' |')
- print('|---' * (num_values + 1) + '|')
- print(' '.join([
- '| {:.3f}'.format(value) for value in self.eval_results['keypoint']
- ]) + ' |')
- def get_results(self):
- return self.eval_results
- class KeyPointTopDownMPIIEval(object):
- def __init__(self,
- anno_file,
- num_samples,
- num_joints,
- output_eval,
- oks_thre=0.9,
- save_prediction_only=False):
- super(KeyPointTopDownMPIIEval, self).__init__()
- self.ann_file = anno_file
- self.res_file = os.path.join(output_eval, "keypoints_results.json")
- self.save_prediction_only = save_prediction_only
- self.reset()
- def reset(self):
- self.results = []
- self.eval_results = {}
- self.idx = 0
- def update(self, inputs, outputs):
- kpts, _ = outputs['keypoint'][0]
- num_images = inputs['image'].shape[0]
- results = {}
- results['preds'] = kpts[:, :, 0:3]
- results['boxes'] = np.zeros((num_images, 6))
- results['boxes'][:, 0:2] = inputs['center'].numpy()[:, 0:2]
- results['boxes'][:, 2:4] = inputs['scale'].numpy()[:, 0:2]
- results['boxes'][:, 4] = np.prod(inputs['scale'].numpy() * 200, 1)
- results['boxes'][:, 5] = np.squeeze(inputs['score'].numpy())
- results['image_path'] = inputs['image_file']
- self.results.append(results)
- def accumulate(self):
- self._mpii_keypoint_results_save()
- if self.save_prediction_only:
- logger.info(f'The keypoint result is saved to {self.res_file} '
- 'and do not evaluate the mAP.')
- return
- self.eval_results = self.evaluate(self.results)
- def _mpii_keypoint_results_save(self):
- results = []
- for res in self.results:
- if len(res) == 0:
- continue
- result = [{
- 'preds': res['preds'][k].tolist(),
- 'boxes': res['boxes'][k].tolist(),
- 'image_path': res['image_path'][k],
- } for k in range(len(res))]
- results.extend(result)
- with open(self.res_file, 'w') as f:
- json.dump(results, f, sort_keys=True, indent=4)
- logger.info(f'The keypoint result is saved to {self.res_file}.')
- def log(self):
- if self.save_prediction_only:
- return
- for item, value in self.eval_results.items():
- print("{} : {}".format(item, value))
- def get_results(self):
- return self.eval_results
- def evaluate(self, outputs, savepath=None):
- """Evaluate PCKh for MPII dataset. refer to
- https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
- Copyright (c) Microsoft, under the MIT License.
- Args:
- outputs(list(preds, boxes)):
- * preds (np.ndarray[N,K,3]): The first two dimensions are
- coordinates, score is the third dimension of the array.
- * boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
- , scale[1],area, score]
- Returns:
- dict: PCKh for each joint
- """
- kpts = []
- for output in outputs:
- preds = output['preds']
- batch_size = preds.shape[0]
- for i in range(batch_size):
- kpts.append({'keypoints': preds[i]})
- preds = np.stack([kpt['keypoints'] for kpt in kpts])
- # convert 0-based index to 1-based index,
- # and get the first two dimensions.
- preds = preds[..., :2] + 1.0
- if savepath is not None:
- pred_file = os.path.join(savepath, 'pred.mat')
- savemat(pred_file, mdict={'preds': preds})
- SC_BIAS = 0.6
- threshold = 0.5
- gt_file = os.path.join(
- os.path.dirname(self.ann_file), 'mpii_gt_val.mat')
- gt_dict = loadmat(gt_file)
- dataset_joints = gt_dict['dataset_joints']
- jnt_missing = gt_dict['jnt_missing']
- pos_gt_src = gt_dict['pos_gt_src']
- headboxes_src = gt_dict['headboxes_src']
- pos_pred_src = np.transpose(preds, [1, 2, 0])
- head = np.where(dataset_joints == 'head')[1][0]
- lsho = np.where(dataset_joints == 'lsho')[1][0]
- lelb = np.where(dataset_joints == 'lelb')[1][0]
- lwri = np.where(dataset_joints == 'lwri')[1][0]
- lhip = np.where(dataset_joints == 'lhip')[1][0]
- lkne = np.where(dataset_joints == 'lkne')[1][0]
- lank = np.where(dataset_joints == 'lank')[1][0]
- rsho = np.where(dataset_joints == 'rsho')[1][0]
- relb = np.where(dataset_joints == 'relb')[1][0]
- rwri = np.where(dataset_joints == 'rwri')[1][0]
- rkne = np.where(dataset_joints == 'rkne')[1][0]
- rank = np.where(dataset_joints == 'rank')[1][0]
- rhip = np.where(dataset_joints == 'rhip')[1][0]
- jnt_visible = 1 - jnt_missing
- uv_error = pos_pred_src - pos_gt_src
- uv_err = np.linalg.norm(uv_error, axis=1)
- headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :]
- headsizes = np.linalg.norm(headsizes, axis=0)
- headsizes *= SC_BIAS
- scale = headsizes * np.ones((len(uv_err), 1), dtype=np.float32)
- scaled_uv_err = uv_err / scale
- scaled_uv_err = scaled_uv_err * jnt_visible
- jnt_count = np.sum(jnt_visible, axis=1)
- less_than_threshold = (scaled_uv_err <= threshold) * jnt_visible
- PCKh = 100. * np.sum(less_than_threshold, axis=1) / jnt_count
- # save
- rng = np.arange(0, 0.5 + 0.01, 0.01)
- pckAll = np.zeros((len(rng), 16), dtype=np.float32)
- for r, threshold in enumerate(rng):
- less_than_threshold = (scaled_uv_err <= threshold) * jnt_visible
- pckAll[r, :] = 100. * np.sum(less_than_threshold,
- axis=1) / jnt_count
- PCKh = np.ma.array(PCKh, mask=False)
- PCKh.mask[6:8] = True
- jnt_count = np.ma.array(jnt_count, mask=False)
- jnt_count.mask[6:8] = True
- jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64)
- name_value = [ #noqa
- ('Head', PCKh[head]),
- ('Shoulder', 0.5 * (PCKh[lsho] + PCKh[rsho])),
- ('Elbow', 0.5 * (PCKh[lelb] + PCKh[relb])),
- ('Wrist', 0.5 * (PCKh[lwri] + PCKh[rwri])),
- ('Hip', 0.5 * (PCKh[lhip] + PCKh[rhip])),
- ('Knee', 0.5 * (PCKh[lkne] + PCKh[rkne])),
- ('Ankle', 0.5 * (PCKh[lank] + PCKh[rank])),
- ('PCKh', np.sum(PCKh * jnt_ratio)),
- ('PCKh@0.1', np.sum(pckAll[11, :] * jnt_ratio))
- ]
- name_value = OrderedDict(name_value)
- return name_value
- def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
- """sort kpts and remove the repeated ones."""
- kpts = sorted(kpts, key=lambda x: x[key])
- num = len(kpts)
- for i in range(num - 1, 0, -1):
- if kpts[i][key] == kpts[i - 1][key]:
- del kpts[i]
- return kpts
|