coco.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. from ppdet.core.workspace import register, serializable
  17. from .dataset import DetDataset
  18. from ppdet.utils.logger import setup_logger
  19. logger = setup_logger(__name__)
  20. @register
  21. @serializable
  22. class COCODataSet(DetDataset):
  23. """
  24. Load dataset with COCO format.
  25. Args:
  26. dataset_dir (str): root directory for dataset.
  27. image_dir (str): directory for images.
  28. anno_path (str): coco annotation file path.
  29. data_fields (list): key name of data dictionary, at least have 'image'.
  30. sample_num (int): number of samples to load, -1 means all.
  31. load_crowd (bool): whether to load crowded ground-truth.
  32. False as default
  33. allow_empty (bool): whether to load empty entry. False as default
  34. empty_ratio (float): the ratio of empty record number to total
  35. record's, if empty_ratio is out of [0. ,1.), do not sample the
  36. records and use all the empty entries. 1. as default
  37. """
  38. def __init__(self,
  39. dataset_dir=None,
  40. image_dir=None,
  41. anno_path=None,
  42. data_fields=['image'],
  43. sample_num=-1,
  44. load_crowd=False,
  45. allow_empty=False,
  46. empty_ratio=1.):
  47. super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
  48. data_fields, sample_num)
  49. self.load_image_only = False
  50. self.load_semantic = False
  51. self.load_crowd = load_crowd
  52. self.allow_empty = allow_empty
  53. self.empty_ratio = empty_ratio
  54. def _sample_empty(self, records, num):
  55. # if empty_ratio is out of [0. ,1.), do not sample the records
  56. if self.empty_ratio < 0. or self.empty_ratio >= 1.:
  57. return records
  58. import random
  59. sample_num = min(
  60. int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
  61. records = random.sample(records, sample_num)
  62. return records
  63. def parse_dataset(self):
  64. anno_path = os.path.join(self.dataset_dir, self.anno_path)
  65. image_dir = os.path.join(self.dataset_dir, self.image_dir)
  66. assert anno_path.endswith('.json'), \
  67. 'invalid coco annotation file: ' + anno_path
  68. from pycocotools.coco import COCO
  69. coco = COCO(anno_path)
  70. img_ids = coco.getImgIds()
  71. img_ids.sort()
  72. cat_ids = coco.getCatIds()
  73. records = []
  74. empty_records = []
  75. ct = 0
  76. self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  77. self.cname2cid = dict({
  78. coco.loadCats(catid)[0]['name']: clsid
  79. for catid, clsid in self.catid2clsid.items()
  80. })
  81. if 'annotations' not in coco.dataset:
  82. self.load_image_only = True
  83. logger.warning('Annotation file: {} does not contains ground truth '
  84. 'and load image information only.'.format(anno_path))
  85. for img_id in img_ids:
  86. img_anno = coco.loadImgs([img_id])[0]
  87. im_fname = img_anno['file_name']
  88. im_w = float(img_anno['width'])
  89. im_h = float(img_anno['height'])
  90. im_path = os.path.join(image_dir,
  91. im_fname) if image_dir else im_fname
  92. is_empty = False
  93. if not os.path.exists(im_path):
  94. logger.warning('Illegal image file: {}, and it will be '
  95. 'ignored'.format(im_path))
  96. continue
  97. if im_w < 0 or im_h < 0:
  98. logger.warning('Illegal width: {} or height: {} in annotation, '
  99. 'and im_id: {} will be ignored'.format(
  100. im_w, im_h, img_id))
  101. continue
  102. coco_rec = {
  103. 'im_file': im_path,
  104. 'im_id': np.array([img_id]),
  105. 'h': im_h,
  106. 'w': im_w,
  107. } if 'image' in self.data_fields else {}
  108. if not self.load_image_only:
  109. ins_anno_ids = coco.getAnnIds(
  110. imgIds=[img_id], iscrowd=None if self.load_crowd else False)
  111. instances = coco.loadAnns(ins_anno_ids)
  112. bboxes = []
  113. is_rbox_anno = False
  114. for inst in instances:
  115. # check gt bbox
  116. if inst.get('ignore', False):
  117. continue
  118. if 'bbox' not in inst.keys():
  119. continue
  120. else:
  121. if not any(np.array(inst['bbox'])):
  122. continue
  123. # read rbox anno or not
  124. is_rbox_anno = True if len(inst['bbox']) == 5 else False
  125. if is_rbox_anno:
  126. xc, yc, box_w, box_h, angle = inst['bbox']
  127. x1 = xc - box_w / 2.0
  128. y1 = yc - box_h / 2.0
  129. x2 = x1 + box_w
  130. y2 = y1 + box_h
  131. else:
  132. x1, y1, box_w, box_h = inst['bbox']
  133. x2 = x1 + box_w
  134. y2 = y1 + box_h
  135. eps = 1e-5
  136. if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
  137. inst['clean_bbox'] = [
  138. round(float(x), 3) for x in [x1, y1, x2, y2]
  139. ]
  140. if is_rbox_anno:
  141. inst['clean_rbox'] = [xc, yc, box_w, box_h, angle]
  142. bboxes.append(inst)
  143. else:
  144. logger.warning(
  145. 'Found an invalid bbox in annotations: im_id: {}, '
  146. 'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
  147. img_id, float(inst['area']), x1, y1, x2, y2))
  148. num_bbox = len(bboxes)
  149. if num_bbox <= 0 and not self.allow_empty:
  150. continue
  151. elif num_bbox <= 0:
  152. is_empty = True
  153. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  154. if is_rbox_anno:
  155. gt_rbox = np.zeros((num_bbox, 5), dtype=np.float32)
  156. gt_theta = np.zeros((num_bbox, 1), dtype=np.int32)
  157. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  158. is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
  159. gt_poly = [None] * num_bbox
  160. has_segmentation = False
  161. for i, box in enumerate(bboxes):
  162. catid = box['category_id']
  163. gt_class[i][0] = self.catid2clsid[catid]
  164. gt_bbox[i, :] = box['clean_bbox']
  165. # xc, yc, w, h, theta
  166. if is_rbox_anno:
  167. gt_rbox[i, :] = box['clean_rbox']
  168. is_crowd[i][0] = box['iscrowd']
  169. # check RLE format
  170. if 'segmentation' in box and box['iscrowd'] == 1:
  171. gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
  172. elif 'segmentation' in box and box['segmentation']:
  173. if not np.array(box['segmentation']
  174. ).size > 0 and not self.allow_empty:
  175. bboxes.pop(i)
  176. gt_poly.pop(i)
  177. np.delete(is_crowd, i)
  178. np.delete(gt_class, i)
  179. np.delete(gt_bbox, i)
  180. else:
  181. gt_poly[i] = box['segmentation']
  182. has_segmentation = True
  183. if has_segmentation and not any(
  184. gt_poly) and not self.allow_empty:
  185. continue
  186. if is_rbox_anno:
  187. gt_rec = {
  188. 'is_crowd': is_crowd,
  189. 'gt_class': gt_class,
  190. 'gt_bbox': gt_bbox,
  191. 'gt_rbox': gt_rbox,
  192. 'gt_poly': gt_poly,
  193. }
  194. else:
  195. gt_rec = {
  196. 'is_crowd': is_crowd,
  197. 'gt_class': gt_class,
  198. 'gt_bbox': gt_bbox,
  199. 'gt_poly': gt_poly,
  200. }
  201. for k, v in gt_rec.items():
  202. if k in self.data_fields:
  203. coco_rec[k] = v
  204. # TODO: remove load_semantic
  205. if self.load_semantic and 'semantic' in self.data_fields:
  206. seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
  207. 'train2017', im_fname[:-3] + 'png')
  208. coco_rec.update({'semantic': seg_path})
  209. logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
  210. im_path, img_id, im_h, im_w))
  211. if is_empty:
  212. empty_records.append(coco_rec)
  213. else:
  214. records.append(coco_rec)
  215. ct += 1
  216. if self.sample_num > 0 and ct >= self.sample_num:
  217. break
  218. assert ct > 0, 'not found any coco record in %s' % (anno_path)
  219. logger.debug('{} samples in file {}'.format(ct, anno_path))
  220. if self.allow_empty and len(empty_records) > 0:
  221. empty_records = self._sample_empty(empty_records, len(records))
  222. records += empty_records
  223. self.roidbs = records