queue.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import six
  19. if six.PY3:
  20. import pickle
  21. from io import BytesIO as StringIO
  22. from queue import Empty
  23. else:
  24. import cPickle as pickle
  25. from cStringIO import StringIO
  26. from Queue import Empty
  27. import logging
  28. import traceback
  29. import multiprocessing as mp
  30. from multiprocessing.queues import Queue
  31. from .sharedmemory import SharedMemoryMgr
  32. logger = logging.getLogger(__name__)
  33. class SharedQueueError(ValueError):
  34. """ SharedQueueError
  35. """
  36. pass
  37. class SharedQueue(Queue):
  38. """ a Queue based on shared memory to communicate data between Process,
  39. and it's interface is compatible with 'multiprocessing.queues.Queue'
  40. """
  41. def __init__(self, maxsize=0, mem_mgr=None, memsize=None, pagesize=None):
  42. """ init
  43. """
  44. if six.PY3:
  45. super(SharedQueue, self).__init__(maxsize, ctx=mp.get_context())
  46. else:
  47. super(SharedQueue, self).__init__(maxsize)
  48. if mem_mgr is not None:
  49. self._shared_mem = mem_mgr
  50. else:
  51. self._shared_mem = SharedMemoryMgr(
  52. capacity=memsize, pagesize=pagesize)
  53. def put(self, obj, **kwargs):
  54. """ put an object to this queue
  55. """
  56. obj = pickle.dumps(obj, -1)
  57. buff = None
  58. try:
  59. buff = self._shared_mem.malloc(len(obj))
  60. buff.put(obj)
  61. super(SharedQueue, self).put(buff, **kwargs)
  62. except Exception as e:
  63. stack_info = traceback.format_exc()
  64. err_msg = 'failed to put a element to SharedQueue '\
  65. 'with stack info[%s]' % (stack_info)
  66. logger.warning(err_msg)
  67. if buff is not None:
  68. buff.free()
  69. raise e
  70. def get(self, **kwargs):
  71. """ get an object from this queue
  72. """
  73. buff = None
  74. try:
  75. buff = super(SharedQueue, self).get(**kwargs)
  76. data = buff.get()
  77. return pickle.load(StringIO(data))
  78. except Empty as e:
  79. raise e
  80. except Exception as e:
  81. stack_info = traceback.format_exc()
  82. err_msg = 'failed to get element from SharedQueue '\
  83. 'with stack info[%s]' % (stack_info)
  84. logger.warning(err_msg)
  85. raise e
  86. finally:
  87. if buff is not None:
  88. buff.free()
  89. def release(self):
  90. self._shared_mem.release()
  91. self._shared_mem = None