changeset 447:aefe70229fdd

bake: Commonize worker pool code between html and asset baking. The `workerpool` package now defines a generic-ish worker pool. It's similar to the Python framework pool but with a simpler use-case (only one way to queue jobs) and support for workers to send a final "report" to the master process, which we use to get timing information here. The rest of the changes basically remove a whole bunch of duplicated code that's not needed anymore.
author Ludovic Chabant <ludovic@chabant.com>
date Sun, 05 Jul 2015 00:09:41 -0700
parents 4cdf6c2157a0
children a17774094db8
files piecrust/baking/baker.py piecrust/baking/worker.py piecrust/processing/pipeline.py piecrust/processing/worker.py piecrust/workerpool.py
diffstat 5 files changed, 448 insertions(+), 425 deletions(-) [+]
line wrap: on
line diff
--- a/piecrust/baking/baker.py	Thu Jul 02 23:28:24 2015 -0700
+++ b/piecrust/baking/baker.py	Sun Jul 05 00:09:41 2015 -0700
@@ -103,17 +103,13 @@
         # Bake taxonomies.
         self._bakeTaxonomies(record, pool)
 
-        # All done with the workers.
-        self._terminateWorkerPool(pool)
-
-        # Get the timing information from the workers.
+        # All done with the workers. Close the pool and get timing reports.
+        reports = pool.close()
         record.current.timers = {}
-        for i in range(len(pool.workers)):
-            try:
-                timers = pool.results.get(True, 0.1)
-            except queue.Empty:
-                logger.error("Didn't get timing information from all workers.")
-                break
+        for i in range(len(reports)):
+            timers = reports[i]
+            if timers is None:
+                continue
 
             worker_name = 'BakeWorker_%d' % i
             record.current.timers[worker_name] = {}
@@ -214,41 +210,43 @@
                     (page_count, REALM_NAMES[realm].lower())))
 
     def _loadRealmPages(self, record, pool, factories):
+        def _handler(res):
+            # Create the record entry for this page.
+            record_entry = BakeRecordEntry(res.source_name, res.path)
+            record_entry.config = res.config
+            if res.errors:
+                record_entry.errors += res.errors
+                record.current.success = False
+                self._logErrors(res.path, res.errors)
+            record.addEntry(record_entry)
+
         logger.debug("Loading %d realm pages..." % len(factories))
         with format_timed_scope(logger,
                                 "loaded %d pages" % len(factories),
                                 level=logging.DEBUG, colored=False,
                                 timer_env=self.app.env,
                                 timer_category='LoadJob'):
-            for fac in factories:
-                job = BakeWorkerJob(
-                        JOB_LOAD,
-                        LoadJobPayload(fac))
-                pool.queue.put_nowait(job)
-
-            def _handler(res):
-                # Create the record entry for this page.
-                record_entry = BakeRecordEntry(res.source_name, res.path)
-                record_entry.config = res.config
-                if res.errors:
-                    record_entry.errors += res.errors
-                    record.current.success = False
-                    self._logErrors(res.path, res.errors)
-                record.addEntry(record_entry)
-
-            self._waitOnWorkerPool(
-                    pool,
-                    expected_result_count=len(factories),
-                    result_handler=_handler)
+            jobs = [
+                BakeWorkerJob(JOB_LOAD, LoadJobPayload(fac))
+                for fac in factories]
+            ar = pool.queueJobs(jobs, handler=_handler)
+            ar.wait()
 
     def _renderRealmPages(self, record, pool, factories):
+        def _handler(res):
+            entry = record.getCurrentEntry(res.path)
+            if res.errors:
+                entry.errors += res.errors
+                record.current.success = False
+                self._logErrors(res.path, res.errors)
+
         logger.debug("Rendering %d realm pages..." % len(factories))
         with format_timed_scope(logger,
                                 "prepared %d pages" % len(factories),
                                 level=logging.DEBUG, colored=False,
                                 timer_env=self.app.env,
                                 timer_category='RenderFirstSubJob'):
-            expected_result_count = 0
+            jobs = []
             for fac in factories:
                 record_entry = record.getCurrentEntry(fac.path)
                 if record_entry.errors:
@@ -278,49 +276,38 @@
                 job = BakeWorkerJob(
                         JOB_RENDER_FIRST,
                         RenderFirstSubJobPayload(fac))
-                pool.queue.put_nowait(job)
-                expected_result_count += 1
+                jobs.append(job)
 
-            def _handler(res):
-                entry = record.getCurrentEntry(res.path)
-                if res.errors:
-                    entry.errors += res.errors
-                    record.current.success = False
-                    self._logErrors(res.path, res.errors)
-
-            self._waitOnWorkerPool(
-                    pool,
-                    expected_result_count=expected_result_count,
-                    result_handler=_handler)
+            ar = pool.queueJobs(jobs, handler=_handler)
+            ar.wait()
 
     def _bakeRealmPages(self, record, pool, realm, factories):
+        def _handler(res):
+            entry = record.getCurrentEntry(res.path, res.taxonomy_info)
+            entry.subs = res.sub_entries
+            if res.errors:
+                entry.errors += res.errors
+                self._logErrors(res.path, res.errors)
+            if entry.has_any_error:
+                record.current.success = False
+            if entry.was_any_sub_baked:
+                record.current.baked_count[realm] += 1
+                record.dirty_source_names.add(entry.source_name)
+
         logger.debug("Baking %d realm pages..." % len(factories))
         with format_timed_scope(logger,
                                 "baked %d pages" % len(factories),
                                 level=logging.DEBUG, colored=False,
                                 timer_env=self.app.env,
                                 timer_category='BakeJob'):
-            expected_result_count = 0
+            jobs = []
             for fac in factories:
-                if self._queueBakeJob(record, pool, fac):
-                    expected_result_count += 1
+                job = self._makeBakeJob(record, fac)
+                if job is not None:
+                    jobs.append(job)
 
-            def _handler(res):
-                entry = record.getCurrentEntry(res.path, res.taxonomy_info)
-                entry.subs = res.sub_entries
-                if res.errors:
-                    entry.errors += res.errors
-                    self._logErrors(res.path, res.errors)
-                if entry.has_any_error:
-                    record.current.success = False
-                if entry.was_any_sub_baked:
-                    record.current.baked_count[realm] += 1
-                    record.dirty_source_names.add(entry.source_name)
-
-            self._waitOnWorkerPool(
-                    pool,
-                    expected_result_count=expected_result_count,
-                    result_handler=_handler)
+            ar = pool.queueJobs(jobs, handler=_handler)
+            ar.wait()
 
     def _bakeTaxonomies(self, record, pool):
         logger.debug("Baking taxonomy pages...")
@@ -400,8 +387,16 @@
         return buckets
 
     def _bakeTaxonomyBuckets(self, record, pool, buckets):
+        def _handler(res):
+            entry = record.getCurrentEntry(res.path, res.taxonomy_info)
+            entry.subs = res.sub_entries
+            if res.errors:
+                entry.errors += res.errors
+            if entry.has_any_error:
+                record.current.success = False
+
         # Start baking those terms.
-        expected_result_count = 0
+        jobs = []
         for source_name, source_taxonomies in buckets.items():
             for tax_name, tt_info in source_taxonomies.items():
                 terms = tt_info.dirty_terms
@@ -435,21 +430,12 @@
                             fac.source.name, fac.path, tax_info)
                     record.addEntry(cur_entry)
 
-                    if self._queueBakeJob(record, pool, fac, tax_info):
-                        expected_result_count += 1
+                    job = self._makeBakeJob(record, fac, tax_info)
+                    if job is not None:
+                        jobs.append(job)
 
-        def _handler(res):
-            entry = record.getCurrentEntry(res.path, res.taxonomy_info)
-            entry.subs = res.sub_entries
-            if res.errors:
-                entry.errors += res.errors
-            if entry.has_any_error:
-                record.current.success = False
-
-        self._waitOnWorkerPool(
-                pool,
-                expected_result_count=expected_result_count,
-                result_handler=_handler)
+        ar = pool.queueJobs(jobs, handler=_handler)
+        ar.wait()
 
         # Now we create bake entries for all the terms that were *not* dirty.
         # This is because otherwise, on the next incremental bake, we wouldn't
@@ -470,9 +456,9 @@
                     logger.debug("Taxonomy term '%s:%s' isn't used anymore." %
                                  (ti.taxonomy_name, ti.term))
 
-        return expected_result_count
+        return len(jobs)
 
-    def _queueBakeJob(self, record, pool, fac, tax_info=None):
+    def _makeBakeJob(self, record, fac, tax_info=None):
         # Get the previous (if any) and current entry for this page.
         pair = record.getPreviousAndCurrentEntries(fac.path, tax_info)
         assert pair is not None
@@ -483,7 +469,7 @@
         if cur_entry.errors:
             logger.debug("Ignoring %s because it had previous "
                          "errors." % fac.ref_spec)
-            return False
+            return None
 
         # Build the route metadata and find the appropriate route.
         page = fac.buildPage()
@@ -515,15 +501,14 @@
                         (fac.ref_spec, uri, override_entry.path))
                 logger.error(cur_entry.errors[-1])
             cur_entry.flags |= BakeRecordEntry.FLAG_OVERRIDEN
-            return False
+            return None
 
         job = BakeWorkerJob(
                 JOB_BAKE,
                 BakeJobPayload(fac, route_metadata, prev_entry,
                                record.dirty_source_names,
                                tax_info))
-        pool.queue.put_nowait(job)
-        return True
+        return job
 
     def _handleDeletetions(self, record):
         logger.debug("Handling deletions...")
@@ -544,78 +529,16 @@
             logger.error("  " + e)
 
     def _createWorkerPool(self):
-        import sys
-        from piecrust.baking.worker import BakeWorkerContext, worker_func
-
-        main_module = sys.modules['__main__']
-        is_profiling = os.path.basename(main_module.__file__) in [
-                'profile.py', 'cProfile.py']
-
-        pool = _WorkerPool()
-        for i in range(self.num_workers):
-            ctx = BakeWorkerContext(
-                    self.app.root_dir, self.app.cache.base_dir, self.out_dir,
-                    pool.queue, pool.results, pool.abort_event,
-                    force=self.force, debug=self.app.debug,
-                    is_profiling=is_profiling)
-            w = multiprocessing.Process(
-                    name='BakeWorker_%d' % i,
-                    target=worker_func, args=(i, ctx))
-            w.start()
-            pool.workers.append(w)
-        return pool
-
-    def _terminateWorkerPool(self, pool):
-        pool.abort_event.set()
-        for w in pool.workers:
-            w.join()
+        from piecrust.workerpool import WorkerPool
+        from piecrust.baking.worker import BakeWorkerContext, BakeWorker
 
-    def _waitOnWorkerPool(self, pool,
-                          expected_result_count=-1, result_handler=None):
-        assert result_handler is None or expected_result_count >= 0
-        abort_with_exception = None
-        try:
-            if result_handler is None:
-                pool.queue.join()
-            else:
-                got_count = 0
-                while got_count < expected_result_count:
-                    try:
-                        res = pool.results.get(True, 10)
-                    except queue.Empty:
-                        logger.error(
-                                "Got %d results, expected %d, and timed-out "
-                                "for 10 seconds. A worker might be stuck?" %
-                                (got_count, expected_result_count))
-                        abort_with_exception = Exception("Worker time-out.")
-                        break
-
-                    if isinstance(res, dict) and res.get('type') == 'error':
-                        abort_with_exception = Exception(
-                                'Worker critical error:\n' +
-                                '\n'.join(res['messages']))
-                        break
-
-                    got_count += 1
-                    result_handler(res)
-        except KeyboardInterrupt as kiex:
-            logger.warning("Bake aborted by user... "
-                           "waiting for workers to stop.")
-            abort_with_exception = kiex
-
-        if abort_with_exception:
-            pool.abort_event.set()
-            for w in pool.workers:
-                w.join(2)
-            raise abort_with_exception
-
-
-class _WorkerPool(object):
-    def __init__(self):
-        self.queue = multiprocessing.JoinableQueue()
-        self.results = multiprocessing.Queue()
-        self.abort_event = multiprocessing.Event()
-        self.workers = []
+        ctx = BakeWorkerContext(
+                self.app.root_dir, self.app.cache.base_dir, self.out_dir,
+                force=self.force, debug=self.app.debug)
+        pool = WorkerPool(
+                worker_class=BakeWorker,
+                initargs=(ctx,))
+        return pool
 
 
 class _TaxonomyTermsInfo(object):
--- a/piecrust/baking/worker.py	Thu Jul 02 23:28:24 2015 -0700
+++ b/piecrust/baking/worker.py	Sun Jul 05 00:09:41 2015 -0700
@@ -1,5 +1,4 @@
 import time
-import queue
 import logging
 from piecrust.app import PieCrust
 from piecrust.baking.single import PageBaker, BakingError
@@ -7,45 +6,59 @@
         QualifiedPage, PageRenderingContext, render_page_segments)
 from piecrust.routing import create_route_metadata
 from piecrust.sources.base import PageFactory
+from piecrust.workerpool import IWorker
 
 
 logger = logging.getLogger(__name__)
 
 
-def worker_func(wid, ctx):
-    if ctx.is_profiling:
-        try:
-            import cProfile as profile
-        except ImportError:
-            import profile
-
-        ctx.is_profiling = False
-        profile.runctx('_real_worker_func(wid, ctx)',
-                       globals(), locals(),
-                       filename='BakeWorker-%d.prof' % wid)
-    else:
-        _real_worker_func(wid, ctx)
-
-
-def _real_worker_func(wid, ctx):
-    logger.debug("Worker %d booting up..." % wid)
-    w = BakeWorker(wid, ctx)
-    w.run()
-
-
 class BakeWorkerContext(object):
     def __init__(self, root_dir, sub_cache_dir, out_dir,
-                 work_queue, results, abort_event,
-                 force=False, debug=False, is_profiling=False):
+                 force=False, debug=False):
         self.root_dir = root_dir
         self.sub_cache_dir = sub_cache_dir
         self.out_dir = out_dir
-        self.work_queue = work_queue
-        self.results = results
-        self.abort_event = abort_event
         self.force = force
         self.debug = debug
-        self.is_profiling = is_profiling
+
+
+class BakeWorker(IWorker):
+    def __init__(self, ctx):
+        self.ctx = ctx
+        self.work_start_time = time.perf_counter()
+
+    def initialize(self):
+        # Create the app local to this worker.
+        app = PieCrust(self.ctx.root_dir, debug=self.ctx.debug)
+        app._useSubCacheDir(self.ctx.sub_cache_dir)
+        app.env.fs_cache_only_for_main_page = True
+        app.env.registerTimer("BakeWorker_%d_Total" % self.wid)
+        app.env.registerTimer("BakeWorkerInit")
+        app.env.registerTimer("JobReceive")
+        self.app = app
+
+        # Create the job handlers.
+        job_handlers = {
+                JOB_LOAD: LoadJobHandler(app, self.ctx),
+                JOB_RENDER_FIRST: RenderFirstSubJobHandler(app, self.ctx),
+                JOB_BAKE: BakeJobHandler(app, self.ctx)}
+        for jt, jh in job_handlers.items():
+            app.env.registerTimer(type(jh).__name__)
+        self.job_handlers = job_handlers
+
+        app.env.stepTimerSince("BakeWorkerInit", self.work_start_time)
+
+    def process(self, job):
+        handler = self.job_handlers[job.job_type]
+        with self.app.env.timerScope(type(handler).__name__):
+            return handler.handleJob(job)
+
+    def getReport(self):
+        self.app.env.stepTimerSince("BakeWorker_%d_Total" % self.wid,
+                                    self.work_start_time)
+        return {
+                'type': 'timers',
+                'data': self.app.env._timers}
 
 
 JOB_LOAD, JOB_RENDER_FIRST, JOB_BAKE = range(0, 3)
@@ -57,67 +70,6 @@
         self.payload = payload
 
 
-class BakeWorker(object):
-    def __init__(self, wid, ctx):
-        self.wid = wid
-        self.ctx = ctx
-
-    def run(self):
-        logger.debug("Working %d initializing..." % self.wid)
-        work_start_time = time.perf_counter()
-
-        # Create the app local to this worker.
-        app = PieCrust(self.ctx.root_dir, debug=self.ctx.debug)
-        app._useSubCacheDir(self.ctx.sub_cache_dir)
-        app.env.fs_cache_only_for_main_page = True
-        app.env.registerTimer("BakeWorker_%d_Total" % self.wid)
-        app.env.registerTimer("BakeWorkerInit")
-        app.env.registerTimer("JobReceive")
-
-        # Create the job handlers.
-        job_handlers = {
-                JOB_LOAD: LoadJobHandler(app, self.ctx),
-                JOB_RENDER_FIRST: RenderFirstSubJobHandler(app, self.ctx),
-                JOB_BAKE: BakeJobHandler(app, self.ctx)}
-        for jt, jh in job_handlers.items():
-            app.env.registerTimer(type(jh).__name__)
-
-        app.env.stepTimerSince("BakeWorkerInit", work_start_time)
-
-        # Start working!
-        aborted_with_exception = None
-        while not self.ctx.abort_event.is_set():
-            try:
-                with app.env.timerScope('JobReceive'):
-                    job = self.ctx.work_queue.get(True, 0.01)
-            except queue.Empty:
-                continue
-
-            try:
-                handler = job_handlers[job.job_type]
-                with app.env.timerScope(type(handler).__name__):
-                    handler.handleJob(job)
-            except Exception as ex:
-                self.ctx.abort_event.set()
-                aborted_with_exception = ex
-                logger.debug("[%d] Critical error, aborting." % self.wid)
-                if self.ctx.debug:
-                    logger.exception(ex)
-                break
-            finally:
-                self.ctx.work_queue.task_done()
-
-        if aborted_with_exception is not None:
-            msgs = _get_errors(aborted_with_exception)
-            self.ctx.results.put_nowait({'type': 'error', 'messages': msgs})
-
-        # Send our timers to the main process before exiting.
-        app.env.stepTimerSince("BakeWorker_%d_Total" % self.wid,
-                               work_start_time)
-        self.ctx.results.put_nowait({
-                'type': 'timers', 'data': app.env._timers})
-
-
 class JobHandler(object):
     def __init__(self, app, ctx):
         self.app = app
@@ -203,8 +155,7 @@
             result.errors = _get_errors(ex)
             if self.ctx.debug:
                 logger.exception(ex)
-
-        self.ctx.results.put_nowait(result)
+        return result
 
 
 class RenderFirstSubJobHandler(JobHandler):
@@ -231,8 +182,7 @@
             result.errors = _get_errors(ex)
             if self.ctx.debug:
                 logger.exception(ex)
-
-        self.ctx.results.put_nowait(result)
+        return result
 
 
 class BakeJobHandler(JobHandler):
@@ -272,5 +222,5 @@
             if self.ctx.debug:
                 logger.exception(ex)
 
-        self.ctx.results.put_nowait(result)
+        return result
 
--- a/piecrust/processing/pipeline.py	Thu Jul 02 23:28:24 2015 -0700
+++ b/piecrust/processing/pipeline.py	Sun Jul 05 00:09:41 2015 -0700
@@ -2,7 +2,6 @@
 import os.path
 import re
 import time
-import queue
 import hashlib
 import logging
 import multiprocessing
@@ -12,16 +11,16 @@
         ProcessorPipelineRecordEntry, TransitionalProcessorPipelineRecord,
         FLAG_PROCESSED)
 from piecrust.processing.worker import (
-        ProcessingWorkerContext, ProcessingWorkerJob,
-        worker_func, get_filtered_processors)
+        ProcessingWorkerJob,
+        get_filtered_processors)
 
 
 logger = logging.getLogger(__name__)
 
 
 class _ProcessingContext(object):
-    def __init__(self, pool, record, base_dir, mount_info):
-        self.pool = pool
+    def __init__(self, jobs, record, base_dir, mount_info):
+        self.jobs = jobs
         self.record = record
         self.base_dir = base_dir
         self.mount_info = mount_info
@@ -94,9 +93,6 @@
         self.ignore_patterns += make_re(
                 pipeline_ctx._additional_ignore_patterns)
 
-        # Create the worker pool.
-        pool = _WorkerPool()
-
         # Create the pipeline record.
         record = TransitionalProcessorPipelineRecord()
         record_cache = self.app.cache.getCache('proc')
@@ -132,19 +128,19 @@
                 for e in entry.errors:
                     logger.error("  " + e)
 
+        jobs = []
+        self._process(src_dir_or_file, record, jobs)
         pool = self._createWorkerPool()
-        expected_result_count = self._process(src_dir_or_file, pool, record)
-        self._waitOnWorkerPool(pool, expected_result_count, _handler)
-        self._terminateWorkerPool(pool)
+        ar = pool.queueJobs(jobs, handler=_handler)
+        ar.wait()
 
-        # Get timing information from the workers.
+        # Shutdown the workers and get timing information from them.
+        reports = pool.close()
         record.current.timers = {}
-        for i in range(len(pool.workers)):
-            try:
-                timers = pool.results.get(True, 0.1)
-            except queue.Empty:
-                logger.error("Didn't get timing information from all workers.")
-                break
+        for i in range(len(reports)):
+            timers = reports[i]
+            if timers is None:
+                continue
 
             worker_name = 'PipelineWorker_%d' % i
             record.current.timers[worker_name] = {}
@@ -185,9 +181,7 @@
 
         return record.detach()
 
-    def _process(self, src_dir_or_file, pool, record):
-        expected_result_count = 0
-
+    def _process(self, src_dir_or_file, record, jobs):
         if src_dir_or_file is not None:
             # Process only the given path.
             # Find out what mount point this is in.
@@ -203,28 +197,23 @@
                                 "mount point: %s" %
                                 (src_dir_or_file, known_roots))
 
-            ctx = _ProcessingContext(pool, record, base_dir, mount_info)
+            ctx = _ProcessingContext(jobs, record, base_dir, mount_info)
             logger.debug("Initiating processing pipeline on: %s" %
                          src_dir_or_file)
             if os.path.isdir(src_dir_or_file):
-                expected_result_count = self._processDirectory(
-                        ctx, src_dir_or_file)
+                self._processDirectory(ctx, src_dir_or_file)
             elif os.path.isfile(src_dir_or_file):
                 self._processFile(ctx, src_dir_or_file)
-                expected_result_count = 1
 
         else:
             # Process everything.
             for name, info in self.mounts.items():
                 path = info['path']
-                ctx = _ProcessingContext(pool, record, path, info)
+                ctx = _ProcessingContext(jobs, record, path, info)
                 logger.debug("Initiating processing pipeline on: %s" % path)
-                expected_result_count = self._processDirectory(ctx, path)
-
-        return expected_result_count
+                self._processDirectory(ctx, path)
 
     def _processDirectory(self, ctx, start_dir):
-        queued_count = 0
         for dirpath, dirnames, filenames in os.walk(start_dir):
             rel_dirpath = os.path.relpath(dirpath, start_dir)
             dirnames[:] = [d for d in dirnames
@@ -235,8 +224,6 @@
                 if re_matchany(filename, self.ignore_patterns, rel_dirpath):
                     continue
                 self._processFile(ctx, os.path.join(dirpath, filename))
-                queued_count += 1
-        return queued_count
 
     def _processFile(self, ctx, path):
         # TODO: handle overrides between mount-points.
@@ -250,79 +237,23 @@
 
         job = ProcessingWorkerJob(ctx.base_dir, ctx.mount_info, path,
                                   force=force_this)
-
-        logger.debug("Queuing: %s" % path)
-        ctx.pool.queue.put_nowait(job)
+        ctx.jobs.append(job)
 
     def _createWorkerPool(self):
-        import sys
-
-        main_module = sys.modules['__main__']
-        is_profiling = os.path.basename(main_module.__file__) in [
-                'profile.py', 'cProfile.py']
-
-        pool = _WorkerPool()
-        for i in range(self.num_workers):
-            ctx = ProcessingWorkerContext(
-                    self.app.root_dir, self.out_dir, self.tmp_dir,
-                    pool.queue, pool.results, pool.abort_event,
-                    self.force, self.app.debug)
-            ctx.is_profiling = is_profiling
-            ctx.enabled_processors = self.enabled_processors
-            ctx.additional_processors = self.additional_processors
-            w = multiprocessing.Process(
-                    name='PipelineWorker_%d' % i,
-                    target=worker_func, args=(i, ctx))
-            w.start()
-            pool.workers.append(w)
-        return pool
+        from piecrust.workerpool import WorkerPool
+        from piecrust.processing.worker import (
+                ProcessingWorkerContext, ProcessingWorker)
 
-    def _waitOnWorkerPool(self, pool, expected_result_count, result_handler):
-        abort_with_exception = None
-        try:
-            got_count = 0
-            while got_count < expected_result_count:
-                try:
-                    res = pool.results.get(True, 10)
-                except queue.Empty:
-                    logger.error(
-                            "Got %d results, expected %d, and timed-out "
-                            "for 10 seconds. A worker might be stuck?" %
-                            (got_count, expected_result_count))
-                    abort_with_exception = Exception("Worker time-out.")
-                    break
-
-                if isinstance(res, dict) and res.get('type') == 'error':
-                    abort_with_exception = Exception(
-                            'Worker critical error:\n' +
-                            '\n'.join(res['messages']))
-                    break
+        ctx = ProcessingWorkerContext(
+                self.app.root_dir, self.out_dir, self.tmp_dir,
+                self.force, self.app.debug)
+        ctx.enabled_processors = self.enabled_processors
+        ctx.additional_processors = self.additional_processors
 
-                got_count += 1
-                result_handler(res)
-        except KeyboardInterrupt as kiex:
-            logger.warning("Bake aborted by user... "
-                           "waiting for workers to stop.")
-            abort_with_exception = kiex
-
-        if abort_with_exception:
-            pool.abort_event.set()
-            for w in pool.workers:
-                w.join(2)
-            raise abort_with_exception
-
-    def _terminateWorkerPool(self, pool):
-        pool.abort_event.set()
-        for w in pool.workers:
-            w.join()
-
-
-class _WorkerPool(object):
-    def __init__(self):
-        self.queue = multiprocessing.JoinableQueue()
-        self.results = multiprocessing.Queue()
-        self.abort_event = multiprocessing.Event()
-        self.workers = []
+        pool = WorkerPool(
+                worker_class=ProcessingWorker,
+                initargs=(ctx,))
+        return pool
 
 
 def make_mount_infos(mounts, root_dir):
--- a/piecrust/processing/worker.py	Thu Jul 02 23:28:24 2015 -0700
+++ b/piecrust/processing/worker.py	Sun Jul 05 00:09:41 2015 -0700
@@ -1,7 +1,6 @@
+import re
 import os.path
-import re
 import time
-import queue
 import logging
 from piecrust.app import PieCrust
 from piecrust.processing.base import PipelineContext
@@ -13,6 +12,7 @@
         ProcessingTreeError, ProcessorError,
         get_node_name_tree, print_node,
         STATE_DIRTY)
+from piecrust.workerpool import IWorker
 
 
 logger = logging.getLogger(__name__)
@@ -22,37 +22,12 @@
 re_ansicolors = re.compile('\033\\[\d+m')
 
 
-def worker_func(wid, ctx):
-    if ctx.is_profiling:
-        try:
-            import cProfile as profile
-        except ImportError:
-            import profile
-
-        ctx.is_profiling = False
-        profile.runctx('_real_worker_func(wid, ctx)',
-                       globals(), locals(),
-                       filename='PipelineWorker-%d.prof' % wid)
-    else:
-        _real_worker_func(wid, ctx)
-
-
-def _real_worker_func(wid, ctx):
-    logger.debug("Worker %d booting up..." % wid)
-    w = ProcessingWorker(wid, ctx)
-    w.run()
-
-
 class ProcessingWorkerContext(object):
     def __init__(self, root_dir, out_dir, tmp_dir,
-                 work_queue, results, abort_event,
                  force=False, debug=False):
         self.root_dir = root_dir
         self.out_dir = out_dir
         self.tmp_dir = tmp_dir
-        self.work_queue = work_queue
-        self.results = results
-        self.abort_event = abort_event
         self.force = force
         self.debug = debug
         self.is_profiling = False
@@ -77,23 +52,20 @@
         self.errors = None
 
 
-class ProcessingWorker(object):
-    def __init__(self, wid, ctx):
-        self.wid = wid
+class ProcessingWorker(IWorker):
+    def __init__(self, ctx):
         self.ctx = ctx
+        self.work_start_time = time.perf_counter()
 
-    def run(self):
-        logger.debug("Worker %d initializing..." % self.wid)
-        work_start_time = time.perf_counter()
-
+    def initialize(self):
         # Create the app local to this worker.
         app = PieCrust(self.ctx.root_dir, debug=self.ctx.debug)
-        app.env.fs_cache_only_for_main_page = True
         app.env.registerTimer("PipelineWorker_%d_Total" % self.wid)
         app.env.registerTimer("PipelineWorkerInit")
         app.env.registerTimer("JobReceive")
         app.env.registerTimer('BuildProcessingTree')
         app.env.registerTimer('RunProcessingTree')
+        self.app = app
 
         processors = app.plugin_loader.getProcessors()
         if self.ctx.enabled_processors:
@@ -108,9 +80,10 @@
                 app.env.registerTimer(proc.__class__.__name__)
                 proc.initialize(app)
                 processors.append(proc)
+        self.processors = processors
 
         # Invoke pre-processors.
-        pipeline_ctx = PipelineContext(self.wid, app, self.ctx.out_dir,
+        pipeline_ctx = PipelineContext(self.wid, self.app, self.ctx.out_dir,
                                        self.ctx.tmp_dir, self.ctx.force)
         for proc in processors:
             proc.onPipelineStart(pipeline_ctx)
@@ -119,52 +92,18 @@
         # patching the processors with some new ones.
         processors.sort(key=lambda p: p.priority)
 
-        app.env.stepTimerSince("PipelineWorkerInit", work_start_time)
-
-        aborted_with_exception = None
-        while not self.ctx.abort_event.is_set():
-            try:
-                with app.env.timerScope('JobReceive'):
-                    job = self.ctx.work_queue.get(True, 0.01)
-            except queue.Empty:
-                continue
+        app.env.stepTimerSince("PipelineWorkerInit", self.work_start_time)
 
-            try:
-                result = self._unsafeRun(app, processors, job)
-                self.ctx.results.put_nowait(result)
-            except Exception as ex:
-                self.ctx.abort_event.set()
-                aborted_with_exception = ex
-                logger.error("[%d] Critical error, aborting." % self.wid)
-                if self.ctx.debug:
-                    logger.exception(ex)
-                break
-            finally:
-                self.ctx.work_queue.task_done()
-
-        if aborted_with_exception is not None:
-            msgs = _get_errors(aborted_with_exception)
-            self.ctx.results.put_nowait({'type': 'error', 'messages': msgs})
-
-        # Invoke post-processors.
-        for proc in processors:
-            proc.onPipelineEnd(pipeline_ctx)
-
-        app.env.stepTimerSince("PipelineWorker_%d_Total" % self.wid,
-                               work_start_time)
-        self.ctx.results.put_nowait({
-                'type': 'timers', 'data': app.env._timers})
-
-    def _unsafeRun(self, app, processors, job):
+    def process(self, job):
         result = ProcessingWorkerResult(job.path)
 
         processors = get_filtered_processors(
-                processors, job.mount_info['processors'])
+                self.processors, job.mount_info['processors'])
 
         # Build the processing tree for this job.
         rel_path = os.path.relpath(job.path, job.base_dir)
         try:
-            with app.env.timerScope('BuildProcessingTree'):
+            with self.app.env.timerScope('BuildProcessingTree'):
                 builder = ProcessingTreeBuilder(processors)
                 tree_root = builder.build(rel_path)
                 result.flags |= FLAG_PREPARED
@@ -184,7 +123,7 @@
             tree_root.setState(STATE_DIRTY, True)
 
         try:
-            with app.env.timerScope('RunProcessingTree'):
+            with self.app.env.timerScope('RunProcessingTree'):
                 runner = ProcessingTreeRunner(
                         job.base_dir, self.ctx.tmp_dir, self.ctx.out_dir)
                 if runner.processSubTree(tree_root):
@@ -197,6 +136,19 @@
 
         return result
 
+    def getReport(self):
+        # Invoke post-processors.
+        pipeline_ctx = PipelineContext(self.wid, self.app, self.ctx.out_dir,
+                                       self.ctx.tmp_dir, self.ctx.force)
+        for proc in self.processors:
+            proc.onPipelineEnd(pipeline_ctx)
+
+        self.app.env.stepTimerSince("PipelineWorker_%d_Total" % self.wid,
+                                    self.work_start_time)
+        return {
+                'type': 'timers',
+                'data': self.app.env._timers}
+
 
 def get_filtered_processors(processors, authorized_names):
     if not authorized_names or authorized_names == 'all':
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/piecrust/workerpool.py	Sun Jul 05 00:09:41 2015 -0700
@@ -0,0 +1,267 @@
+import os
+import sys
+import logging
+import threading
+import multiprocessing
+
+
+logger = logging.getLogger(__name__)
+
+
+class IWorker(object):
+    def initialize(self):
+        raise NotImplementedError()
+
+    def process(self, job):
+        raise NotImplementedError()
+
+    def getReport(self):
+        return None
+
+
+TASK_JOB = 0
+TASK_END = 1
+
+
+def worker_func(params):
+    if params.is_profiling:
+        try:
+            import cProfile as profile
+        except ImportError:
+            import profile
+
+        params.is_profiling = False
+        name = params.worker_class.__name__
+        profile.runctx('_real_worker_func(params)',
+                       globals(), locals(),
+                       filename='%s-%d.prof' % (name, params.wid))
+    else:
+        _real_worker_func(params)
+
+
+def _real_worker_func(params):
+    if hasattr(params.inqueue, '_writer'):
+        params.inqueue._writer.close()
+        params.outqueue._reader.close()
+
+    wid = params.wid
+    logger.debug("Worker %d initializing..." % wid)
+
+    w = params.worker_class(*params.initargs)
+    w.wid = wid
+    w.initialize()
+
+    get = params.inqueue.get
+    put = params.outqueue.put
+
+    completed = 0
+    while True:
+        try:
+            task = get()
+        except (EOFError, OSError):
+            logger.debug("Worker %d encountered connection problem." % wid)
+            break
+
+        task_type, task_data = task
+        if task_type == TASK_END:
+            logger.debug("Worker %d got end task, exiting." % wid)
+            try:
+                rep = (task_type, True, wid, (wid, w.getReport()))
+            except Exception as e:
+                if params.wrap_exception:
+                    e = multiprocessing.ExceptionWithTraceback(
+                            e, e.__traceback__)
+                rep = (task_type, False, wid, (wid, e))
+            put(rep)
+            break
+
+        try:
+            res = (task_type, True, wid, w.process(task_data))
+        except Exception as e:
+            if params.wrap_exception:
+                e = multiprocessing.ExceptionWithTraceback(e, e.__traceback__)
+            res = (task_type, False, wid, e)
+        put(res)
+
+        completed += 1
+
+    logger.debug("Worker %d completed %d tasks." % (wid, completed))
+
+
+class _WorkerParams(object):
+    def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(),
+                 wrap_exception=False, is_profiling=False):
+        self.wid = wid
+        self.inqueue = inqueue
+        self.outqueue = outqueue
+        self.worker_class = worker_class
+        self.initargs = initargs
+        self.wrap_exception = wrap_exception
+        self.is_profiling = is_profiling
+
+
+class WorkerPool(object):
+    def __init__(self, worker_class, worker_count=None, initargs=()):
+        worker_count = worker_count or os.cpu_count() or 1
+
+        self._task_queue = multiprocessing.SimpleQueue()
+        self._result_queue = multiprocessing.SimpleQueue()
+        self._quick_put = self._task_queue._writer.send
+        self._quick_get = self._result_queue._reader.recv
+
+        self._callback = None
+        self._error_callback = None
+        self._listener = None
+
+        main_module = sys.modules['__main__']
+        is_profiling = os.path.basename(main_module.__file__) in [
+                'profile.py', 'cProfile.py']
+
+        self._pool = []
+        for i in range(worker_count):
+            worker_params = _WorkerParams(
+                    i, self._task_queue, self._result_queue,
+                    worker_class, initargs,
+                    is_profiling=is_profiling)
+            w = multiprocessing.Process(target=worker_func,
+                                        args=(worker_params,))
+            w.name = w.name.replace('Process', 'PoolWorker')
+            w.daemon = True
+            w.start()
+            self._pool.append(w)
+
+        self._result_handler = threading.Thread(
+                target=WorkerPool._handleResults,
+                args=(self,))
+        self._result_handler.daemon = True
+        self._result_handler.start()
+
+        self._closed = False
+
+    def setHandler(self, callback=None, error_callback=None):
+        self._callback = callback
+        self._error_callback = error_callback
+
+    def queueJobs(self, jobs, handler=None):
+        if self._closed:
+            raise Exception("This worker pool has been closed.")
+        if self._listener is not None:
+            raise Exception("A previous job queue has not finished yet.")
+
+        if handler is not None:
+            self.setHandler(handler)
+
+        if not hasattr(jobs, '__len__'):
+            jobs = list(jobs)
+
+        res = AsyncResult(self, len(jobs))
+        if res._count == 0:
+            res._event.set()
+            return res
+
+        self._listener = res
+        for job in jobs:
+            self._quick_put((TASK_JOB, job))
+
+        return res
+
+    def close(self):
+        if self._listener is not None:
+            raise Exception("A previous job queue has not finished yet.")
+
+        logger.debug("Closing worker pool...")
+        handler = _ReportHandler(len(self._pool))
+        self._callback = handler._handle
+        for w in self._pool:
+            self._quick_put((TASK_END, None))
+        for w in self._pool:
+            w.join()
+
+        logger.debug("Waiting for reports...")
+        if not handler.wait(2):
+            missing = handler.reports.index(None)
+            logger.warning(
+                    "Didn't receive all worker reports before timeout. "
+                    "Missing report from worker %d." % missing)
+
+        logger.debug("Exiting result handler thread...")
+        self._result_queue.put(None)
+        self._result_handler.join()
+        self._closed = True
+
+        return handler.reports
+
+    @staticmethod
+    def _handleResults(pool):
+        while True:
+            try:
+                res = pool._quick_get()
+            except (EOFError, OSError):
+                logger.debug("Result handler thread encountered connection "
+                             "problem, exiting.")
+                return
+
+            if res is None:
+                logger.debug("Result handler exiting.")
+                break
+
+            task_type, success, wid, data = res
+            try:
+                if success and pool._callback:
+                    pool._callback(data)
+                elif not success and pool._error_callback:
+                    pool._error_callback(data)
+            except Exception as ex:
+                logger.exception(ex)
+
+            if task_type == TASK_JOB:
+                pool._listener._onTaskDone()
+
+
+class AsyncResult(object):
+    def __init__(self, pool, count):
+        self._pool = pool
+        self._count = count
+        self._event = threading.Event()
+
+    def ready(self):
+        return self._event.is_set()
+
+    def wait(self, timeout=None):
+        return self._event.wait(timeout)
+
+    def _onTaskDone(self):
+        self._count -= 1
+        if self._count == 0:
+            self._pool.setHandler(None)
+            self._pool._listener = None
+            self._event.set()
+
+
+class _ReportHandler(object):
+    def __init__(self, worker_count):
+        self.reports = [None] * worker_count
+        self._count = worker_count
+        self._received = 0
+        self._event = threading.Event()
+
+    def wait(self, timeout=None):
+        return self._event.wait(timeout)
+
+    def _handle(self, res):
+        wid, data = res
+        if wid < 0 or wid > self._count:
+            logger.error("Ignoring report from unknown worker %d." % wid)
+            return
+
+        self._received += 1
+        self.reports[wid] = data
+
+        if self._received == self._count:
+            self._event.set()
+
+    def _handleError(self, res):
+        wid, data = res
+        logger.error("Worker %d failed to send its report." % wid)
+        logger.exception(data)
+