reader.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import copy
  19. import functools
  20. import collections
  21. import traceback
  22. import numpy as np
  23. import logging
  24. from ppdet.core.workspace import register, serializable
  25. from .parallel_map import ParallelMap
  26. from .transform.batch_operators import Gt2YoloTarget
  27. __all__ = ['Reader', 'create_reader']
  28. logger = logging.getLogger(__name__)
  29. class Compose(object):
  30. def __init__(self, transforms, ctx=None):
  31. self.transforms = transforms
  32. self.ctx = ctx
  33. def __call__(self, data):
  34. ctx = self.ctx if self.ctx else {}
  35. for f in self.transforms:
  36. try:
  37. data = f(data, ctx)
  38. except Exception as e:
  39. stack_info = traceback.format_exc()
  40. logger.warning(
  41. "fail to map op [{}] with error: {} and stack:\n{}".format(
  42. f, e, str(stack_info)))
  43. raise e
  44. return data
  45. def _calc_img_weights(roidbs):
  46. """ calculate the probabilities of each sample
  47. """
  48. imgs_cls = []
  49. num_per_cls = {}
  50. img_weights = []
  51. for i, roidb in enumerate(roidbs):
  52. img_cls = set([k for cls in roidbs[i]['gt_class'] for k in cls])
  53. imgs_cls.append(img_cls)
  54. for c in img_cls:
  55. if c not in num_per_cls:
  56. num_per_cls[c] = 1
  57. else:
  58. num_per_cls[c] += 1
  59. for i in range(len(roidbs)):
  60. weights = 0
  61. for c in imgs_cls[i]:
  62. weights += 1 / num_per_cls[c]
  63. img_weights.append(weights)
  64. # probabilities sum to 1
  65. img_weights = img_weights / np.sum(img_weights)
  66. return img_weights
  67. def _has_empty(item):
  68. def empty(x):
  69. if isinstance(x, np.ndarray) and x.size == 0:
  70. return True
  71. elif isinstance(x, collections.Sequence) and len(x) == 0:
  72. return True
  73. else:
  74. return False
  75. if isinstance(item, collections.Sequence) and len(item) == 0:
  76. return True
  77. if item is None:
  78. return True
  79. if empty(item):
  80. return True
  81. return False
  82. def _segm(samples):
  83. assert 'gt_poly' in samples
  84. segms = samples['gt_poly']
  85. if 'is_crowd' in samples:
  86. is_crowd = samples['is_crowd']
  87. if len(segms) != 0:
  88. assert len(segms) == is_crowd.shape[0]
  89. gt_masks = []
  90. valid = True
  91. for i in range(len(segms)):
  92. segm = segms[i]
  93. gt_segm = []
  94. if 'is_crowd' in samples and is_crowd[i]:
  95. gt_segm.append([[0, 0]])
  96. else:
  97. for poly in segm:
  98. if len(poly) == 0:
  99. valid = False
  100. break
  101. gt_segm.append(np.array(poly).reshape(-1, 2))
  102. if (not valid) or len(gt_segm) == 0:
  103. break
  104. gt_masks.append(gt_segm)
  105. return gt_masks
  106. def batch_arrange(batch_samples, fields):
  107. def im_shape(samples, dim=3):
  108. # hard code
  109. assert 'h' in samples
  110. assert 'w' in samples
  111. if dim == 3: # RCNN, ..
  112. return np.array((samples['h'], samples['w'], 1), dtype=np.float32)
  113. else: # YOLOv3, ..
  114. return np.array((samples['h'], samples['w']), dtype=np.int32)
  115. arrange_batch = []
  116. for samples in batch_samples:
  117. one_ins = ()
  118. for i, field in enumerate(fields):
  119. if field == 'gt_mask':
  120. one_ins += (_segm(samples), )
  121. elif field == 'im_shape':
  122. one_ins += (im_shape(samples), )
  123. elif field == 'im_size':
  124. one_ins += (im_shape(samples, 2), )
  125. else:
  126. if field == 'is_difficult':
  127. field = 'difficult'
  128. assert field in samples, '{} not in samples'.format(field)
  129. one_ins += (samples[field], )
  130. arrange_batch.append(one_ins)
  131. return arrange_batch
  132. @register
  133. @serializable
  134. class Reader(object):
  135. """
  136. Args:
  137. dataset (DataSet): DataSet object
  138. sample_transforms (list of BaseOperator): a list of sample transforms
  139. operators.
  140. batch_transforms (list of BaseOperator): a list of batch transforms
  141. operators.
  142. batch_size (int): batch size.
  143. shuffle (bool): whether shuffle dataset or not. Default False.
  144. drop_last (bool): whether drop last batch or not. Default False.
  145. drop_empty (bool): whether drop sample when it's gt is empty or not.
  146. Default True.
  147. mixup_epoch (int): mixup epoc number. Default is -1, meaning
  148. not use mixup.
  149. cutmix_epoch (int): cutmix epoc number. Default is -1, meaning
  150. not use cutmix.
  151. class_aware_sampling (bool): whether use class-aware sampling or not.
  152. Default False.
  153. worker_num (int): number of working threads/processes.
  154. Default -1, meaning not use multi-threads/multi-processes.
  155. use_process (bool): whether use multi-processes or not.
  156. It only works when worker_num > 1. Default False.
  157. bufsize (int): buffer size for multi-threads/multi-processes,
  158. please note, one instance in buffer is one batch data.
  159. memsize (str): size of shared memory used in result queue when
  160. use_process is true. Default 3G.
  161. inputs_def (dict): network input definition use to get input fields,
  162. which is used to determine the order of returned data.
  163. devices_num (int): number of devices.
  164. num_trainers (int): number of trainers. Default 1.
  165. """
  166. def __init__(self,
  167. dataset=None,
  168. sample_transforms=None,
  169. batch_transforms=None,
  170. batch_size=1,
  171. shuffle=False,
  172. drop_last=False,
  173. drop_empty=True,
  174. mixup_epoch=-1,
  175. cutmix_epoch=-1,
  176. class_aware_sampling=False,
  177. worker_num=-1,
  178. use_process=False,
  179. use_fine_grained_loss=False,
  180. num_classes=80,
  181. bufsize=-1,
  182. memsize='3G',
  183. inputs_def=None,
  184. devices_num=1,
  185. num_trainers=1):
  186. self._dataset = dataset
  187. self._roidbs = self._dataset.get_roidb()
  188. self._fields = copy.deepcopy(inputs_def[
  189. 'fields']) if inputs_def else None
  190. # transform
  191. self._sample_transforms = Compose(sample_transforms,
  192. {'fields': self._fields})
  193. self._batch_transforms = None
  194. if use_fine_grained_loss:
  195. for bt in batch_transforms:
  196. if isinstance(bt, Gt2YoloTarget):
  197. bt.num_classes = num_classes
  198. elif batch_transforms:
  199. batch_transforms = [
  200. bt for bt in batch_transforms
  201. if not isinstance(bt, Gt2YoloTarget)
  202. ]
  203. if batch_transforms:
  204. self._batch_transforms = Compose(batch_transforms,
  205. {'fields': self._fields})
  206. # data
  207. if inputs_def and inputs_def.get('multi_scale', False):
  208. from ppdet.modeling.architectures.input_helper import multiscale_def
  209. im_shape = inputs_def[
  210. 'image_shape'] if 'image_shape' in inputs_def else [
  211. 3, None, None
  212. ]
  213. _, ms_fields = multiscale_def(im_shape, inputs_def['num_scales'],
  214. inputs_def['use_flip'])
  215. self._fields += ms_fields
  216. self._batch_size = batch_size
  217. self._shuffle = shuffle
  218. self._drop_last = drop_last
  219. self._drop_empty = drop_empty
  220. # sampling
  221. self._mixup_epoch = mixup_epoch // num_trainers
  222. self._cutmix_epoch = cutmix_epoch // num_trainers
  223. self._class_aware_sampling = class_aware_sampling
  224. self._load_img = False
  225. self._sample_num = len(self._roidbs)
  226. if self._class_aware_sampling:
  227. self.img_weights = _calc_img_weights(self._roidbs)
  228. self._indexes = None
  229. self._pos = -1
  230. self._epoch = -1
  231. self._curr_iter = 0
  232. # multi-process
  233. self._worker_num = worker_num
  234. self._parallel = None
  235. if self._worker_num > -1:
  236. task = functools.partial(self.worker, self._drop_empty)
  237. bufsize = devices_num * 2 if bufsize == -1 else bufsize
  238. self._parallel = ParallelMap(self, task, worker_num, bufsize,
  239. use_process, memsize)
  240. def __call__(self):
  241. if self._worker_num > -1:
  242. return self._parallel
  243. else:
  244. return self
  245. def __iter__(self):
  246. return self
  247. def reset(self):
  248. """implementation of Dataset.reset
  249. """
  250. if self._epoch < 0:
  251. self._epoch = 0
  252. else:
  253. self._epoch += 1
  254. self.indexes = [i for i in range(self.size())]
  255. if self._class_aware_sampling:
  256. self.indexes = np.random.choice(
  257. self._sample_num,
  258. self._sample_num,
  259. replace=True,
  260. p=self.img_weights)
  261. if self._shuffle:
  262. trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
  263. np.random.seed(self._epoch + trainer_id)
  264. np.random.shuffle(self.indexes)
  265. if self._mixup_epoch > 0 and len(self.indexes) < 2:
  266. logger.debug("Disable mixup for dataset samples "
  267. "less than 2 samples")
  268. self._mixup_epoch = -1
  269. if self._cutmix_epoch > 0 and len(self.indexes) < 2:
  270. logger.info("Disable cutmix for dataset samples "
  271. "less than 2 samples")
  272. self._cutmix_epoch = -1
  273. self._pos = 0
  274. def __next__(self):
  275. return self.next()
  276. def next(self):
  277. if self._epoch < 0:
  278. self.reset()
  279. if self.drained():
  280. raise StopIteration
  281. batch = self._load_batch()
  282. self._curr_iter += 1
  283. if self._drop_last and len(batch) < self._batch_size:
  284. raise StopIteration
  285. if self._worker_num > -1:
  286. return batch
  287. else:
  288. return self.worker(self._drop_empty, batch)
  289. def _load_batch(self):
  290. batch = []
  291. bs = 0
  292. while bs != self._batch_size:
  293. if self._pos >= self.size():
  294. break
  295. pos = self.indexes[self._pos]
  296. sample = copy.deepcopy(self._roidbs[pos])
  297. sample["curr_iter"] = self._curr_iter
  298. self._pos += 1
  299. if self._drop_empty and self._fields and 'gt_bbox' in sample:
  300. if _has_empty(sample['gt_bbox']):
  301. #logger.warn('gt_bbox {} is empty or not valid in {}, '
  302. # 'drop this sample'.format(
  303. # sample['im_file'], sample['gt_bbox']))
  304. continue
  305. has_mask = 'gt_mask' in self._fields or 'gt_segm' in self._fields
  306. if self._drop_empty and self._fields and has_mask:
  307. if _has_empty(_segm(sample)):
  308. #logger.warn('gt_mask is empty or not valid in {}'.format(
  309. # sample['im_file']))
  310. continue
  311. if self._load_img:
  312. sample['image'] = self._load_image(sample['im_file'])
  313. if self._epoch < self._mixup_epoch:
  314. num = len(self.indexes)
  315. mix_idx = np.random.randint(1, num)
  316. mix_idx = self.indexes[(mix_idx + self._pos - 1) % num]
  317. sample['mixup'] = copy.deepcopy(self._roidbs[mix_idx])
  318. sample['mixup']["curr_iter"] = self._curr_iter
  319. if self._load_img:
  320. sample['mixup']['image'] = self._load_image(sample['mixup'][
  321. 'im_file'])
  322. if self._epoch < self._cutmix_epoch:
  323. num = len(self.indexes)
  324. mix_idx = np.random.randint(1, num)
  325. sample['cutmix'] = copy.deepcopy(self._roidbs[mix_idx])
  326. sample['cutmix']["curr_iter"] = self._curr_iter
  327. if self._load_img:
  328. sample['cutmix']['image'] = self._load_image(sample[
  329. 'cutmix']['im_file'])
  330. batch.append(sample)
  331. bs += 1
  332. return batch
  333. def worker(self, drop_empty=True, batch_samples=None):
  334. """
  335. sample transform and batch transform.
  336. """
  337. batch = []
  338. for sample in batch_samples:
  339. sample = self._sample_transforms(sample)
  340. if drop_empty and 'gt_bbox' in sample:
  341. if _has_empty(sample['gt_bbox']):
  342. #logger.warn('gt_bbox {} is empty or not valid in {}, '
  343. # 'drop this sample'.format(
  344. # sample['im_file'], sample['gt_bbox']))
  345. continue
  346. batch.append(sample)
  347. if len(batch) > 0 and self._batch_transforms:
  348. batch = self._batch_transforms(batch)
  349. if len(batch) > 0 and self._fields:
  350. batch = batch_arrange(batch, self._fields)
  351. return batch
  352. def _load_image(self, filename):
  353. with open(filename, 'rb') as f:
  354. return f.read()
  355. def size(self):
  356. """ implementation of Dataset.size
  357. """
  358. return self._sample_num
  359. def drained(self):
  360. """ implementation of Dataset.drained
  361. """
  362. assert self._epoch >= 0, 'The first epoch has not begin!'
  363. return self._pos >= self.size()
  364. def stop(self):
  365. if self._parallel:
  366. self._parallel.stop()
  367. def create_reader(cfg,
  368. max_iter=0,
  369. global_cfg=None,
  370. devices_num=1,
  371. num_trainers=1):
  372. """
  373. Return iterable data reader.
  374. Args:
  375. max_iter (int): number of iterations.
  376. """
  377. if not isinstance(cfg, dict):
  378. raise TypeError("The config should be a dict when creating reader.")
  379. # synchornize use_fine_grained_loss/num_classes from global_cfg to reader cfg
  380. if global_cfg:
  381. cfg['use_fine_grained_loss'] = getattr(global_cfg,
  382. 'use_fine_grained_loss', False)
  383. cfg['num_classes'] = getattr(global_cfg, 'num_classes', 80)
  384. cfg['devices_num'] = devices_num
  385. cfg['num_trainers'] = num_trainers
  386. reader = Reader(**cfg)()
  387. def _reader():
  388. n = 0
  389. while True:
  390. for _batch in reader:
  391. if len(_batch) > 0:
  392. yield _batch
  393. n += 1
  394. if max_iter > 0 and n == max_iter:
  395. return
  396. reader.reset()
  397. if max_iter <= 0:
  398. return
  399. return _reader