diff piecrust/workerpool.py @ 447:aefe70229fdd

bake: Commonize worker pool code between html and asset baking. The `workerpool` package now defines a generic-ish worker pool. It's similar to the Python framework pool but with a simpler use-case (only one way to queue jobs) and support for workers to send a final "report" to the master process, which we use to get timing information here. The rest of the changes basically remove a whole bunch of duplicated code that's not needed anymore.
author Ludovic Chabant <ludovic@chabant.com>
date Sun, 05 Jul 2015 00:09:41 -0700
parents
children 838f3964f400
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/piecrust/workerpool.py	Sun Jul 05 00:09:41 2015 -0700
@@ -0,0 +1,267 @@
+import os
+import sys
+import logging
+import threading
+import multiprocessing
+
+
+logger = logging.getLogger(__name__)
+
+
+class IWorker(object):
+    def initialize(self):
+        raise NotImplementedError()
+
+    def process(self, job):
+        raise NotImplementedError()
+
+    def getReport(self):
+        return None
+
+
+TASK_JOB = 0
+TASK_END = 1
+
+
+def worker_func(params):
+    if params.is_profiling:
+        try:
+            import cProfile as profile
+        except ImportError:
+            import profile
+
+        params.is_profiling = False
+        name = params.worker_class.__name__
+        profile.runctx('_real_worker_func(params)',
+                       globals(), locals(),
+                       filename='%s-%d.prof' % (name, params.wid))
+    else:
+        _real_worker_func(params)
+
+
+def _real_worker_func(params):
+    if hasattr(params.inqueue, '_writer'):
+        params.inqueue._writer.close()
+        params.outqueue._reader.close()
+
+    wid = params.wid
+    logger.debug("Worker %d initializing..." % wid)
+
+    w = params.worker_class(*params.initargs)
+    w.wid = wid
+    w.initialize()
+
+    get = params.inqueue.get
+    put = params.outqueue.put
+
+    completed = 0
+    while True:
+        try:
+            task = get()
+        except (EOFError, OSError):
+            logger.debug("Worker %d encountered connection problem." % wid)
+            break
+
+        task_type, task_data = task
+        if task_type == TASK_END:
+            logger.debug("Worker %d got end task, exiting." % wid)
+            try:
+                rep = (task_type, True, wid, (wid, w.getReport()))
+            except Exception as e:
+                if params.wrap_exception:
+                    e = multiprocessing.ExceptionWithTraceback(
+                            e, e.__traceback__)
+                rep = (task_type, False, wid, (wid, e))
+            put(rep)
+            break
+
+        try:
+            res = (task_type, True, wid, w.process(task_data))
+        except Exception as e:
+            if params.wrap_exception:
+                e = multiprocessing.ExceptionWithTraceback(e, e.__traceback__)
+            res = (task_type, False, wid, e)
+        put(res)
+
+        completed += 1
+
+    logger.debug("Worker %d completed %d tasks." % (wid, completed))
+
+
+class _WorkerParams(object):
+    def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(),
+                 wrap_exception=False, is_profiling=False):
+        self.wid = wid
+        self.inqueue = inqueue
+        self.outqueue = outqueue
+        self.worker_class = worker_class
+        self.initargs = initargs
+        self.wrap_exception = wrap_exception
+        self.is_profiling = is_profiling
+
+
+class WorkerPool(object):
+    def __init__(self, worker_class, worker_count=None, initargs=()):
+        worker_count = worker_count or os.cpu_count() or 1
+
+        self._task_queue = multiprocessing.SimpleQueue()
+        self._result_queue = multiprocessing.SimpleQueue()
+        self._quick_put = self._task_queue._writer.send
+        self._quick_get = self._result_queue._reader.recv
+
+        self._callback = None
+        self._error_callback = None
+        self._listener = None
+
+        main_module = sys.modules['__main__']
+        is_profiling = os.path.basename(main_module.__file__) in [
+                'profile.py', 'cProfile.py']
+
+        self._pool = []
+        for i in range(worker_count):
+            worker_params = _WorkerParams(
+                    i, self._task_queue, self._result_queue,
+                    worker_class, initargs,
+                    is_profiling=is_profiling)
+            w = multiprocessing.Process(target=worker_func,
+                                        args=(worker_params,))
+            w.name = w.name.replace('Process', 'PoolWorker')
+            w.daemon = True
+            w.start()
+            self._pool.append(w)
+
+        self._result_handler = threading.Thread(
+                target=WorkerPool._handleResults,
+                args=(self,))
+        self._result_handler.daemon = True
+        self._result_handler.start()
+
+        self._closed = False
+
+    def setHandler(self, callback=None, error_callback=None):
+        self._callback = callback
+        self._error_callback = error_callback
+
+    def queueJobs(self, jobs, handler=None):
+        if self._closed:
+            raise Exception("This worker pool has been closed.")
+        if self._listener is not None:
+            raise Exception("A previous job queue has not finished yet.")
+
+        if handler is not None:
+            self.setHandler(handler)
+
+        if not hasattr(jobs, '__len__'):
+            jobs = list(jobs)
+
+        res = AsyncResult(self, len(jobs))
+        if res._count == 0:
+            res._event.set()
+            return res
+
+        self._listener = res
+        for job in jobs:
+            self._quick_put((TASK_JOB, job))
+
+        return res
+
+    def close(self):
+        if self._listener is not None:
+            raise Exception("A previous job queue has not finished yet.")
+
+        logger.debug("Closing worker pool...")
+        handler = _ReportHandler(len(self._pool))
+        self._callback = handler._handle
+        for w in self._pool:
+            self._quick_put((TASK_END, None))
+        for w in self._pool:
+            w.join()
+
+        logger.debug("Waiting for reports...")
+        if not handler.wait(2):
+            missing = handler.reports.index(None)
+            logger.warning(
+                    "Didn't receive all worker reports before timeout. "
+                    "Missing report from worker %d." % missing)
+
+        logger.debug("Exiting result handler thread...")
+        self._result_queue.put(None)
+        self._result_handler.join()
+        self._closed = True
+
+        return handler.reports
+
+    @staticmethod
+    def _handleResults(pool):
+        while True:
+            try:
+                res = pool._quick_get()
+            except (EOFError, OSError):
+                logger.debug("Result handler thread encountered connection "
+                             "problem, exiting.")
+                return
+
+            if res is None:
+                logger.debug("Result handler exiting.")
+                break
+
+            task_type, success, wid, data = res
+            try:
+                if success and pool._callback:
+                    pool._callback(data)
+                elif not success and pool._error_callback:
+                    pool._error_callback(data)
+            except Exception as ex:
+                logger.exception(ex)
+
+            if task_type == TASK_JOB:
+                pool._listener._onTaskDone()
+
+
+class AsyncResult(object):
+    def __init__(self, pool, count):
+        self._pool = pool
+        self._count = count
+        self._event = threading.Event()
+
+    def ready(self):
+        return self._event.is_set()
+
+    def wait(self, timeout=None):
+        return self._event.wait(timeout)
+
+    def _onTaskDone(self):
+        self._count -= 1
+        if self._count == 0:
+            self._pool.setHandler(None)
+            self._pool._listener = None
+            self._event.set()
+
+
+class _ReportHandler(object):
+    def __init__(self, worker_count):
+        self.reports = [None] * worker_count
+        self._count = worker_count
+        self._received = 0
+        self._event = threading.Event()
+
+    def wait(self, timeout=None):
+        return self._event.wait(timeout)
+
+    def _handle(self, res):
+        wid, data = res
+        if wid < 0 or wid > self._count:
+            logger.error("Ignoring report from unknown worker %d." % wid)
+            return
+
+        self._received += 1
+        self.reports[wid] = data
+
+        if self._received == self._count:
+            self._event.set()
+
+    def _handleError(self, res):
+        wid, data = res
+        logger.error("Worker %d failed to send its report." % wid)
+        logger.exception(data)
+