parallel_map.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # function:
  15. # transform samples in 'source' using 'worker'
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import os
  20. import sys
  21. import six
  22. if six.PY3:
  23. from queue import Empty
  24. else:
  25. from Queue import Empty
  26. import uuid
  27. import logging
  28. import signal
  29. import threading
  30. logger = logging.getLogger(__name__)
  31. main_pid = os.getpid()
  32. worker_set = set()
  33. class EndSignal(object):
  34. """ signal used to notify worker to exit
  35. """
  36. def __init__(self, id, errno=0, errmsg=''):
  37. self.id = id
  38. self.errno = errno
  39. self.errmsg = errmsg
  40. class ParallelMap(object):
  41. """
  42. Transform samples to mapped samples which is similar to
  43. 'basic.MappedDataset', but multiple workers (threads or processes)
  44. will be used
  45. Notes:
  46. this class is not thread-safe
  47. """
  48. def __init__(self,
  49. source,
  50. worker,
  51. worker_num,
  52. bufsize=100,
  53. use_process=False,
  54. memsize='3G'):
  55. self._worker_num = worker_num
  56. self._bufsize = bufsize
  57. self._use_process = use_process
  58. if self._use_process and sys.platform == "win32":
  59. logger.debug("Use multi-thread reader instead of "
  60. "multi-process reader on Windows.")
  61. self._use_process = False
  62. if self._use_process and type(memsize) is str:
  63. assert memsize[-1].lower() in ['g', 'm'], \
  64. "invalid param for memsize[%s], should be " \
  65. "ended with 'G' or 'g' or 'M' or 'm'" % (memsize)
  66. power = 3 if memsize[-1].lower() == 'g' else 2
  67. self._memsize = int(memsize[:-1]) * (1024**power)
  68. self._started = False
  69. self._source = source
  70. self._worker = worker
  71. self._exit = False
  72. self._setup()
  73. self._souce_drained = False
  74. def __iter__(self):
  75. return self
  76. def __next__(self):
  77. return self.next()
  78. def _setup(self):
  79. """setup input/output queues and workers """
  80. use_process = self._use_process
  81. bufsize = self._bufsize
  82. if use_process:
  83. from .shared_queue import SharedQueue as Queue
  84. from multiprocessing import Process as Worker
  85. from multiprocessing import Event
  86. memsize = self._memsize
  87. self._inq = Queue(bufsize, memsize=memsize)
  88. self._outq = Queue(bufsize, memsize=memsize)
  89. else:
  90. if six.PY3:
  91. from queue import Queue
  92. else:
  93. from Queue import Queue
  94. from threading import Thread as Worker
  95. from threading import Event
  96. self._inq = Queue(bufsize)
  97. self._outq = Queue(bufsize)
  98. consumer_num = self._worker_num
  99. id = str(uuid.uuid4())[-3:]
  100. self._producer = threading.Thread(
  101. target=self._produce,
  102. args=('producer-' + id, self._source, self._inq))
  103. self._producer.daemon = True
  104. self._consumers = []
  105. self._consumer_endsig = {}
  106. global worker_set
  107. for i in range(consumer_num):
  108. consumer_id = 'consumer-' + id + '-' + str(i)
  109. p = Worker(
  110. target=self._consume,
  111. args=(consumer_id, self._inq, self._outq, self._worker))
  112. self._consumers.append(p)
  113. p.daemon = True
  114. setattr(p, 'id', consumer_id)
  115. if use_process:
  116. worker_set.add(p)
  117. self._epoch = -1
  118. self._feeding_ev = Event()
  119. self._produced = 0 # produced sample in self._produce
  120. self._consumed = 0 # consumed sample in self.next
  121. def _produce(self, id, source, inq):
  122. """Fetch data from source and feed it to 'inq' queue"""
  123. endsig = EndSignal(id)
  124. while True:
  125. self._feeding_ev.wait()
  126. if self._exit:
  127. break
  128. try:
  129. s = source.next()
  130. inq.put(s)
  131. self._produced += 1
  132. except StopIteration:
  133. self._souce_drained = True
  134. self._feeding_ev.clear()
  135. self._feeding_ev.wait()
  136. except Exception as e:
  137. endsig.errno = -1
  138. endsig.errmsg = "producer[{}] failed with error: {}" \
  139. .format(id, str(e))
  140. inq.put(endsig)
  141. break
  142. def _consume(self, id, inq, outq, worker):
  143. """Fetch data from 'inq', process it and put result to 'outq'"""
  144. if self._use_process:
  145. # handle SIGTERM signal to exit to prevent print stack frame
  146. signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit())
  147. endsig = EndSignal(id)
  148. while True:
  149. sample = inq.get()
  150. if isinstance(sample, EndSignal):
  151. endsig.errno = sample.errno
  152. endsig.errmsg = "consumer[{}] exits for reason[{}]" \
  153. .format(id, sample.errmsg)
  154. outq.put(endsig)
  155. break
  156. try:
  157. result = worker(sample)
  158. outq.put(result)
  159. except Exception as e:
  160. endsig.errno = -2
  161. endsig.errmsg = "consumer[{}] failed to map with error:[{}]" \
  162. .format(id, str(e))
  163. outq.put(endsig)
  164. break
  165. def drained(self):
  166. assert self._epoch >= 0, "first epoch has not started yet"
  167. return self._source.drained() and self._produced == self._consumed
  168. def stop(self):
  169. """ notify to exit
  170. """
  171. self._exit = True
  172. self._feeding_ev.set()
  173. for _ in range(len(self._consumers)):
  174. self._inq.put(EndSignal(0, "notify consumers to exit"))
  175. def _consumer_healthy(self):
  176. abnormal_num = 0
  177. for w in self._consumers:
  178. if not w.is_alive() and w.id not in self._consumer_endsig:
  179. abnormal_num += 1
  180. if self._use_process:
  181. errmsg = "consumer[{}] exit abnormally with exitcode[{}]" \
  182. .format(w.pid, w.exitcode)
  183. else:
  184. errmsg = "consumer[{}] exit abnormally".format(w.ident)
  185. logger.warning(errmsg)
  186. if abnormal_num > 0:
  187. logger.warning("{} consumers have exited abnormally!!!" \
  188. .format(abnormal_num))
  189. return abnormal_num == 0
  190. def next(self):
  191. """ get next transformed sample
  192. """
  193. if self._epoch < 0:
  194. self.reset()
  195. if self.drained():
  196. raise StopIteration()
  197. while not self._exit:
  198. try:
  199. sample = self._outq.get(timeout=3)
  200. except Empty as e:
  201. if not self._consumer_healthy():
  202. raise StopIteration()
  203. else:
  204. continue
  205. if isinstance(sample, EndSignal):
  206. self._consumer_endsig[sample.id] = sample
  207. logger.warning("recv endsignal from outq with errmsg[{}]" \
  208. .format(sample.errmsg))
  209. if len(self._consumer_endsig.keys()) < len(self._consumers):
  210. self._inq.put(sample)
  211. else:
  212. self._exit = True
  213. raise StopIteration("all consumers exited, no more samples")
  214. else:
  215. self._consumed += 1
  216. return sample
  217. raise StopIteration()
  218. def reset(self):
  219. """ reset for a new epoch of samples
  220. """
  221. assert not self._exit, "cannot reset for already stopped dataset"
  222. if self._epoch < 0:
  223. self._epoch = 0
  224. for w in self._consumers:
  225. w.start()
  226. self._producer.start()
  227. else:
  228. assert self._consumer_healthy(), "cannot start another pass of data" \
  229. " for some consumers exited abnormally before!!!"
  230. if not self.drained():
  231. logger.warning("reset before epoch[{}] finishes".format(
  232. self._epoch))
  233. self._produced = self._produced - self._consumed
  234. else:
  235. self._produced = 0
  236. self._epoch += 1
  237. assert len(self._consumer_endsig.keys()) == 0, "some consumers already exited," \
  238. + " cannot start another epoch"
  239. self._source.reset()
  240. self._souce_drained = False
  241. self._consumed = 0
  242. self._feeding_ev.set()
  243. # FIXME: fix me if you have better impliment
  244. # handle terminate reader process, do not print stack frame
  245. signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit())
  246. # FIXME(dkp): KeyboardInterrupt should be handled inside ParallelMap
  247. # and do such as: 1. exit workers 2. close queues 3. release shared
  248. # memory, HACK KeyboardInterrupt with global signal.SIGINT handler
  249. # here, should be refined later
  250. def _term_workers(sig_num, frame):
  251. global worker_set, main_pid
  252. # only do subporcess killing in main process
  253. if os.getpid() != main_pid:
  254. return
  255. logger.info("KeyboardInterrupt: main proc {} exit, kill subprocess {}" \
  256. .format(os.getpid(), [w.pid for w in worker_set]))
  257. for w in worker_set:
  258. if w.pid is not None:
  259. os.kill(w.pid, signal.SIGINT)
  260. sys.exit()
  261. signal.signal(signal.SIGINT, _term_workers)