123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import cv2
- import json
- import copy
- import numpy as np
- try:
- from collections.abc import Sequence
- except Exception:
- from collections import Sequence
- from ppdet.core.workspace import register, serializable
- from ppdet.data.crop_utils.annotation_cropper import AnnoCropper
- from .coco import COCODataSet
- from .dataset import _make_dataset, _is_valid_file
- from ppdet.utils.logger import setup_logger
- logger = setup_logger('sniper_coco_dataset')
- @register
- @serializable
- class SniperCOCODataSet(COCODataSet):
- """SniperCOCODataSet"""
- def __init__(self,
- dataset_dir=None,
- image_dir=None,
- anno_path=None,
- proposals_file=None,
- data_fields=['image'],
- sample_num=-1,
- load_crowd=False,
- allow_empty=True,
- empty_ratio=1.,
- is_trainset=True,
- image_target_sizes=[2000, 1000],
- valid_box_ratio_ranges=[[-1, 0.1],[0.08, -1]],
- chip_target_size=500,
- chip_target_stride=200,
- use_neg_chip=False,
- max_neg_num_per_im=8,
- max_per_img=-1,
- nms_thresh=0.5):
- super(SniperCOCODataSet, self).__init__(
- dataset_dir=dataset_dir,
- image_dir=image_dir,
- anno_path=anno_path,
- data_fields=data_fields,
- sample_num=sample_num,
- load_crowd=load_crowd,
- allow_empty=allow_empty,
- empty_ratio=empty_ratio
- )
- self.proposals_file = proposals_file
- self.proposals = None
- self.anno_cropper = None
- self.is_trainset = is_trainset
- self.image_target_sizes = image_target_sizes
- self.valid_box_ratio_ranges = valid_box_ratio_ranges
- self.chip_target_size = chip_target_size
- self.chip_target_stride = chip_target_stride
- self.use_neg_chip = use_neg_chip
- self.max_neg_num_per_im = max_neg_num_per_im
- self.max_per_img = max_per_img
- self.nms_thresh = nms_thresh
- def parse_dataset(self):
- if not hasattr(self, "roidbs"):
- super(SniperCOCODataSet, self).parse_dataset()
- if self.is_trainset:
- self._parse_proposals()
- self._merge_anno_proposals()
- self.ori_roidbs = copy.deepcopy(self.roidbs)
- self.init_anno_cropper()
- self.roidbs = self.generate_chips_roidbs(self.roidbs, self.is_trainset)
- def set_proposals_file(self, file_path):
- self.proposals_file = file_path
- def init_anno_cropper(self):
- logger.info("Init AnnoCropper...")
- self.anno_cropper = AnnoCropper(
- image_target_sizes=self.image_target_sizes,
- valid_box_ratio_ranges=self.valid_box_ratio_ranges,
- chip_target_size=self.chip_target_size,
- chip_target_stride=self.chip_target_stride,
- use_neg_chip=self.use_neg_chip,
- max_neg_num_per_im=self.max_neg_num_per_im,
- max_per_img=self.max_per_img,
- nms_thresh=self.nms_thresh
- )
- def generate_chips_roidbs(self, roidbs, is_trainset):
- if is_trainset:
- roidbs = self.anno_cropper.crop_anno_records(roidbs)
- else:
- roidbs = self.anno_cropper.crop_infer_anno_records(roidbs)
- return roidbs
- def _parse_proposals(self):
- if self.proposals_file:
- self.proposals = {}
- logger.info("Parse proposals file:{}".format(self.proposals_file))
- with open(self.proposals_file, 'r') as f:
- proposals = json.load(f)
- for prop in proposals:
- image_id = prop["image_id"]
- if image_id not in self.proposals:
- self.proposals[image_id] = []
- x, y, w, h = prop["bbox"]
- self.proposals[image_id].append([x, y, x + w, y + h])
- def _merge_anno_proposals(self):
- assert self.roidbs
- if self.proposals and len(self.proposals.keys()) > 0:
- logger.info("merge proposals to annos")
- for id, record in enumerate(self.roidbs):
- image_id = int(record["im_id"])
- if image_id not in self.proposals.keys():
- logger.info("image id :{} no proposals".format(image_id))
- record["proposals"] = np.array(self.proposals.get(image_id, []), dtype=np.float32)
- self.roidbs[id] = record
- def get_ori_roidbs(self):
- if not hasattr(self, "ori_roidbs"):
- return None
- return self.ori_roidbs
- def get_roidbs(self):
- if not hasattr(self, "roidbs"):
- self.parse_dataset()
- return self.roidbs
- def set_roidbs(self, roidbs):
- self.roidbs = roidbs
- def check_or_download_dataset(self):
- return
- def _parse(self):
- image_dir = self.image_dir
- if not isinstance(image_dir, Sequence):
- image_dir = [image_dir]
- images = []
- for im_dir in image_dir:
- if os.path.isdir(im_dir):
- im_dir = os.path.join(self.dataset_dir, im_dir)
- images.extend(_make_dataset(im_dir))
- elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
- images.append(im_dir)
- return images
- def _load_images(self):
- images = self._parse()
- ct = 0
- records = []
- for image in images:
- assert image != '' and os.path.isfile(image), \
- "Image {} not found".format(image)
- if self.sample_num > 0 and ct >= self.sample_num:
- break
- im = cv2.imread(image)
- h, w, c = im.shape
- rec = {'im_id': np.array([ct]), 'im_file': image, "h": h, "w": w}
- self._imid2path[ct] = image
- ct += 1
- records.append(rec)
- assert len(records) > 0, "No image file found"
- return records
- def get_imid2path(self):
- return self._imid2path
- def set_images(self, images):
- self._imid2path = {}
- self.image_dir = images
- self.roidbs = self._load_images()
|