reader.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import traceback
  16. import six
  17. import sys
  18. if sys.version_info >= (3, 0):
  19. pass
  20. else:
  21. pass
  22. import numpy as np
  23. from paddle.io import DataLoader, DistributedBatchSampler
  24. from paddle.fluid.dataloader.collate import default_collate_fn
  25. from ppdet.core.workspace import register
  26. from . import transform
  27. from .shm_utils import _get_shared_memory_size_in_M
  28. from ppdet.utils.logger import setup_logger
  29. logger = setup_logger('reader')
  30. MAIN_PID = os.getpid()
  31. class Compose(object):
  32. def __init__(self, transforms, num_classes=80):
  33. self.transforms = transforms
  34. self.transforms_cls = []
  35. for t in self.transforms:
  36. for k, v in t.items():
  37. op_cls = getattr(transform, k)
  38. f = op_cls(**v)
  39. if hasattr(f, 'num_classes'):
  40. f.num_classes = num_classes
  41. self.transforms_cls.append(f)
  42. def __call__(self, data):
  43. for f in self.transforms_cls:
  44. try:
  45. data = f(data)
  46. except Exception as e:
  47. stack_info = traceback.format_exc()
  48. logger.warning("fail to map sample transform [{}] "
  49. "with error: {} and stack:\n{}".format(
  50. f, e, str(stack_info)))
  51. raise e
  52. return data
  53. class BatchCompose(Compose):
  54. def __init__(self, transforms, num_classes=80, collate_batch=True):
  55. super(BatchCompose, self).__init__(transforms, num_classes)
  56. self.collate_batch = collate_batch
  57. def __call__(self, data):
  58. for f in self.transforms_cls:
  59. try:
  60. data = f(data)
  61. except Exception as e:
  62. stack_info = traceback.format_exc()
  63. logger.warning("fail to map batch transform [{}] "
  64. "with error: {} and stack:\n{}".format(
  65. f, e, str(stack_info)))
  66. raise e
  67. # remove keys which is not needed by model
  68. extra_key = ['h', 'w', 'flipped']
  69. for k in extra_key:
  70. for sample in data:
  71. if k in sample:
  72. sample.pop(k)
  73. # batch data, if user-define batch function needed
  74. # use user-defined here
  75. if self.collate_batch:
  76. batch_data = default_collate_fn(data)
  77. else:
  78. batch_data = {}
  79. for k in data[0].keys():
  80. tmp_data = []
  81. for i in range(len(data)):
  82. tmp_data.append(data[i][k])
  83. if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
  84. tmp_data = np.stack(tmp_data, axis=0)
  85. batch_data[k] = tmp_data
  86. return batch_data
  87. class BaseDataLoader(object):
  88. """
  89. Base DataLoader implementation for detection models
  90. Args:
  91. sample_transforms (list): a list of transforms to perform
  92. on each sample
  93. batch_transforms (list): a list of transforms to perform
  94. on batch
  95. batch_size (int): batch size for batch collating, default 1.
  96. shuffle (bool): whether to shuffle samples
  97. drop_last (bool): whether to drop the last incomplete,
  98. default False
  99. num_classes (int): class number of dataset, default 80
  100. collate_batch (bool): whether to collate batch in dataloader.
  101. If set to True, the samples will collate into batch according
  102. to the batch size. Otherwise, the ground-truth will not collate,
  103. which is used when the number of ground-truch is different in
  104. samples.
  105. use_shared_memory (bool): whether to use shared memory to
  106. accelerate data loading, enable this only if you
  107. are sure that the shared memory size of your OS
  108. is larger than memory cost of input datas of model.
  109. Note that shared memory will be automatically
  110. disabled if the shared memory of OS is less than
  111. 1G, which is not enough for detection models.
  112. Default False.
  113. """
  114. def __init__(self,
  115. sample_transforms=[],
  116. batch_transforms=[],
  117. batch_size=1,
  118. shuffle=False,
  119. drop_last=False,
  120. num_classes=80,
  121. collate_batch=True,
  122. use_shared_memory=False,
  123. **kwargs):
  124. # sample transform
  125. self._sample_transforms = Compose(
  126. sample_transforms, num_classes=num_classes)
  127. # batch transfrom
  128. self._batch_transforms = BatchCompose(batch_transforms, num_classes,
  129. collate_batch)
  130. self.batch_size = batch_size
  131. self.shuffle = shuffle
  132. self.drop_last = drop_last
  133. self.use_shared_memory = use_shared_memory
  134. self.kwargs = kwargs
  135. def __call__(self,
  136. dataset,
  137. worker_num,
  138. batch_sampler=None,
  139. return_list=False):
  140. self.dataset = dataset
  141. self.dataset.check_or_download_dataset()
  142. self.dataset.parse_dataset()
  143. # get data
  144. self.dataset.set_transform(self._sample_transforms)
  145. # set kwargs
  146. self.dataset.set_kwargs(**self.kwargs)
  147. # batch sampler
  148. if batch_sampler is None:
  149. self._batch_sampler = DistributedBatchSampler(
  150. self.dataset,
  151. batch_size=self.batch_size,
  152. shuffle=self.shuffle,
  153. drop_last=self.drop_last)
  154. else:
  155. self._batch_sampler = batch_sampler
  156. # DataLoader do not start sub-process in Windows and Mac
  157. # system, do not need to use shared memory
  158. use_shared_memory = self.use_shared_memory and \
  159. sys.platform not in ['win32', 'darwin']
  160. # check whether shared memory size is bigger than 1G(1024M)
  161. if use_shared_memory:
  162. shm_size = _get_shared_memory_size_in_M()
  163. if shm_size is not None and shm_size < 1024.:
  164. logger.warning("Shared memory size is less than 1G, "
  165. "disable shared_memory in DataLoader")
  166. use_shared_memory = False
  167. self.dataloader = DataLoader(
  168. dataset=self.dataset,
  169. batch_sampler=self._batch_sampler,
  170. collate_fn=self._batch_transforms,
  171. num_workers=worker_num,
  172. return_list=return_list,
  173. use_shared_memory=use_shared_memory)
  174. self.loader = iter(self.dataloader)
  175. return self
  176. def __len__(self):
  177. return len(self._batch_sampler)
  178. def __iter__(self):
  179. return self
  180. def __next__(self):
  181. try:
  182. return next(self.loader)
  183. except StopIteration:
  184. self.loader = iter(self.dataloader)
  185. six.reraise(*sys.exc_info())
  186. def next(self):
  187. # python2 compatibility
  188. return self.__next__()
  189. @register
  190. class TrainReader(BaseDataLoader):
  191. __shared__ = ['num_classes']
  192. def __init__(self,
  193. sample_transforms=[],
  194. batch_transforms=[],
  195. batch_size=1,
  196. shuffle=True,
  197. drop_last=True,
  198. num_classes=80,
  199. collate_batch=True,
  200. **kwargs):
  201. super(TrainReader, self).__init__(sample_transforms, batch_transforms,
  202. batch_size, shuffle, drop_last,
  203. num_classes, collate_batch, **kwargs)
  204. @register
  205. class EvalReader(BaseDataLoader):
  206. __shared__ = ['num_classes']
  207. def __init__(self,
  208. sample_transforms=[],
  209. batch_transforms=[],
  210. batch_size=1,
  211. shuffle=False,
  212. drop_last=True,
  213. num_classes=80,
  214. **kwargs):
  215. super(EvalReader, self).__init__(sample_transforms, batch_transforms,
  216. batch_size, shuffle, drop_last,
  217. num_classes, **kwargs)
  218. @register
  219. class TestReader(BaseDataLoader):
  220. __shared__ = ['num_classes']
  221. def __init__(self,
  222. sample_transforms=[],
  223. batch_transforms=[],
  224. batch_size=1,
  225. shuffle=False,
  226. drop_last=False,
  227. num_classes=80,
  228. **kwargs):
  229. super(TestReader, self).__init__(sample_transforms, batch_transforms,
  230. batch_size, shuffle, drop_last,
  231. num_classes, **kwargs)
  232. @register
  233. class EvalMOTReader(BaseDataLoader):
  234. __shared__ = ['num_classes']
  235. def __init__(self,
  236. sample_transforms=[],
  237. batch_transforms=[],
  238. batch_size=1,
  239. shuffle=False,
  240. drop_last=False,
  241. num_classes=1,
  242. **kwargs):
  243. super(EvalMOTReader, self).__init__(sample_transforms, batch_transforms,
  244. batch_size, shuffle, drop_last,
  245. num_classes, **kwargs)
  246. @register
  247. class TestMOTReader(BaseDataLoader):
  248. __shared__ = ['num_classes']
  249. def __init__(self,
  250. sample_transforms=[],
  251. batch_transforms=[],
  252. batch_size=1,
  253. shuffle=False,
  254. drop_last=False,
  255. num_classes=1,
  256. **kwargs):
  257. super(TestMOTReader, self).__init__(sample_transforms, batch_transforms,
  258. batch_size, shuffle, drop_last,
  259. num_classes, **kwargs)