123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310 |
- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # function:
- # transform samples in 'source' using 'worker'
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import sys
- import six
- if six.PY3:
- from queue import Empty
- else:
- from Queue import Empty
- import uuid
- import logging
- import signal
- import threading
- logger = logging.getLogger(__name__)
- main_pid = os.getpid()
- worker_set = set()
- class EndSignal(object):
- """ signal used to notify worker to exit
- """
- def __init__(self, id, errno=0, errmsg=''):
- self.id = id
- self.errno = errno
- self.errmsg = errmsg
- class ParallelMap(object):
- """
- Transform samples to mapped samples which is similar to
- 'basic.MappedDataset', but multiple workers (threads or processes)
- will be used
- Notes:
- this class is not thread-safe
- """
- def __init__(self,
- source,
- worker,
- worker_num,
- bufsize=100,
- use_process=False,
- memsize='3G'):
- self._worker_num = worker_num
- self._bufsize = bufsize
- self._use_process = use_process
- if self._use_process and sys.platform == "win32":
- logger.debug("Use multi-thread reader instead of "
- "multi-process reader on Windows.")
- self._use_process = False
- if self._use_process and type(memsize) is str:
- assert memsize[-1].lower() in ['g', 'm'], \
- "invalid param for memsize[%s], should be " \
- "ended with 'G' or 'g' or 'M' or 'm'" % (memsize)
- power = 3 if memsize[-1].lower() == 'g' else 2
- self._memsize = int(memsize[:-1]) * (1024**power)
- self._started = False
- self._source = source
- self._worker = worker
- self._exit = False
- self._setup()
- self._souce_drained = False
- def __iter__(self):
- return self
- def __next__(self):
- return self.next()
- def _setup(self):
- """setup input/output queues and workers """
- use_process = self._use_process
- bufsize = self._bufsize
- if use_process:
- from .shared_queue import SharedQueue as Queue
- from multiprocessing import Process as Worker
- from multiprocessing import Event
- memsize = self._memsize
- self._inq = Queue(bufsize, memsize=memsize)
- self._outq = Queue(bufsize, memsize=memsize)
- else:
- if six.PY3:
- from queue import Queue
- else:
- from Queue import Queue
- from threading import Thread as Worker
- from threading import Event
- self._inq = Queue(bufsize)
- self._outq = Queue(bufsize)
- consumer_num = self._worker_num
- id = str(uuid.uuid4())[-3:]
- self._producer = threading.Thread(
- target=self._produce,
- args=('producer-' + id, self._source, self._inq))
- self._producer.daemon = True
- self._consumers = []
- self._consumer_endsig = {}
- global worker_set
- for i in range(consumer_num):
- consumer_id = 'consumer-' + id + '-' + str(i)
- p = Worker(
- target=self._consume,
- args=(consumer_id, self._inq, self._outq, self._worker))
- self._consumers.append(p)
- p.daemon = True
- setattr(p, 'id', consumer_id)
- if use_process:
- worker_set.add(p)
- self._epoch = -1
- self._feeding_ev = Event()
- self._produced = 0 # produced sample in self._produce
- self._consumed = 0 # consumed sample in self.next
- def _produce(self, id, source, inq):
- """Fetch data from source and feed it to 'inq' queue"""
- endsig = EndSignal(id)
- while True:
- self._feeding_ev.wait()
- if self._exit:
- break
- try:
- s = source.next()
- inq.put(s)
- self._produced += 1
- except StopIteration:
- self._souce_drained = True
- self._feeding_ev.clear()
- self._feeding_ev.wait()
- except Exception as e:
- endsig.errno = -1
- endsig.errmsg = "producer[{}] failed with error: {}" \
- .format(id, str(e))
- inq.put(endsig)
- break
- def _consume(self, id, inq, outq, worker):
- """Fetch data from 'inq', process it and put result to 'outq'"""
- if self._use_process:
- # handle SIGTERM signal to exit to prevent print stack frame
- signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit())
- endsig = EndSignal(id)
- while True:
- sample = inq.get()
- if isinstance(sample, EndSignal):
- endsig.errno = sample.errno
- endsig.errmsg = "consumer[{}] exits for reason[{}]" \
- .format(id, sample.errmsg)
- outq.put(endsig)
- break
- try:
- result = worker(sample)
- outq.put(result)
- except Exception as e:
- endsig.errno = -2
- endsig.errmsg = "consumer[{}] failed to map with error:[{}]" \
- .format(id, str(e))
- outq.put(endsig)
- break
- def drained(self):
- assert self._epoch >= 0, "first epoch has not started yet"
- return self._source.drained() and self._produced == self._consumed
- def stop(self):
- """ notify to exit
- """
- self._exit = True
- self._feeding_ev.set()
- for _ in range(len(self._consumers)):
- self._inq.put(EndSignal(0, "notify consumers to exit"))
- def _consumer_healthy(self):
- abnormal_num = 0
- for w in self._consumers:
- if not w.is_alive() and w.id not in self._consumer_endsig:
- abnormal_num += 1
- if self._use_process:
- errmsg = "consumer[{}] exit abnormally with exitcode[{}]" \
- .format(w.pid, w.exitcode)
- else:
- errmsg = "consumer[{}] exit abnormally".format(w.ident)
- logger.warning(errmsg)
- if abnormal_num > 0:
- logger.warning("{} consumers have exited abnormally!!!" \
- .format(abnormal_num))
- return abnormal_num == 0
- def next(self):
- """ get next transformed sample
- """
- if self._epoch < 0:
- self.reset()
- if self.drained():
- raise StopIteration()
- while not self._exit:
- try:
- sample = self._outq.get(timeout=3)
- except Empty as e:
- if not self._consumer_healthy():
- raise StopIteration()
- else:
- continue
- if isinstance(sample, EndSignal):
- self._consumer_endsig[sample.id] = sample
- logger.warning("recv endsignal from outq with errmsg[{}]" \
- .format(sample.errmsg))
- if len(self._consumer_endsig.keys()) < len(self._consumers):
- self._inq.put(sample)
- else:
- self._exit = True
- raise StopIteration("all consumers exited, no more samples")
- else:
- self._consumed += 1
- return sample
- raise StopIteration()
- def reset(self):
- """ reset for a new epoch of samples
- """
- assert not self._exit, "cannot reset for already stopped dataset"
- if self._epoch < 0:
- self._epoch = 0
- for w in self._consumers:
- w.start()
- self._producer.start()
- else:
- assert self._consumer_healthy(), "cannot start another pass of data" \
- " for some consumers exited abnormally before!!!"
- if not self.drained():
- logger.warning("reset before epoch[{}] finishes".format(
- self._epoch))
- self._produced = self._produced - self._consumed
- else:
- self._produced = 0
- self._epoch += 1
- assert len(self._consumer_endsig.keys()) == 0, "some consumers already exited," \
- + " cannot start another epoch"
- self._source.reset()
- self._souce_drained = False
- self._consumed = 0
- self._feeding_ev.set()
- # FIXME: fix me if you have better impliment
- # handle terminate reader process, do not print stack frame
- signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit())
- # FIXME(dkp): KeyboardInterrupt should be handled inside ParallelMap
- # and do such as: 1. exit workers 2. close queues 3. release shared
- # memory, HACK KeyboardInterrupt with global signal.SIGINT handler
- # here, should be refined later
- def _term_workers(sig_num, frame):
- global worker_set, main_pid
- # only do subporcess killing in main process
- if os.getpid() != main_pid:
- return
- logger.info("KeyboardInterrupt: main proc {} exit, kill subprocess {}" \
- .format(os.getpid(), [w.pid for w in worker_set]))
- for w in worker_set:
- if w.pid is not None:
- os.kill(w.pid, signal.SIGINT)
- sys.exit()
- signal.signal(signal.SIGINT, _term_workers)
|