voc.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import xml.etree.ElementTree as ET
  17. from ppdet.core.workspace import register, serializable
  18. from .dataset import DataSet
  19. import logging
  20. logger = logging.getLogger(__name__)
  21. @register
  22. @serializable
  23. class VOCDataSet(DataSet):
  24. """
  25. Load dataset with PascalVOC format.
  26. Notes:
  27. `anno_path` must contains xml file and image file path for annotations.
  28. Args:
  29. dataset_dir (str): root directory for dataset.
  30. image_dir (str): directory for images.
  31. anno_path (str): voc annotation file path.
  32. sample_num (int): number of samples to load, -1 means all.
  33. use_default_label (bool): whether use the default mapping of
  34. label to integer index. Default True.
  35. with_background (bool): whether load background as a class,
  36. default True.
  37. label_list (str): if use_default_label is False, will load
  38. mapping between category and class index.
  39. """
  40. def __init__(self,
  41. dataset_dir=None,
  42. image_dir=None,
  43. anno_path=None,
  44. sample_num=-1,
  45. use_default_label=False,
  46. with_background=True,
  47. label_list='label_list.txt'):
  48. super(VOCDataSet, self).__init__(
  49. image_dir=image_dir,
  50. anno_path=anno_path,
  51. sample_num=sample_num,
  52. dataset_dir=dataset_dir,
  53. with_background=with_background)
  54. # roidbs is list of dict whose structure is:
  55. # {
  56. # 'im_file': im_fname, # image file name
  57. # 'im_id': im_id, # image id
  58. # 'h': im_h, # height of image
  59. # 'w': im_w, # width
  60. # 'is_crowd': is_crowd,
  61. # 'gt_class': gt_class,
  62. # 'gt_score': gt_score,
  63. # 'gt_bbox': gt_bbox,
  64. # 'difficult': difficult
  65. # }
  66. self.roidbs = None
  67. # 'cname2id' is a dict to map category name to class id
  68. self.cname2cid = None
  69. self.use_default_label = use_default_label
  70. self.label_list = label_list
  71. def load_roidb_and_cname2cid(self):
  72. anno_path = os.path.join(self.dataset_dir, self.anno_path)
  73. image_dir = os.path.join(self.dataset_dir, self.image_dir)
  74. # mapping category name to class id
  75. # if with_background is True:
  76. # background:0, first_class:1, second_class:2, ...
  77. # if with_background is False:
  78. # first_class:0, second_class:1, ...
  79. records = []
  80. ct = 0
  81. cname2cid = {}
  82. if not self.use_default_label:
  83. label_path = os.path.join(self.dataset_dir, self.label_list)
  84. if not os.path.exists(label_path):
  85. raise ValueError("label_list {} does not exists".format(
  86. label_path))
  87. with open(label_path, 'r') as fr:
  88. label_id = int(self.with_background)
  89. for line in fr.readlines():
  90. cname2cid[line.strip()] = label_id
  91. label_id += 1
  92. else:
  93. cname2cid = pascalvoc_label(self.with_background)
  94. with open(anno_path, 'r') as fr:
  95. while True:
  96. line = fr.readline()
  97. if not line:
  98. break
  99. img_file, xml_file = [os.path.join(image_dir, x) \
  100. for x in line.strip().split()[:2]]
  101. if not os.path.exists(img_file):
  102. logger.warning(
  103. 'Illegal image file: {}, and it will be ignored'.format(
  104. img_file))
  105. continue
  106. if not os.path.isfile(xml_file):
  107. logger.warning(
  108. 'Illegal xml file: {}, and it will be ignored'.format(
  109. xml_file))
  110. continue
  111. tree = ET.parse(xml_file)
  112. if tree.find('id') is None:
  113. im_id = np.array([ct])
  114. else:
  115. im_id = np.array([int(tree.find('id').text)])
  116. objs = tree.findall('object')
  117. im_w = float(tree.find('size').find('width').text)
  118. im_h = float(tree.find('size').find('height').text)
  119. if im_w < 0 or im_h < 0:
  120. logger.warning(
  121. 'Illegal width: {} or height: {} in annotation, '
  122. 'and {} will be ignored'.format(im_w, im_h, xml_file))
  123. continue
  124. gt_bbox = []
  125. gt_class = []
  126. gt_score = []
  127. is_crowd = []
  128. difficult = []
  129. for i, obj in enumerate(objs):
  130. cname = obj.find('name').text
  131. _difficult = int(obj.find('difficult').text)
  132. x1 = float(obj.find('bndbox').find('xmin').text)
  133. y1 = float(obj.find('bndbox').find('ymin').text)
  134. x2 = float(obj.find('bndbox').find('xmax').text)
  135. y2 = float(obj.find('bndbox').find('ymax').text)
  136. x1 = max(0, x1)
  137. y1 = max(0, y1)
  138. x2 = min(im_w - 1, x2)
  139. y2 = min(im_h - 1, y2)
  140. if x2 > x1 and y2 > y1:
  141. gt_bbox.append([x1, y1, x2, y2])
  142. gt_class.append([cname2cid[cname]])
  143. gt_score.append([1.])
  144. is_crowd.append([0])
  145. difficult.append([_difficult])
  146. else:
  147. logger.warning(
  148. 'Found an invalid bbox in annotations: xml_file: {}'
  149. ', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
  150. xml_file, x1, y1, x2, y2))
  151. gt_bbox = np.array(gt_bbox).astype('float32')
  152. gt_class = np.array(gt_class).astype('int32')
  153. gt_score = np.array(gt_score).astype('float32')
  154. is_crowd = np.array(is_crowd).astype('int32')
  155. difficult = np.array(difficult).astype('int32')
  156. voc_rec = {
  157. 'im_file': img_file,
  158. 'im_id': im_id,
  159. 'h': im_h,
  160. 'w': im_w,
  161. 'is_crowd': is_crowd,
  162. 'gt_class': gt_class,
  163. 'gt_score': gt_score,
  164. 'gt_bbox': gt_bbox,
  165. 'difficult': difficult
  166. }
  167. if len(objs) != 0:
  168. records.append(voc_rec)
  169. ct += 1
  170. if self.sample_num > 0 and ct >= self.sample_num:
  171. break
  172. assert len(records) > 0, 'not found any voc record in %s' % (
  173. self.anno_path)
  174. logger.debug('{} samples in file {}'.format(ct, anno_path))
  175. self.roidbs, self.cname2cid = records, cname2cid
  176. def pascalvoc_label(with_background=True):
  177. labels_map = {
  178. 'aeroplane': 1,
  179. 'bicycle': 2,
  180. 'bird': 3,
  181. 'boat': 4,
  182. 'bottle': 5,
  183. 'bus': 6,
  184. 'car': 7,
  185. 'cat': 8,
  186. 'chair': 9,
  187. 'cow': 10,
  188. 'diningtable': 11,
  189. 'dog': 12,
  190. 'horse': 13,
  191. 'motorbike': 14,
  192. 'person': 15,
  193. 'pottedplant': 16,
  194. 'sheep': 17,
  195. 'sofa': 18,
  196. 'train': 19,
  197. 'tvmonitor': 20
  198. }
  199. if not with_background:
  200. labels_map = {k: v - 1 for k, v in labels_map.items()}
  201. return labels_map