voc_eval.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import os
  19. from ..data.source.voc import pascalvoc_label
  20. from .map_utils import DetectionMAP
  21. from .coco_eval import bbox2out
  22. import logging
  23. logger = logging.getLogger(__name__)
  24. __all__ = ['bbox_eval', 'bbox2out', 'get_category_info']
  25. def bbox_eval(results,
  26. class_num,
  27. overlap_thresh=0.5,
  28. map_type='11point',
  29. is_bbox_normalized=False,
  30. evaluate_difficult=False):
  31. """
  32. Bounding box evaluation for VOC dataset
  33. Args:
  34. results (list): prediction bounding box results.
  35. class_num (int): evaluation class number.
  36. overlap_thresh (float): the postive threshold of
  37. bbox overlap
  38. map_type (string): method for mAP calcualtion,
  39. can only be '11point' or 'integral'
  40. is_bbox_normalized (bool): whether bbox is normalized
  41. to range [0, 1].
  42. evaluate_difficult (bool): whether to evaluate
  43. difficult gt bbox.
  44. """
  45. assert 'bbox' in results[0]
  46. logger.info("Start evaluate...")
  47. detection_map = DetectionMAP(
  48. class_num=class_num,
  49. overlap_thresh=overlap_thresh,
  50. map_type=map_type,
  51. is_bbox_normalized=is_bbox_normalized,
  52. evaluate_difficult=evaluate_difficult)
  53. for t in results:
  54. bboxes = t['bbox'][0]
  55. bbox_lengths = t['bbox'][1][0]
  56. if bboxes.shape == (1, 1) or bboxes is None:
  57. continue
  58. gt_boxes = t['gt_bbox'][0]
  59. gt_labels = t['gt_class'][0]
  60. difficults = t['is_difficult'][0] if not evaluate_difficult \
  61. else None
  62. if len(t['gt_bbox'][1]) == 0:
  63. # gt_bbox, gt_class, difficult read as zero padded Tensor
  64. bbox_idx = 0
  65. for i in range(len(gt_boxes)):
  66. gt_box = gt_boxes[i]
  67. gt_label = gt_labels[i]
  68. difficult = None if difficults is None \
  69. else difficults[i]
  70. bbox_num = bbox_lengths[i]
  71. bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
  72. gt_box, gt_label, difficult = prune_zero_padding(
  73. gt_box, gt_label, difficult)
  74. detection_map.update(bbox, gt_box, gt_label, difficult)
  75. bbox_idx += bbox_num
  76. else:
  77. # gt_box, gt_label, difficult read as LoDTensor
  78. gt_box_lengths = t['gt_bbox'][1][0]
  79. bbox_idx = 0
  80. gt_box_idx = 0
  81. for i in range(len(bbox_lengths)):
  82. bbox_num = bbox_lengths[i]
  83. gt_box_num = gt_box_lengths[i]
  84. bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
  85. gt_box = gt_boxes[gt_box_idx:gt_box_idx + gt_box_num]
  86. gt_label = gt_labels[gt_box_idx:gt_box_idx + gt_box_num]
  87. difficult = None if difficults is None else \
  88. difficults[gt_box_idx: gt_box_idx + gt_box_num]
  89. detection_map.update(bbox, gt_box, gt_label, difficult)
  90. bbox_idx += bbox_num
  91. gt_box_idx += gt_box_num
  92. logger.info("Accumulating evaluatation results...")
  93. detection_map.accumulate()
  94. map_stat = 100. * detection_map.get_map()
  95. logger.info("mAP({:.2f}, {}) = {:.2f}%".format(overlap_thresh, map_type,
  96. map_stat))
  97. return map_stat
  98. def prune_zero_padding(gt_box, gt_label, difficult=None):
  99. valid_cnt = 0
  100. for i in range(len(gt_box)):
  101. if gt_box[i, 0] == 0 and gt_box[i, 1] == 0 and \
  102. gt_box[i, 2] == 0 and gt_box[i, 3] == 0:
  103. break
  104. valid_cnt += 1
  105. return (gt_box[:valid_cnt], gt_label[:valid_cnt], difficult[:valid_cnt]
  106. if difficult is not None else None)
  107. def get_category_info(anno_file=None,
  108. with_background=True,
  109. use_default_label=False):
  110. if use_default_label or anno_file is None \
  111. or not os.path.exists(anno_file):
  112. logger.info("Not found annotation file {}, load "
  113. "voc2012 categories.".format(anno_file))
  114. return vocall_category_info(with_background)
  115. else:
  116. logger.info("Load categories from {}".format(anno_file))
  117. return get_category_info_from_anno(anno_file, with_background)
  118. def get_category_info_from_anno(anno_file, with_background=True):
  119. """
  120. Get class id to category id map and category id
  121. to category name map from annotation file.
  122. Args:
  123. anno_file (str): annotation file path
  124. with_background (bool, default True):
  125. whether load background as class 0.
  126. """
  127. cats = []
  128. with open(anno_file) as f:
  129. for line in f.readlines():
  130. cats.append(line.strip())
  131. if cats[0] != 'background' and with_background:
  132. cats.insert(0, 'background')
  133. if cats[0] == 'background' and not with_background:
  134. cats = cats[1:]
  135. clsid2catid = {i: i for i in range(len(cats))}
  136. catid2name = {i: name for i, name in enumerate(cats)}
  137. return clsid2catid, catid2name
  138. def vocall_category_info(with_background=True):
  139. """
  140. Get class id to category id map and category id
  141. to category name map of mixup voc dataset
  142. Args:
  143. with_background (bool, default True):
  144. whether load background as class 0.
  145. """
  146. label_map = pascalvoc_label(with_background)
  147. label_map = sorted(label_map.items(), key=lambda x: x[1])
  148. cats = [l[0] for l in label_map]
  149. if with_background:
  150. cats.insert(0, 'background')
  151. clsid2catid = {i: i for i in range(len(cats))}
  152. catid2name = {i: name for i, name in enumerate(cats)}
  153. return clsid2catid, catid2name