123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458 |
- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import copy
- import functools
- import collections
- import traceback
- import numpy as np
- import logging
- from ppdet.core.workspace import register, serializable
- from .parallel_map import ParallelMap
- from .transform.batch_operators import Gt2YoloTarget
- __all__ = ['Reader', 'create_reader']
- logger = logging.getLogger(__name__)
- class Compose(object):
- def __init__(self, transforms, ctx=None):
- self.transforms = transforms
- self.ctx = ctx
- def __call__(self, data):
- ctx = self.ctx if self.ctx else {}
- for f in self.transforms:
- try:
- data = f(data, ctx)
- except Exception as e:
- stack_info = traceback.format_exc()
- logger.warning(
- "fail to map op [{}] with error: {} and stack:\n{}".format(
- f, e, str(stack_info)))
- raise e
- return data
- def _calc_img_weights(roidbs):
- """ calculate the probabilities of each sample
- """
- imgs_cls = []
- num_per_cls = {}
- img_weights = []
- for i, roidb in enumerate(roidbs):
- img_cls = set([k for cls in roidbs[i]['gt_class'] for k in cls])
- imgs_cls.append(img_cls)
- for c in img_cls:
- if c not in num_per_cls:
- num_per_cls[c] = 1
- else:
- num_per_cls[c] += 1
- for i in range(len(roidbs)):
- weights = 0
- for c in imgs_cls[i]:
- weights += 1 / num_per_cls[c]
- img_weights.append(weights)
- # probabilities sum to 1
- img_weights = img_weights / np.sum(img_weights)
- return img_weights
- def _has_empty(item):
- def empty(x):
- if isinstance(x, np.ndarray) and x.size == 0:
- return True
- elif isinstance(x, collections.Sequence) and len(x) == 0:
- return True
- else:
- return False
- if isinstance(item, collections.Sequence) and len(item) == 0:
- return True
- if item is None:
- return True
- if empty(item):
- return True
- return False
- def _segm(samples):
- assert 'gt_poly' in samples
- segms = samples['gt_poly']
- if 'is_crowd' in samples:
- is_crowd = samples['is_crowd']
- if len(segms) != 0:
- assert len(segms) == is_crowd.shape[0]
- gt_masks = []
- valid = True
- for i in range(len(segms)):
- segm = segms[i]
- gt_segm = []
- if 'is_crowd' in samples and is_crowd[i]:
- gt_segm.append([[0, 0]])
- else:
- for poly in segm:
- if len(poly) == 0:
- valid = False
- break
- gt_segm.append(np.array(poly).reshape(-1, 2))
- if (not valid) or len(gt_segm) == 0:
- break
- gt_masks.append(gt_segm)
- return gt_masks
- def batch_arrange(batch_samples, fields):
- def im_shape(samples, dim=3):
- # hard code
- assert 'h' in samples
- assert 'w' in samples
- if dim == 3: # RCNN, ..
- return np.array((samples['h'], samples['w'], 1), dtype=np.float32)
- else: # YOLOv3, ..
- return np.array((samples['h'], samples['w']), dtype=np.int32)
- arrange_batch = []
- for samples in batch_samples:
- one_ins = ()
- for i, field in enumerate(fields):
- if field == 'gt_mask':
- one_ins += (_segm(samples), )
- elif field == 'im_shape':
- one_ins += (im_shape(samples), )
- elif field == 'im_size':
- one_ins += (im_shape(samples, 2), )
- else:
- if field == 'is_difficult':
- field = 'difficult'
- assert field in samples, '{} not in samples'.format(field)
- one_ins += (samples[field], )
- arrange_batch.append(one_ins)
- return arrange_batch
- @register
- @serializable
- class Reader(object):
- """
- Args:
- dataset (DataSet): DataSet object
- sample_transforms (list of BaseOperator): a list of sample transforms
- operators.
- batch_transforms (list of BaseOperator): a list of batch transforms
- operators.
- batch_size (int): batch size.
- shuffle (bool): whether shuffle dataset or not. Default False.
- drop_last (bool): whether drop last batch or not. Default False.
- drop_empty (bool): whether drop sample when it's gt is empty or not.
- Default True.
- mixup_epoch (int): mixup epoc number. Default is -1, meaning
- not use mixup.
- cutmix_epoch (int): cutmix epoc number. Default is -1, meaning
- not use cutmix.
- class_aware_sampling (bool): whether use class-aware sampling or not.
- Default False.
- worker_num (int): number of working threads/processes.
- Default -1, meaning not use multi-threads/multi-processes.
- use_process (bool): whether use multi-processes or not.
- It only works when worker_num > 1. Default False.
- bufsize (int): buffer size for multi-threads/multi-processes,
- please note, one instance in buffer is one batch data.
- memsize (str): size of shared memory used in result queue when
- use_process is true. Default 3G.
- inputs_def (dict): network input definition use to get input fields,
- which is used to determine the order of returned data.
- devices_num (int): number of devices.
- num_trainers (int): number of trainers. Default 1.
- """
- def __init__(self,
- dataset=None,
- sample_transforms=None,
- batch_transforms=None,
- batch_size=1,
- shuffle=False,
- drop_last=False,
- drop_empty=True,
- mixup_epoch=-1,
- cutmix_epoch=-1,
- class_aware_sampling=False,
- worker_num=-1,
- use_process=False,
- use_fine_grained_loss=False,
- num_classes=80,
- bufsize=-1,
- memsize='3G',
- inputs_def=None,
- devices_num=1,
- num_trainers=1):
- self._dataset = dataset
- self._roidbs = self._dataset.get_roidb()
- self._fields = copy.deepcopy(inputs_def[
- 'fields']) if inputs_def else None
- # transform
- self._sample_transforms = Compose(sample_transforms,
- {'fields': self._fields})
- self._batch_transforms = None
- if use_fine_grained_loss:
- for bt in batch_transforms:
- if isinstance(bt, Gt2YoloTarget):
- bt.num_classes = num_classes
- elif batch_transforms:
- batch_transforms = [
- bt for bt in batch_transforms
- if not isinstance(bt, Gt2YoloTarget)
- ]
- if batch_transforms:
- self._batch_transforms = Compose(batch_transforms,
- {'fields': self._fields})
- # data
- if inputs_def and inputs_def.get('multi_scale', False):
- from ppdet.modeling.architectures.input_helper import multiscale_def
- im_shape = inputs_def[
- 'image_shape'] if 'image_shape' in inputs_def else [
- 3, None, None
- ]
- _, ms_fields = multiscale_def(im_shape, inputs_def['num_scales'],
- inputs_def['use_flip'])
- self._fields += ms_fields
- self._batch_size = batch_size
- self._shuffle = shuffle
- self._drop_last = drop_last
- self._drop_empty = drop_empty
- # sampling
- self._mixup_epoch = mixup_epoch // num_trainers
- self._cutmix_epoch = cutmix_epoch // num_trainers
- self._class_aware_sampling = class_aware_sampling
- self._load_img = False
- self._sample_num = len(self._roidbs)
- if self._class_aware_sampling:
- self.img_weights = _calc_img_weights(self._roidbs)
- self._indexes = None
- self._pos = -1
- self._epoch = -1
- self._curr_iter = 0
- # multi-process
- self._worker_num = worker_num
- self._parallel = None
- if self._worker_num > -1:
- task = functools.partial(self.worker, self._drop_empty)
- bufsize = devices_num * 2 if bufsize == -1 else bufsize
- self._parallel = ParallelMap(self, task, worker_num, bufsize,
- use_process, memsize)
- def __call__(self):
- if self._worker_num > -1:
- return self._parallel
- else:
- return self
- def __iter__(self):
- return self
- def reset(self):
- """implementation of Dataset.reset
- """
- if self._epoch < 0:
- self._epoch = 0
- else:
- self._epoch += 1
- self.indexes = [i for i in range(self.size())]
- if self._class_aware_sampling:
- self.indexes = np.random.choice(
- self._sample_num,
- self._sample_num,
- replace=True,
- p=self.img_weights)
- if self._shuffle:
- trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
- np.random.seed(self._epoch + trainer_id)
- np.random.shuffle(self.indexes)
- if self._mixup_epoch > 0 and len(self.indexes) < 2:
- logger.debug("Disable mixup for dataset samples "
- "less than 2 samples")
- self._mixup_epoch = -1
- if self._cutmix_epoch > 0 and len(self.indexes) < 2:
- logger.info("Disable cutmix for dataset samples "
- "less than 2 samples")
- self._cutmix_epoch = -1
- self._pos = 0
- def __next__(self):
- return self.next()
- def next(self):
- if self._epoch < 0:
- self.reset()
- if self.drained():
- raise StopIteration
- batch = self._load_batch()
- self._curr_iter += 1
- if self._drop_last and len(batch) < self._batch_size:
- raise StopIteration
- if self._worker_num > -1:
- return batch
- else:
- return self.worker(self._drop_empty, batch)
- def _load_batch(self):
- batch = []
- bs = 0
- while bs != self._batch_size:
- if self._pos >= self.size():
- break
- pos = self.indexes[self._pos]
- sample = copy.deepcopy(self._roidbs[pos])
- sample["curr_iter"] = self._curr_iter
- self._pos += 1
- if self._drop_empty and self._fields and 'gt_bbox' in sample:
- if _has_empty(sample['gt_bbox']):
- #logger.warn('gt_bbox {} is empty or not valid in {}, '
- # 'drop this sample'.format(
- # sample['im_file'], sample['gt_bbox']))
- continue
- has_mask = 'gt_mask' in self._fields or 'gt_segm' in self._fields
- if self._drop_empty and self._fields and has_mask:
- if _has_empty(_segm(sample)):
- #logger.warn('gt_mask is empty or not valid in {}'.format(
- # sample['im_file']))
- continue
- if self._load_img:
- sample['image'] = self._load_image(sample['im_file'])
- if self._epoch < self._mixup_epoch:
- num = len(self.indexes)
- mix_idx = np.random.randint(1, num)
- mix_idx = self.indexes[(mix_idx + self._pos - 1) % num]
- sample['mixup'] = copy.deepcopy(self._roidbs[mix_idx])
- sample['mixup']["curr_iter"] = self._curr_iter
- if self._load_img:
- sample['mixup']['image'] = self._load_image(sample['mixup'][
- 'im_file'])
- if self._epoch < self._cutmix_epoch:
- num = len(self.indexes)
- mix_idx = np.random.randint(1, num)
- sample['cutmix'] = copy.deepcopy(self._roidbs[mix_idx])
- sample['cutmix']["curr_iter"] = self._curr_iter
- if self._load_img:
- sample['cutmix']['image'] = self._load_image(sample[
- 'cutmix']['im_file'])
- batch.append(sample)
- bs += 1
- return batch
- def worker(self, drop_empty=True, batch_samples=None):
- """
- sample transform and batch transform.
- """
- batch = []
- for sample in batch_samples:
- sample = self._sample_transforms(sample)
- if drop_empty and 'gt_bbox' in sample:
- if _has_empty(sample['gt_bbox']):
- #logger.warn('gt_bbox {} is empty or not valid in {}, '
- # 'drop this sample'.format(
- # sample['im_file'], sample['gt_bbox']))
- continue
- batch.append(sample)
- if len(batch) > 0 and self._batch_transforms:
- batch = self._batch_transforms(batch)
- if len(batch) > 0 and self._fields:
- batch = batch_arrange(batch, self._fields)
- return batch
- def _load_image(self, filename):
- with open(filename, 'rb') as f:
- return f.read()
- def size(self):
- """ implementation of Dataset.size
- """
- return self._sample_num
- def drained(self):
- """ implementation of Dataset.drained
- """
- assert self._epoch >= 0, 'The first epoch has not begin!'
- return self._pos >= self.size()
- def stop(self):
- if self._parallel:
- self._parallel.stop()
- def create_reader(cfg,
- max_iter=0,
- global_cfg=None,
- devices_num=1,
- num_trainers=1):
- """
- Return iterable data reader.
- Args:
- max_iter (int): number of iterations.
- """
- if not isinstance(cfg, dict):
- raise TypeError("The config should be a dict when creating reader.")
- # synchornize use_fine_grained_loss/num_classes from global_cfg to reader cfg
- if global_cfg:
- cfg['use_fine_grained_loss'] = getattr(global_cfg,
- 'use_fine_grained_loss', False)
- cfg['num_classes'] = getattr(global_cfg, 'num_classes', 80)
- cfg['devices_num'] = devices_num
- cfg['num_trainers'] = num_trainers
- reader = Reader(**cfg)()
- def _reader():
- n = 0
- while True:
- for _batch in reader:
- if len(_batch) > 0:
- yield _batch
- n += 1
- if max_iter > 0 and n == max_iter:
- return
- reader.reset()
- if max_iter <= 0:
- return
- return _reader
|