123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import numpy as np
- from ppdet.core.workspace import register, serializable
- from .dataset import DetDataset
- from ppdet.utils.logger import setup_logger
- logger = setup_logger(__name__)
- @register
- @serializable
- class COCODataSet(DetDataset):
- """
- Load dataset with COCO format.
- Args:
- dataset_dir (str): root directory for dataset.
- image_dir (str): directory for images.
- anno_path (str): coco annotation file path.
- data_fields (list): key name of data dictionary, at least have 'image'.
- sample_num (int): number of samples to load, -1 means all.
- load_crowd (bool): whether to load crowded ground-truth.
- False as default
- allow_empty (bool): whether to load empty entry. False as default
- empty_ratio (float): the ratio of empty record number to total
- record's, if empty_ratio is out of [0. ,1.), do not sample the
- records and use all the empty entries. 1. as default
- """
- def __init__(self,
- dataset_dir=None,
- image_dir=None,
- anno_path=None,
- data_fields=['image'],
- sample_num=-1,
- load_crowd=False,
- allow_empty=False,
- empty_ratio=1.):
- super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
- data_fields, sample_num)
- self.load_image_only = False
- self.load_semantic = False
- self.load_crowd = load_crowd
- self.allow_empty = allow_empty
- self.empty_ratio = empty_ratio
- def _sample_empty(self, records, num):
- # if empty_ratio is out of [0. ,1.), do not sample the records
- if self.empty_ratio < 0. or self.empty_ratio >= 1.:
- return records
- import random
- sample_num = min(
- int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
- records = random.sample(records, sample_num)
- return records
- def parse_dataset(self):
- anno_path = os.path.join(self.dataset_dir, self.anno_path)
- image_dir = os.path.join(self.dataset_dir, self.image_dir)
- assert anno_path.endswith('.json'), \
- 'invalid coco annotation file: ' + anno_path
- from pycocotools.coco import COCO
- coco = COCO(anno_path)
- img_ids = coco.getImgIds()
- img_ids.sort()
- cat_ids = coco.getCatIds()
- records = []
- empty_records = []
- ct = 0
- self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
- self.cname2cid = dict({
- coco.loadCats(catid)[0]['name']: clsid
- for catid, clsid in self.catid2clsid.items()
- })
- if 'annotations' not in coco.dataset:
- self.load_image_only = True
- logger.warning('Annotation file: {} does not contains ground truth '
- 'and load image information only.'.format(anno_path))
- for img_id in img_ids:
- img_anno = coco.loadImgs([img_id])[0]
- im_fname = img_anno['file_name']
- im_w = float(img_anno['width'])
- im_h = float(img_anno['height'])
- im_path = os.path.join(image_dir,
- im_fname) if image_dir else im_fname
- is_empty = False
- if not os.path.exists(im_path):
- logger.warning('Illegal image file: {}, and it will be '
- 'ignored'.format(im_path))
- continue
- if im_w < 0 or im_h < 0:
- logger.warning('Illegal width: {} or height: {} in annotation, '
- 'and im_id: {} will be ignored'.format(
- im_w, im_h, img_id))
- continue
- coco_rec = {
- 'im_file': im_path,
- 'im_id': np.array([img_id]),
- 'h': im_h,
- 'w': im_w,
- } if 'image' in self.data_fields else {}
- if not self.load_image_only:
- ins_anno_ids = coco.getAnnIds(
- imgIds=[img_id], iscrowd=None if self.load_crowd else False)
- instances = coco.loadAnns(ins_anno_ids)
- bboxes = []
- is_rbox_anno = False
- for inst in instances:
- # check gt bbox
- if inst.get('ignore', False):
- continue
- if 'bbox' not in inst.keys():
- continue
- else:
- if not any(np.array(inst['bbox'])):
- continue
- # read rbox anno or not
- is_rbox_anno = True if len(inst['bbox']) == 5 else False
- if is_rbox_anno:
- xc, yc, box_w, box_h, angle = inst['bbox']
- x1 = xc - box_w / 2.0
- y1 = yc - box_h / 2.0
- x2 = x1 + box_w
- y2 = y1 + box_h
- else:
- x1, y1, box_w, box_h = inst['bbox']
- x2 = x1 + box_w
- y2 = y1 + box_h
- eps = 1e-5
- if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
- inst['clean_bbox'] = [
- round(float(x), 3) for x in [x1, y1, x2, y2]
- ]
- if is_rbox_anno:
- inst['clean_rbox'] = [xc, yc, box_w, box_h, angle]
- bboxes.append(inst)
- else:
- logger.warning(
- 'Found an invalid bbox in annotations: im_id: {}, '
- 'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
- img_id, float(inst['area']), x1, y1, x2, y2))
- num_bbox = len(bboxes)
- if num_bbox <= 0 and not self.allow_empty:
- continue
- elif num_bbox <= 0:
- is_empty = True
- gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
- if is_rbox_anno:
- gt_rbox = np.zeros((num_bbox, 5), dtype=np.float32)
- gt_theta = np.zeros((num_bbox, 1), dtype=np.int32)
- gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
- is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
- gt_poly = [None] * num_bbox
- has_segmentation = False
- for i, box in enumerate(bboxes):
- catid = box['category_id']
- gt_class[i][0] = self.catid2clsid[catid]
- gt_bbox[i, :] = box['clean_bbox']
- # xc, yc, w, h, theta
- if is_rbox_anno:
- gt_rbox[i, :] = box['clean_rbox']
- is_crowd[i][0] = box['iscrowd']
- # check RLE format
- if 'segmentation' in box and box['iscrowd'] == 1:
- gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
- elif 'segmentation' in box and box['segmentation']:
- if not np.array(box['segmentation']
- ).size > 0 and not self.allow_empty:
- bboxes.pop(i)
- gt_poly.pop(i)
- np.delete(is_crowd, i)
- np.delete(gt_class, i)
- np.delete(gt_bbox, i)
- else:
- gt_poly[i] = box['segmentation']
- has_segmentation = True
- if has_segmentation and not any(
- gt_poly) and not self.allow_empty:
- continue
- if is_rbox_anno:
- gt_rec = {
- 'is_crowd': is_crowd,
- 'gt_class': gt_class,
- 'gt_bbox': gt_bbox,
- 'gt_rbox': gt_rbox,
- 'gt_poly': gt_poly,
- }
- else:
- gt_rec = {
- 'is_crowd': is_crowd,
- 'gt_class': gt_class,
- 'gt_bbox': gt_bbox,
- 'gt_poly': gt_poly,
- }
- for k, v in gt_rec.items():
- if k in self.data_fields:
- coco_rec[k] = v
- # TODO: remove load_semantic
- if self.load_semantic and 'semantic' in self.data_fields:
- seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
- 'train2017', im_fname[:-3] + 'png')
- coco_rec.update({'semantic': seg_path})
- logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
- im_path, img_id, im_h, im_w))
- if is_empty:
- empty_records.append(coco_rec)
- else:
- records.append(coco_rec)
- ct += 1
- if self.sample_num > 0 and ct >= self.sample_num:
- break
- assert ct > 0, 'not found any coco record in %s' % (anno_path)
- logger.debug('{} samples in file {}'.format(ct, anno_path))
- if self.allow_empty and len(empty_records) > 0:
- empty_records = self._sample_empty(empty_records, len(records))
- records += empty_records
- self.roidbs = records
|