metrics.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import json
  20. import paddle
  21. import numpy as np
  22. import typing
  23. from pathlib import Path
  24. from .map_utils import prune_zero_padding, DetectionMAP
  25. from .coco_utils import get_infer_results, cocoapi_eval
  26. from .widerface_utils import face_eval_run
  27. from ppdet.data.source.category import get_categories
  28. from ppdet.utils.logger import setup_logger
  29. logger = setup_logger(__name__)
  30. __all__ = [
  31. 'Metric', 'COCOMetric', 'VOCMetric', 'WiderFaceMetric', 'get_infer_results',
  32. 'RBoxMetric', 'SNIPERCOCOMetric'
  33. ]
  34. COCO_SIGMAS = np.array([
  35. .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87,
  36. .89, .89
  37. ]) / 10.0
  38. CROWD_SIGMAS = np.array(
  39. [.79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89, .79,
  40. .79]) / 10.0
  41. class Metric(paddle.metric.Metric):
  42. def name(self):
  43. return self.__class__.__name__
  44. def reset(self):
  45. pass
  46. def accumulate(self):
  47. pass
  48. # paddle.metric.Metric defined :metch:`update`, :meth:`accumulate`
  49. # :metch:`reset`, in ppdet, we also need following 2 methods:
  50. # abstract method for logging metric results
  51. def log(self):
  52. pass
  53. # abstract method for getting metric results
  54. def get_results(self):
  55. pass
  56. class COCOMetric(Metric):
  57. def __init__(self, anno_file, **kwargs):
  58. self.anno_file = anno_file
  59. self.clsid2catid = kwargs.get('clsid2catid', None)
  60. if self.clsid2catid is None:
  61. self.clsid2catid, _ = get_categories('COCO', anno_file)
  62. self.classwise = kwargs.get('classwise', False)
  63. self.output_eval = kwargs.get('output_eval', None)
  64. # TODO: bias should be unified
  65. self.bias = kwargs.get('bias', 0)
  66. self.save_prediction_only = kwargs.get('save_prediction_only', False)
  67. self.iou_type = kwargs.get('IouType', 'bbox')
  68. if not self.save_prediction_only:
  69. assert os.path.isfile(anno_file), \
  70. "anno_file {} not a file".format(anno_file)
  71. if self.output_eval is not None:
  72. Path(self.output_eval).mkdir(exist_ok=True)
  73. self.reset()
  74. def reset(self):
  75. # only bbox and mask evaluation support currently
  76. self.results = {'bbox': [], 'mask': [], 'segm': [], 'keypoint': []}
  77. self.eval_results = {}
  78. def update(self, inputs, outputs):
  79. outs = {}
  80. # outputs Tensor -> numpy.ndarray
  81. for k, v in outputs.items():
  82. outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v
  83. # multi-scale inputs: all inputs have same im_id
  84. if isinstance(inputs, typing.Sequence):
  85. im_id = inputs[0]['im_id']
  86. else:
  87. im_id = inputs['im_id']
  88. outs['im_id'] = im_id.numpy() if isinstance(im_id,
  89. paddle.Tensor) else im_id
  90. infer_results = get_infer_results(
  91. outs, self.clsid2catid, bias=self.bias)
  92. self.results['bbox'] += infer_results[
  93. 'bbox'] if 'bbox' in infer_results else []
  94. self.results['mask'] += infer_results[
  95. 'mask'] if 'mask' in infer_results else []
  96. self.results['segm'] += infer_results[
  97. 'segm'] if 'segm' in infer_results else []
  98. self.results['keypoint'] += infer_results[
  99. 'keypoint'] if 'keypoint' in infer_results else []
  100. def accumulate(self):
  101. if len(self.results['bbox']) > 0:
  102. output = "bbox.json"
  103. if self.output_eval:
  104. output = os.path.join(self.output_eval, output)
  105. with open(output, 'w') as f:
  106. json.dump(self.results['bbox'], f)
  107. logger.info('The bbox result is saved to bbox.json.')
  108. if self.save_prediction_only:
  109. logger.info('The bbox result is saved to {} and do not '
  110. 'evaluate the mAP.'.format(output))
  111. else:
  112. bbox_stats = cocoapi_eval(
  113. output,
  114. 'bbox',
  115. anno_file=self.anno_file,
  116. classwise=self.classwise)
  117. self.eval_results['bbox'] = bbox_stats
  118. sys.stdout.flush()
  119. if len(self.results['mask']) > 0:
  120. output = "mask.json"
  121. if self.output_eval:
  122. output = os.path.join(self.output_eval, output)
  123. with open(output, 'w') as f:
  124. json.dump(self.results['mask'], f)
  125. logger.info('The mask result is saved to mask.json.')
  126. if self.save_prediction_only:
  127. logger.info('The mask result is saved to {} and do not '
  128. 'evaluate the mAP.'.format(output))
  129. else:
  130. seg_stats = cocoapi_eval(
  131. output,
  132. 'segm',
  133. anno_file=self.anno_file,
  134. classwise=self.classwise)
  135. self.eval_results['mask'] = seg_stats
  136. sys.stdout.flush()
  137. if len(self.results['segm']) > 0:
  138. output = "segm.json"
  139. if self.output_eval:
  140. output = os.path.join(self.output_eval, output)
  141. with open(output, 'w') as f:
  142. json.dump(self.results['segm'], f)
  143. logger.info('The segm result is saved to segm.json.')
  144. if self.save_prediction_only:
  145. logger.info('The segm result is saved to {} and do not '
  146. 'evaluate the mAP.'.format(output))
  147. else:
  148. seg_stats = cocoapi_eval(
  149. output,
  150. 'segm',
  151. anno_file=self.anno_file,
  152. classwise=self.classwise)
  153. self.eval_results['mask'] = seg_stats
  154. sys.stdout.flush()
  155. if len(self.results['keypoint']) > 0:
  156. output = "keypoint.json"
  157. if self.output_eval:
  158. output = os.path.join(self.output_eval, output)
  159. with open(output, 'w') as f:
  160. json.dump(self.results['keypoint'], f)
  161. logger.info('The keypoint result is saved to keypoint.json.')
  162. if self.save_prediction_only:
  163. logger.info('The keypoint result is saved to {} and do not '
  164. 'evaluate the mAP.'.format(output))
  165. else:
  166. style = 'keypoints'
  167. use_area = True
  168. sigmas = COCO_SIGMAS
  169. if self.iou_type == 'keypoints_crowd':
  170. style = 'keypoints_crowd'
  171. use_area = False
  172. sigmas = CROWD_SIGMAS
  173. keypoint_stats = cocoapi_eval(
  174. output,
  175. style,
  176. anno_file=self.anno_file,
  177. classwise=self.classwise,
  178. sigmas=sigmas,
  179. use_area=use_area)
  180. self.eval_results['keypoint'] = keypoint_stats
  181. sys.stdout.flush()
  182. def log(self):
  183. pass
  184. def get_results(self):
  185. return self.eval_results
  186. class VOCMetric(Metric):
  187. def __init__(self,
  188. label_list,
  189. class_num=20,
  190. overlap_thresh=0.5,
  191. map_type='11point',
  192. is_bbox_normalized=False,
  193. evaluate_difficult=False,
  194. classwise=False):
  195. assert os.path.isfile(label_list), \
  196. "label_list {} not a file".format(label_list)
  197. self.clsid2catid, self.catid2name = get_categories('VOC', label_list)
  198. self.overlap_thresh = overlap_thresh
  199. self.map_type = map_type
  200. self.evaluate_difficult = evaluate_difficult
  201. self.detection_map = DetectionMAP(
  202. class_num=class_num,
  203. overlap_thresh=overlap_thresh,
  204. map_type=map_type,
  205. is_bbox_normalized=is_bbox_normalized,
  206. evaluate_difficult=evaluate_difficult,
  207. catid2name=self.catid2name,
  208. classwise=classwise)
  209. self.reset()
  210. def reset(self):
  211. self.detection_map.reset()
  212. def update(self, inputs, outputs):
  213. bbox_np = outputs['bbox'].numpy() if isinstance(
  214. outputs['bbox'], paddle.Tensor) else outputs['bbox']
  215. bboxes = bbox_np[:, 2:]
  216. scores = bbox_np[:, 1]
  217. labels = bbox_np[:, 0]
  218. bbox_lengths = outputs['bbox_num'].numpy() if isinstance(
  219. outputs['bbox_num'], paddle.Tensor) else outputs['bbox_num']
  220. if bboxes.shape == (1, 1) or bboxes is None:
  221. return
  222. gt_boxes = inputs['gt_bbox']
  223. gt_labels = inputs['gt_class']
  224. difficults = inputs['difficult'] if not self.evaluate_difficult \
  225. else None
  226. if 'scale_factor' in inputs:
  227. scale_factor = inputs['scale_factor'].numpy() if isinstance(
  228. inputs['scale_factor'],
  229. paddle.Tensor) else inputs['scale_factor']
  230. else:
  231. scale_factor = np.ones((gt_boxes.shape[0], 2)).astype('float32')
  232. bbox_idx = 0
  233. for i in range(len(gt_boxes)):
  234. gt_box = gt_boxes[i].numpy() if isinstance(
  235. gt_boxes[i], paddle.Tensor) else gt_boxes[i]
  236. h, w = scale_factor[i]
  237. gt_box = gt_box / np.array([w, h, w, h])
  238. gt_label = gt_labels[i].numpy() if isinstance(
  239. gt_labels[i], paddle.Tensor) else gt_labels[i]
  240. if difficults is not None:
  241. difficult = difficults[i].numpy() if isinstance(
  242. difficults[i], paddle.Tensor) else difficults[i]
  243. else:
  244. difficult = None
  245. bbox_num = bbox_lengths[i]
  246. bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
  247. score = scores[bbox_idx:bbox_idx + bbox_num]
  248. label = labels[bbox_idx:bbox_idx + bbox_num]
  249. gt_box, gt_label, difficult = prune_zero_padding(gt_box, gt_label,
  250. difficult)
  251. self.detection_map.update(bbox, score, label, gt_box, gt_label,
  252. difficult)
  253. bbox_idx += bbox_num
  254. def accumulate(self):
  255. logger.info("Accumulating evaluatation results...")
  256. self.detection_map.accumulate()
  257. def log(self):
  258. map_stat = 100. * self.detection_map.get_map()
  259. logger.info("mAP({:.2f}, {}) = {:.2f}%".format(self.overlap_thresh,
  260. self.map_type, map_stat))
  261. def get_results(self):
  262. return {'bbox': [self.detection_map.get_map()]}
  263. class WiderFaceMetric(Metric):
  264. def __init__(self, image_dir, anno_file, multi_scale=True):
  265. self.image_dir = image_dir
  266. self.anno_file = anno_file
  267. self.multi_scale = multi_scale
  268. self.clsid2catid, self.catid2name = get_categories('widerface')
  269. def update(self, model):
  270. face_eval_run(
  271. model,
  272. self.image_dir,
  273. self.anno_file,
  274. pred_dir='output/pred',
  275. eval_mode='widerface',
  276. multi_scale=self.multi_scale)
  277. class RBoxMetric(Metric):
  278. def __init__(self, anno_file, **kwargs):
  279. assert os.path.isfile(anno_file), \
  280. "anno_file {} not a file".format(anno_file)
  281. assert os.path.exists(anno_file), "anno_file {} not exists".format(
  282. anno_file)
  283. self.anno_file = anno_file
  284. self.gt_anno = json.load(open(self.anno_file))
  285. cats = self.gt_anno['categories']
  286. self.clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)}
  287. self.catid2clsid = {cat['id']: i for i, cat in enumerate(cats)}
  288. self.catid2name = {cat['id']: cat['name'] for cat in cats}
  289. self.classwise = kwargs.get('classwise', False)
  290. self.output_eval = kwargs.get('output_eval', None)
  291. # TODO: bias should be unified
  292. self.bias = kwargs.get('bias', 0)
  293. self.save_prediction_only = kwargs.get('save_prediction_only', False)
  294. self.iou_type = kwargs.get('IouType', 'bbox')
  295. self.overlap_thresh = kwargs.get('overlap_thresh', 0.5)
  296. self.map_type = kwargs.get('map_type', '11point')
  297. self.evaluate_difficult = kwargs.get('evaluate_difficult', False)
  298. class_num = len(self.catid2name)
  299. self.detection_map = DetectionMAP(
  300. class_num=class_num,
  301. overlap_thresh=self.overlap_thresh,
  302. map_type=self.map_type,
  303. is_bbox_normalized=False,
  304. evaluate_difficult=self.evaluate_difficult,
  305. catid2name=self.catid2name,
  306. classwise=self.classwise)
  307. self.reset()
  308. def reset(self):
  309. self.result_bbox = []
  310. self.detection_map.reset()
  311. def update(self, inputs, outputs):
  312. outs = {}
  313. # outputs Tensor -> numpy.ndarray
  314. for k, v in outputs.items():
  315. outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v
  316. im_id = inputs['im_id']
  317. outs['im_id'] = im_id.numpy() if isinstance(im_id,
  318. paddle.Tensor) else im_id
  319. infer_results = get_infer_results(
  320. outs, self.clsid2catid, bias=self.bias)
  321. self.result_bbox += infer_results[
  322. 'bbox'] if 'bbox' in infer_results else []
  323. bbox = [b['bbox'] for b in self.result_bbox]
  324. score = [b['score'] for b in self.result_bbox]
  325. label = [b['category_id'] for b in self.result_bbox]
  326. label = [self.catid2clsid[e] for e in label]
  327. gt_box = [
  328. e['bbox'] for e in self.gt_anno['annotations']
  329. if e['image_id'] == outs['im_id']
  330. ]
  331. gt_label = [
  332. e['category_id'] for e in self.gt_anno['annotations']
  333. if e['image_id'] == outs['im_id']
  334. ]
  335. gt_label = [self.catid2clsid[e] for e in gt_label]
  336. self.detection_map.update(bbox, score, label, gt_box, gt_label)
  337. def accumulate(self):
  338. if len(self.result_bbox) > 0:
  339. output = "bbox.json"
  340. if self.output_eval:
  341. output = os.path.join(self.output_eval, output)
  342. with open(output, 'w') as f:
  343. json.dump(self.result_bbox, f)
  344. logger.info('The bbox result is saved to bbox.json.')
  345. if self.save_prediction_only:
  346. logger.info('The bbox result is saved to {} and do not '
  347. 'evaluate the mAP.'.format(output))
  348. else:
  349. logger.info("Accumulating evaluatation results...")
  350. self.detection_map.accumulate()
  351. def log(self):
  352. map_stat = 100. * self.detection_map.get_map()
  353. logger.info("mAP({:.2f}, {}) = {:.2f}%".format(self.overlap_thresh,
  354. self.map_type, map_stat))
  355. def get_results(self):
  356. return {'bbox': [self.detection_map.get_map()]}
  357. class SNIPERCOCOMetric(COCOMetric):
  358. def __init__(self, anno_file, **kwargs):
  359. super(SNIPERCOCOMetric, self).__init__(anno_file, **kwargs)
  360. self.dataset = kwargs["dataset"]
  361. self.chip_results = []
  362. def reset(self):
  363. # only bbox and mask evaluation support currently
  364. self.results = {'bbox': [], 'mask': [], 'segm': [], 'keypoint': []}
  365. self.eval_results = {}
  366. self.chip_results = []
  367. def update(self, inputs, outputs):
  368. outs = {}
  369. # outputs Tensor -> numpy.ndarray
  370. for k, v in outputs.items():
  371. outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v
  372. im_id = inputs['im_id']
  373. outs['im_id'] = im_id.numpy() if isinstance(im_id,
  374. paddle.Tensor) else im_id
  375. self.chip_results.append(outs)
  376. def accumulate(self):
  377. results = self.dataset.anno_cropper.aggregate_chips_detections(
  378. self.chip_results)
  379. for outs in results:
  380. infer_results = get_infer_results(
  381. outs, self.clsid2catid, bias=self.bias)
  382. self.results['bbox'] += infer_results[
  383. 'bbox'] if 'bbox' in infer_results else []
  384. super(SNIPERCOCOMetric, self).accumulate()