diff piecrust/processing/pipeline.py @ 447:aefe70229fdd

bake: Commonize worker pool code between html and asset baking. The `workerpool` package now defines a generic-ish worker pool. It's similar to the Python framework pool but with a simpler use-case (only one way to queue jobs) and support for workers to send a final "report" to the master process, which we use to get timing information here. The rest of the changes basically remove a whole bunch of duplicated code that's not needed anymore.
author Ludovic Chabant <ludovic@chabant.com>
date Sun, 05 Jul 2015 00:09:41 -0700
parents 171dde4f61dc
children d90ccdf18156
line wrap: on
line diff
--- a/piecrust/processing/pipeline.py	Thu Jul 02 23:28:24 2015 -0700
+++ b/piecrust/processing/pipeline.py	Sun Jul 05 00:09:41 2015 -0700
@@ -2,7 +2,6 @@
 import os.path
 import re
 import time
-import queue
 import hashlib
 import logging
 import multiprocessing
@@ -12,16 +11,16 @@
         ProcessorPipelineRecordEntry, TransitionalProcessorPipelineRecord,
         FLAG_PROCESSED)
 from piecrust.processing.worker import (
-        ProcessingWorkerContext, ProcessingWorkerJob,
-        worker_func, get_filtered_processors)
+        ProcessingWorkerJob,
+        get_filtered_processors)
 
 
 logger = logging.getLogger(__name__)
 
 
 class _ProcessingContext(object):
-    def __init__(self, pool, record, base_dir, mount_info):
-        self.pool = pool
+    def __init__(self, jobs, record, base_dir, mount_info):
+        self.jobs = jobs
         self.record = record
         self.base_dir = base_dir
         self.mount_info = mount_info
@@ -94,9 +93,6 @@
         self.ignore_patterns += make_re(
                 pipeline_ctx._additional_ignore_patterns)
 
-        # Create the worker pool.
-        pool = _WorkerPool()
-
         # Create the pipeline record.
         record = TransitionalProcessorPipelineRecord()
         record_cache = self.app.cache.getCache('proc')
@@ -132,19 +128,19 @@
                 for e in entry.errors:
                     logger.error("  " + e)
 
+        jobs = []
+        self._process(src_dir_or_file, record, jobs)
         pool = self._createWorkerPool()
-        expected_result_count = self._process(src_dir_or_file, pool, record)
-        self._waitOnWorkerPool(pool, expected_result_count, _handler)
-        self._terminateWorkerPool(pool)
+        ar = pool.queueJobs(jobs, handler=_handler)
+        ar.wait()
 
-        # Get timing information from the workers.
+        # Shutdown the workers and get timing information from them.
+        reports = pool.close()
         record.current.timers = {}
-        for i in range(len(pool.workers)):
-            try:
-                timers = pool.results.get(True, 0.1)
-            except queue.Empty:
-                logger.error("Didn't get timing information from all workers.")
-                break
+        for i in range(len(reports)):
+            timers = reports[i]
+            if timers is None:
+                continue
 
             worker_name = 'PipelineWorker_%d' % i
             record.current.timers[worker_name] = {}
@@ -185,9 +181,7 @@
 
         return record.detach()
 
-    def _process(self, src_dir_or_file, pool, record):
-        expected_result_count = 0
-
+    def _process(self, src_dir_or_file, record, jobs):
         if src_dir_or_file is not None:
             # Process only the given path.
             # Find out what mount point this is in.
@@ -203,28 +197,23 @@
                                 "mount point: %s" %
                                 (src_dir_or_file, known_roots))
 
-            ctx = _ProcessingContext(pool, record, base_dir, mount_info)
+            ctx = _ProcessingContext(jobs, record, base_dir, mount_info)
             logger.debug("Initiating processing pipeline on: %s" %
                          src_dir_or_file)
             if os.path.isdir(src_dir_or_file):
-                expected_result_count = self._processDirectory(
-                        ctx, src_dir_or_file)
+                self._processDirectory(ctx, src_dir_or_file)
             elif os.path.isfile(src_dir_or_file):
                 self._processFile(ctx, src_dir_or_file)
-                expected_result_count = 1
 
         else:
             # Process everything.
             for name, info in self.mounts.items():
                 path = info['path']
-                ctx = _ProcessingContext(pool, record, path, info)
+                ctx = _ProcessingContext(jobs, record, path, info)
                 logger.debug("Initiating processing pipeline on: %s" % path)
-                expected_result_count = self._processDirectory(ctx, path)
-
-        return expected_result_count
+                self._processDirectory(ctx, path)
 
     def _processDirectory(self, ctx, start_dir):
-        queued_count = 0
         for dirpath, dirnames, filenames in os.walk(start_dir):
             rel_dirpath = os.path.relpath(dirpath, start_dir)
             dirnames[:] = [d for d in dirnames
@@ -235,8 +224,6 @@
                 if re_matchany(filename, self.ignore_patterns, rel_dirpath):
                     continue
                 self._processFile(ctx, os.path.join(dirpath, filename))
-                queued_count += 1
-        return queued_count
 
     def _processFile(self, ctx, path):
         # TODO: handle overrides between mount-points.
@@ -250,79 +237,23 @@
 
         job = ProcessingWorkerJob(ctx.base_dir, ctx.mount_info, path,
                                   force=force_this)
-
-        logger.debug("Queuing: %s" % path)
-        ctx.pool.queue.put_nowait(job)
+        ctx.jobs.append(job)
 
     def _createWorkerPool(self):
-        import sys
-
-        main_module = sys.modules['__main__']
-        is_profiling = os.path.basename(main_module.__file__) in [
-                'profile.py', 'cProfile.py']
-
-        pool = _WorkerPool()
-        for i in range(self.num_workers):
-            ctx = ProcessingWorkerContext(
-                    self.app.root_dir, self.out_dir, self.tmp_dir,
-                    pool.queue, pool.results, pool.abort_event,
-                    self.force, self.app.debug)
-            ctx.is_profiling = is_profiling
-            ctx.enabled_processors = self.enabled_processors
-            ctx.additional_processors = self.additional_processors
-            w = multiprocessing.Process(
-                    name='PipelineWorker_%d' % i,
-                    target=worker_func, args=(i, ctx))
-            w.start()
-            pool.workers.append(w)
-        return pool
+        from piecrust.workerpool import WorkerPool
+        from piecrust.processing.worker import (
+                ProcessingWorkerContext, ProcessingWorker)
 
-    def _waitOnWorkerPool(self, pool, expected_result_count, result_handler):
-        abort_with_exception = None
-        try:
-            got_count = 0
-            while got_count < expected_result_count:
-                try:
-                    res = pool.results.get(True, 10)
-                except queue.Empty:
-                    logger.error(
-                            "Got %d results, expected %d, and timed-out "
-                            "for 10 seconds. A worker might be stuck?" %
-                            (got_count, expected_result_count))
-                    abort_with_exception = Exception("Worker time-out.")
-                    break
-
-                if isinstance(res, dict) and res.get('type') == 'error':
-                    abort_with_exception = Exception(
-                            'Worker critical error:\n' +
-                            '\n'.join(res['messages']))
-                    break
+        ctx = ProcessingWorkerContext(
+                self.app.root_dir, self.out_dir, self.tmp_dir,
+                self.force, self.app.debug)
+        ctx.enabled_processors = self.enabled_processors
+        ctx.additional_processors = self.additional_processors
 
-                got_count += 1
-                result_handler(res)
-        except KeyboardInterrupt as kiex:
-            logger.warning("Bake aborted by user... "
-                           "waiting for workers to stop.")
-            abort_with_exception = kiex
-
-        if abort_with_exception:
-            pool.abort_event.set()
-            for w in pool.workers:
-                w.join(2)
-            raise abort_with_exception
-
-    def _terminateWorkerPool(self, pool):
-        pool.abort_event.set()
-        for w in pool.workers:
-            w.join()
-
-
-class _WorkerPool(object):
-    def __init__(self):
-        self.queue = multiprocessing.JoinableQueue()
-        self.results = multiprocessing.Queue()
-        self.abort_event = multiprocessing.Event()
-        self.workers = []
+        pool = WorkerPool(
+                worker_class=ProcessingWorker,
+                initargs=(ctx,))
+        return pool
 
 
 def make_mount_infos(mounts, root_dir):