coco.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. from .dataset import DataSet
  17. from ppdet.core.workspace import register, serializable
  18. import logging
  19. logger = logging.getLogger(__name__)
  20. @register
  21. @serializable
  22. class COCODataSet(DataSet):
  23. """
  24. Load COCO records with annotations in json file 'anno_path'
  25. Args:
  26. dataset_dir (str): root directory for dataset.
  27. image_dir (str): directory for images.
  28. anno_path (str): json file path.
  29. sample_num (int): number of samples to load, -1 means all.
  30. with_background (bool): whether load background as a class.
  31. if True, total class number will be 81. default True.
  32. """
  33. def __init__(self,
  34. image_dir=None,
  35. anno_path=None,
  36. dataset_dir=None,
  37. sample_num=-1,
  38. with_background=True,
  39. load_semantic=False):
  40. super(COCODataSet, self).__init__(
  41. image_dir=image_dir,
  42. anno_path=anno_path,
  43. dataset_dir=dataset_dir,
  44. sample_num=sample_num,
  45. with_background=with_background)
  46. self.anno_path = anno_path
  47. self.sample_num = sample_num
  48. self.with_background = with_background
  49. # `roidbs` is list of dict whose structure is:
  50. # {
  51. # 'im_file': im_fname, # image file name
  52. # 'im_id': img_id, # image id
  53. # 'h': im_h, # height of image
  54. # 'w': im_w, # width
  55. # 'is_crowd': is_crowd,
  56. # 'gt_score': gt_score,
  57. # 'gt_class': gt_class,
  58. # 'gt_bbox': gt_bbox,
  59. # 'gt_poly': gt_poly,
  60. # }
  61. self.roidbs = None
  62. # a dict used to map category name to class id
  63. self.cname2cid = None
  64. self.load_image_only = False
  65. self.load_semantic = load_semantic
  66. def load_roidb_and_cname2cid(self):
  67. anno_path = os.path.join(self.dataset_dir, self.anno_path)
  68. image_dir = os.path.join(self.dataset_dir, self.image_dir)
  69. assert anno_path.endswith('.json'), \
  70. 'invalid coco annotation file: ' + anno_path
  71. from pycocotools.coco import COCO
  72. coco = COCO(anno_path)
  73. img_ids = coco.getImgIds()
  74. cat_ids = coco.getCatIds()
  75. records = []
  76. ct = 0
  77. # when with_background = True, mapping category to classid, like:
  78. # background:0, first_class:1, second_class:2, ...
  79. catid2clsid = dict({
  80. catid: i + int(self.with_background)
  81. for i, catid in enumerate(cat_ids)
  82. })
  83. cname2cid = dict({
  84. coco.loadCats(catid)[0]['name']: clsid
  85. for catid, clsid in catid2clsid.items()
  86. })
  87. if 'annotations' not in coco.dataset:
  88. self.load_image_only = True
  89. logger.warning('Annotation file: {} does not contains ground truth '
  90. 'and load image information only.'.format(anno_path))
  91. for img_id in img_ids:
  92. img_anno = coco.loadImgs([img_id])[0]
  93. im_fname = img_anno['file_name']
  94. im_w = float(img_anno['width'])
  95. im_h = float(img_anno['height'])
  96. im_path = os.path.join(image_dir,
  97. im_fname) if image_dir else im_fname
  98. if not os.path.exists(im_path):
  99. logger.warning('Illegal image file: {}, and it will be '
  100. 'ignored'.format(im_path))
  101. continue
  102. if im_w < 0 or im_h < 0:
  103. logger.warning('Illegal width: {} or height: {} in annotation, '
  104. 'and im_id: {} will be ignored'.format(
  105. im_w, im_h, img_id))
  106. continue
  107. coco_rec = {
  108. 'im_file': im_path,
  109. 'im_id': np.array([img_id]),
  110. 'h': im_h,
  111. 'w': im_w,
  112. }
  113. if not self.load_image_only:
  114. ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False)
  115. instances = coco.loadAnns(ins_anno_ids)
  116. bboxes = []
  117. for inst in instances:
  118. x, y, box_w, box_h = inst['bbox']
  119. x1 = max(0, x)
  120. y1 = max(0, y)
  121. x2 = min(im_w - 1, x1 + max(0, box_w - 1))
  122. y2 = min(im_h - 1, y1 + max(0, box_h - 1))
  123. if x2 >= x1 and y2 >= y1:
  124. inst['clean_bbox'] = [x1, y1, x2, y2]
  125. bboxes.append(inst)
  126. else:
  127. logger.warning(
  128. 'Found an invalid bbox in annotations: im_id: {}, '
  129. 'x1: {}, y1: {}, x2: {}, y2: {}.'.format(
  130. img_id, x1, y1, x2, y2))
  131. num_bbox = len(bboxes)
  132. if num_bbox <= 0:
  133. continue
  134. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  135. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  136. gt_score = np.ones((num_bbox, 1), dtype=np.float32)
  137. is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
  138. difficult = np.zeros((num_bbox, 1), dtype=np.int32)
  139. gt_poly = [None] * num_bbox
  140. has_segmentation = False
  141. for i, box in enumerate(bboxes):
  142. catid = box['category_id']
  143. gt_class[i][0] = catid2clsid[catid]
  144. gt_bbox[i, :] = box['clean_bbox']
  145. is_crowd[i][0] = box['iscrowd']
  146. if 'segmentation' in box and box['segmentation']:
  147. gt_poly[i] = box['segmentation']
  148. has_segmentation = True
  149. if has_segmentation and not any(gt_poly):
  150. continue
  151. coco_rec.update({
  152. 'is_crowd': is_crowd,
  153. 'gt_class': gt_class,
  154. 'gt_bbox': gt_bbox,
  155. 'gt_score': gt_score,
  156. 'gt_poly': gt_poly,
  157. })
  158. if self.load_semantic:
  159. seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
  160. 'train2017', im_fname[:-3] + 'png')
  161. coco_rec.update({'semantic': seg_path})
  162. logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
  163. im_path, img_id, im_h, im_w))
  164. records.append(coco_rec)
  165. ct += 1
  166. if self.sample_num > 0 and ct >= self.sample_num:
  167. break
  168. assert len(records) > 0, 'not found any coco record in %s' % (anno_path)
  169. logger.debug('{} samples in file {}'.format(ct, anno_path))
  170. self.roidbs, self.cname2cid = records, cname2cid