voc.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import xml.etree.ElementTree as ET
  17. from ppdet.core.workspace import register, serializable
  18. from .dataset import DetDataset
  19. from ppdet.utils.logger import setup_logger
  20. logger = setup_logger(__name__)
  21. @register
  22. @serializable
  23. class VOCDataSet(DetDataset):
  24. """
  25. Load dataset with PascalVOC format.
  26. Notes:
  27. `anno_path` must contains xml file and image file path for annotations.
  28. Args:
  29. dataset_dir (str): root directory for dataset.
  30. image_dir (str): directory for images.
  31. anno_path (str): voc annotation file path.
  32. data_fields (list): key name of data dictionary, at least have 'image'.
  33. sample_num (int): number of samples to load, -1 means all.
  34. label_list (str): if use_default_label is False, will load
  35. mapping between category and class index.
  36. allow_empty (bool): whether to load empty entry. False as default
  37. empty_ratio (float): the ratio of empty record number to total
  38. record's, if empty_ratio is out of [0. ,1.), do not sample the
  39. records and use all the empty entries. 1. as default
  40. """
  41. def __init__(self,
  42. dataset_dir=None,
  43. image_dir=None,
  44. anno_path=None,
  45. data_fields=['image'],
  46. sample_num=-1,
  47. label_list=None,
  48. allow_empty=False,
  49. empty_ratio=1.):
  50. super(VOCDataSet, self).__init__(
  51. dataset_dir=dataset_dir,
  52. image_dir=image_dir,
  53. anno_path=anno_path,
  54. data_fields=data_fields,
  55. sample_num=sample_num)
  56. self.label_list = label_list
  57. self.allow_empty = allow_empty
  58. self.empty_ratio = empty_ratio
  59. def _sample_empty(self, records, num):
  60. # if empty_ratio is out of [0. ,1.), do not sample the records
  61. if self.empty_ratio < 0. or self.empty_ratio >= 1.:
  62. return records
  63. import random
  64. sample_num = min(
  65. int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
  66. records = random.sample(records, sample_num)
  67. return records
  68. def parse_dataset(self, ):
  69. anno_path = os.path.join(self.dataset_dir, self.anno_path)
  70. image_dir = os.path.join(self.dataset_dir, self.image_dir)
  71. # mapping category name to class id
  72. # first_class:0, second_class:1, ...
  73. records = []
  74. empty_records = []
  75. ct = 0
  76. cname2cid = {}
  77. if self.label_list:
  78. label_path = os.path.join(self.dataset_dir, self.label_list)
  79. if not os.path.exists(label_path):
  80. raise ValueError("label_list {} does not exists".format(
  81. label_path))
  82. with open(label_path, 'r') as fr:
  83. label_id = 0
  84. for line in fr.readlines():
  85. cname2cid[line.strip()] = label_id
  86. label_id += 1
  87. else:
  88. cname2cid = pascalvoc_label()
  89. with open(anno_path, 'r') as fr:
  90. while True:
  91. line = fr.readline()
  92. if not line:
  93. break
  94. img_file, xml_file = [os.path.join(image_dir, x) \
  95. for x in line.strip().split()[:2]]
  96. if not os.path.exists(img_file):
  97. logger.warning(
  98. 'Illegal image file: {}, and it will be ignored'.format(
  99. img_file))
  100. continue
  101. if not os.path.isfile(xml_file):
  102. logger.warning(
  103. 'Illegal xml file: {}, and it will be ignored'.format(
  104. xml_file))
  105. continue
  106. tree = ET.parse(xml_file)
  107. if tree.find('id') is None:
  108. im_id = np.array([ct])
  109. else:
  110. im_id = np.array([int(tree.find('id').text)])
  111. objs = tree.findall('object')
  112. im_w = float(tree.find('size').find('width').text)
  113. im_h = float(tree.find('size').find('height').text)
  114. if im_w < 0 or im_h < 0:
  115. logger.warning(
  116. 'Illegal width: {} or height: {} in annotation, '
  117. 'and {} will be ignored'.format(im_w, im_h, xml_file))
  118. continue
  119. num_bbox, i = len(objs), 0
  120. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  121. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  122. gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
  123. difficult = np.zeros((num_bbox, 1), dtype=np.int32)
  124. for obj in objs:
  125. cname = obj.find('name').text
  126. # user dataset may not contain difficult field
  127. _difficult = obj.find('difficult')
  128. _difficult = int(
  129. _difficult.text) if _difficult is not None else 0
  130. x1 = float(obj.find('bndbox').find('xmin').text)
  131. y1 = float(obj.find('bndbox').find('ymin').text)
  132. x2 = float(obj.find('bndbox').find('xmax').text)
  133. y2 = float(obj.find('bndbox').find('ymax').text)
  134. x1 = max(0, x1)
  135. y1 = max(0, y1)
  136. x2 = min(im_w - 1, x2)
  137. y2 = min(im_h - 1, y2)
  138. if x2 > x1 and y2 > y1:
  139. gt_bbox[i, :] = [x1, y1, x2, y2]
  140. gt_class[i, 0] = cname2cid[cname]
  141. gt_score[i, 0] = 1.
  142. difficult[i, 0] = _difficult
  143. i += 1
  144. else:
  145. logger.warning(
  146. 'Found an invalid bbox in annotations: xml_file: {}'
  147. ', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
  148. xml_file, x1, y1, x2, y2))
  149. gt_bbox = gt_bbox[:i, :]
  150. gt_class = gt_class[:i, :]
  151. gt_score = gt_score[:i, :]
  152. difficult = difficult[:i, :]
  153. voc_rec = {
  154. 'im_file': img_file,
  155. 'im_id': im_id,
  156. 'h': im_h,
  157. 'w': im_w
  158. } if 'image' in self.data_fields else {}
  159. gt_rec = {
  160. 'gt_class': gt_class,
  161. 'gt_score': gt_score,
  162. 'gt_bbox': gt_bbox,
  163. 'difficult': difficult
  164. }
  165. for k, v in gt_rec.items():
  166. if k in self.data_fields:
  167. voc_rec[k] = v
  168. if len(objs) == 0:
  169. empty_records.append(voc_rec)
  170. else:
  171. records.append(voc_rec)
  172. ct += 1
  173. if self.sample_num > 0 and ct >= self.sample_num:
  174. break
  175. assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
  176. logger.debug('{} samples in file {}'.format(ct, anno_path))
  177. if self.allow_empty and len(empty_records) > 0:
  178. empty_records = self._sample_empty(empty_records, len(records))
  179. records += empty_records
  180. self.roidbs, self.cname2cid = records, cname2cid
  181. def get_label_list(self):
  182. return os.path.join(self.dataset_dir, self.label_list)
  183. def pascalvoc_label():
  184. labels_map = {
  185. 'aeroplane': 0,
  186. 'bicycle': 1,
  187. 'bird': 2,
  188. 'boat': 3,
  189. 'bottle': 4,
  190. 'bus': 5,
  191. 'car': 6,
  192. 'cat': 7,
  193. 'chair': 8,
  194. 'cow': 9,
  195. 'diningtable': 10,
  196. 'dog': 11,
  197. 'horse': 12,
  198. 'motorbike': 13,
  199. 'person': 14,
  200. 'pottedplant': 15,
  201. 'sheep': 16,
  202. 'sofa': 17,
  203. 'train': 18,
  204. 'tvmonitor': 19
  205. }
  206. return labels_map