diff piecrust/processing/pipeline.py @ 414:c4b3a7fd2f87

bake: Make pipeline processing multi-process. Not many changes here, as it's pretty straightforward, but an API change for processors so they know if they're being initialized/disposed from the main process or from one of the workers. This makes it possible to do global stuff that has side-effects (e.g. create a directory) vs. doing in-memory stuff.
author Ludovic Chabant <ludovic@chabant.com>
date Sat, 20 Jun 2015 19:20:30 -0700
parents
children 4a43d7015b75
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/piecrust/processing/pipeline.py	Sat Jun 20 19:20:30 2015 -0700
@@ -0,0 +1,351 @@
+import os
+import os.path
+import re
+import time
+import queue
+import hashlib
+import logging
+import multiprocessing
+from piecrust.chefutil import format_timed, format_timed_scope
+from piecrust.processing.base import PipelineContext
+from piecrust.processing.records import (
+        ProcessorPipelineRecordEntry, TransitionalProcessorPipelineRecord,
+        FLAG_PROCESSED)
+from piecrust.processing.worker import (
+        ProcessingWorkerContext, ProcessingWorkerJob,
+        worker_func, get_filtered_processors)
+
+
+logger = logging.getLogger(__name__)
+
+
+class _ProcessingContext(object):
+    def __init__(self, pool, record, base_dir, mount_info):
+        self.pool = pool
+        self.record = record
+        self.base_dir = base_dir
+        self.mount_info = mount_info
+
+
+class ProcessorPipeline(object):
+    def __init__(self, app, out_dir, force=False):
+        assert app and out_dir
+        self.app = app
+        self.out_dir = out_dir
+        self.force = force
+
+        tmp_dir = app.sub_cache_dir
+        if not tmp_dir:
+            import tempfile
+            tmp_dir = os.path.join(tempfile.gettempdir(), 'piecrust')
+        self.tmp_dir = os.path.join(tmp_dir, 'proc')
+
+        baker_params = app.config.get('baker') or {}
+
+        assets_dirs = baker_params.get('assets_dirs', app.assets_dirs)
+        self.mounts = make_mount_infos(assets_dirs, self.app.root_dir)
+
+        self.num_workers = baker_params.get(
+                'workers', multiprocessing.cpu_count())
+
+        ignores = baker_params.get('ignore', [])
+        ignores += [
+                '_cache', '_counter',
+                'theme_info.yml',
+                '.DS_Store', 'Thumbs.db',
+                '.git*', '.hg*', '.svn']
+        self.ignore_patterns = make_re(ignores)
+        self.force_patterns = make_re(baker_params.get('force', []))
+
+        # Those things are mostly for unit-testing.
+        self.enabled_processors = None
+        self.additional_processors = None
+
+    def addIgnorePatterns(self, patterns):
+        self.ignore_patterns += make_re(patterns)
+
+    def run(self, src_dir_or_file=None, *,
+            delete=True, previous_record=None, save_record=True):
+        start_time = time.perf_counter()
+
+        # Get the list of processors for this run.
+        processors = self.app.plugin_loader.getProcessors()
+        if self.enabled_processors is not None:
+            logger.debug("Filtering processors to: %s" %
+                         self.enabled_processors)
+            processors = get_filtered_processors(processors,
+                                                 self.enabled_processors)
+        if self.additional_processors is not None:
+            logger.debug("Adding %s additional processors." %
+                         len(self.additional_processors))
+            for proc in self.additional_processors:
+                self.app.env.registerTimer(proc.__class__.__name__,
+                                           raise_if_registered=False)
+                proc.initialize(self.app)
+                processors.append(proc)
+
+        # Invoke pre-processors.
+        pipeline_ctx = PipelineContext(-1, self.app, self.out_dir,
+                                       self.tmp_dir, self.force)
+        for proc in processors:
+            proc.onPipelineStart(pipeline_ctx)
+
+        # Pre-processors can define additional ignore patterns.
+        self.ignore_patterns += make_re(
+                pipeline_ctx._additional_ignore_patterns)
+
+        # Create the worker pool.
+        pool = _WorkerPool()
+
+        # Create the pipeline record.
+        record = TransitionalProcessorPipelineRecord()
+        record_cache = self.app.cache.getCache('proc')
+        record_name = (
+                hashlib.md5(self.out_dir.encode('utf8')).hexdigest() +
+                '.record')
+        if previous_record:
+            record.setPrevious(previous_record)
+        elif not self.force and record_cache.has(record_name):
+            with format_timed_scope(logger, 'loaded previous bake record',
+                                    level=logging.DEBUG, colored=False):
+                record.loadPrevious(record_cache.getCachePath(record_name))
+        logger.debug("Got %d entries in process record." %
+                     len(record.previous.entries))
+        record.current.success = True
+        record.current.processed_count = 0
+
+        # Work!
+        def _handler(res):
+            entry = record.getCurrentEntry(res.path)
+            assert entry is not None
+            entry.flags |= res.flags
+            entry.proc_tree = res.proc_tree
+            entry.rel_outputs = res.rel_outputs
+            if res.errors:
+                entry.errors += res.errors
+                record.current.success = False
+            if entry.flags & FLAG_PROCESSED:
+                record.current.processed_count += 1
+
+        pool = self._createWorkerPool()
+        expected_result_count = self._process(src_dir_or_file, pool, record)
+        self._waitOnWorkerPool(pool, expected_result_count, _handler)
+        self._terminateWorkerPool(pool)
+
+        # Get timing information from the workers.
+        record.current.timers = {}
+        for _ in range(len(pool.workers)):
+            try:
+                timers = pool.results.get(True, 0.1)
+            except queue.Empty:
+                logger.error("Didn't get timing information from all workers.")
+                break
+
+            for name, val in timers['data'].items():
+                main_val = record.current.timers.setdefault(name, 0)
+                record.current.timers[name] = main_val + val
+
+        # Invoke post-processors.
+        pipeline_ctx.record = record.current
+        for proc in processors:
+            proc.onPipelineEnd(pipeline_ctx)
+
+        # Handle deletions.
+        if delete:
+            for path, reason in record.getDeletions():
+                logger.debug("Removing '%s': %s" % (path, reason))
+                try:
+                    os.remove(path)
+                except FileNotFoundError:
+                    pass
+                logger.info('[delete] %s' % path)
+
+        # Finalize the process record.
+        record.current.process_time = time.time()
+        record.current.out_dir = self.out_dir
+        record.collapseRecords()
+
+        # Save the process record.
+        if save_record:
+            with format_timed_scope(logger, 'saved bake record',
+                                    level=logging.DEBUG, colored=False):
+                record.saveCurrent(record_cache.getCachePath(record_name))
+
+        logger.info(format_timed(
+                start_time,
+                "processed %d assets." % record.current.processed_count))
+
+        return record.detach()
+
+    def _process(self, src_dir_or_file, pool, record):
+        expected_result_count = 0
+
+        if src_dir_or_file is not None:
+            # Process only the given path.
+            # Find out what mount point this is in.
+            for name, info in self.mounts.items():
+                path = info['path']
+                if src_dir_or_file[:len(path)] == path:
+                    base_dir = path
+                    mount_info = info
+                    break
+            else:
+                known_roots = [i['path'] for i in self.mounts.values()]
+                raise Exception("Input path '%s' is not part of any known "
+                                "mount point: %s" %
+                                (src_dir_or_file, known_roots))
+
+            ctx = _ProcessingContext(pool, record, base_dir, mount_info)
+            logger.debug("Initiating processing pipeline on: %s" %
+                         src_dir_or_file)
+            if os.path.isdir(src_dir_or_file):
+                expected_result_count = self._processDirectory(
+                        ctx, src_dir_or_file)
+            elif os.path.isfile(src_dir_or_file):
+                self._processFile(ctx, src_dir_or_file)
+                expected_result_count = 1
+
+        else:
+            # Process everything.
+            for name, info in self.mounts.items():
+                path = info['path']
+                ctx = _ProcessingContext(pool, record, path, info)
+                logger.debug("Initiating processing pipeline on: %s" % path)
+                expected_result_count = self._processDirectory(ctx, path)
+
+        return expected_result_count
+
+    def _processDirectory(self, ctx, start_dir):
+        queued_count = 0
+        for dirpath, dirnames, filenames in os.walk(start_dir):
+            rel_dirpath = os.path.relpath(dirpath, start_dir)
+            dirnames[:] = [d for d in dirnames
+                           if not re_matchany(
+                               d, self.ignore_patterns, rel_dirpath)]
+
+            for filename in filenames:
+                if re_matchany(filename, self.ignore_patterns, rel_dirpath):
+                    continue
+                self._processFile(ctx, os.path.join(dirpath, filename))
+                queued_count += 1
+        return queued_count
+
+    def _processFile(self, ctx, path):
+        # TODO: handle overrides between mount-points.
+
+        entry = ProcessorPipelineRecordEntry(path)
+        ctx.record.addEntry(entry)
+
+        previous_entry = ctx.record.getPreviousEntry(path)
+        force_this = (self.force or previous_entry is None or
+                      not previous_entry.was_processed_successfully)
+
+        job = ProcessingWorkerJob(ctx.base_dir, ctx.mount_info, path,
+                                  force=force_this)
+
+        logger.debug("Queuing: %s" % path)
+        ctx.pool.queue.put_nowait(job)
+
+    def _createWorkerPool(self):
+        pool = _WorkerPool()
+        for i in range(self.num_workers):
+            ctx = ProcessingWorkerContext(
+                    self.app.root_dir, self.out_dir, self.tmp_dir,
+                    pool.queue, pool.results, pool.abort_event,
+                    self.force, self.app.debug)
+            ctx.enabled_processors = self.enabled_processors
+            ctx.additional_processors = self.additional_processors
+            w = multiprocessing.Process(
+                    name='Worker_%d' % i,
+                    target=worker_func, args=(i, ctx))
+            w.start()
+            pool.workers.append(w)
+        return pool
+
+    def _waitOnWorkerPool(self, pool, expected_result_count, result_handler):
+        abort_with_exception = None
+        try:
+            got_count = 0
+            while got_count < expected_result_count:
+                try:
+                    res = pool.results.get(True, 10)
+                except queue.Empty:
+                    logger.error(
+                            "Got %d results, expected %d, and timed-out "
+                            "for 10 seconds. A worker might be stuck?" %
+                            (got_count, expected_result_count))
+                    abort_with_exception = Exception("Worker time-out.")
+                    break
+
+                if isinstance(res, dict) and res.get('type') == 'error':
+                    abort_with_exception = Exception(
+                            'Worker critical error:\n' +
+                            '\n'.join(res['messages']))
+                    break
+
+                got_count += 1
+                result_handler(res)
+        except KeyboardInterrupt as kiex:
+            logger.warning("Bake aborted by user... "
+                           "waiting for workers to stop.")
+            abort_with_exception = kiex
+
+        if abort_with_exception:
+            pool.abort_event.set()
+            for w in pool.workers:
+                w.join(2)
+            raise abort_with_exception
+
+    def _terminateWorkerPool(self, pool):
+        pool.abort_event.set()
+        for w in pool.workers:
+            w.join()
+
+
+class _WorkerPool(object):
+    def __init__(self):
+        self.queue = multiprocessing.JoinableQueue()
+        self.results = multiprocessing.Queue()
+        self.abort_event = multiprocessing.Event()
+        self.workers = []
+
+
+def make_mount_infos(mounts, root_dir):
+    if isinstance(mounts, list):
+        mounts = {m: {} for m in mounts}
+
+    for name, info in mounts.items():
+        if not isinstance(info, dict):
+            raise Exception("Asset directory info for '%s' is not a "
+                            "dictionary." % name)
+        info.setdefault('processors', 'all -uglifyjs -cleancss')
+        info['path'] = os.path.join(root_dir, name)
+
+    return mounts
+
+
+def make_re(patterns):
+    re_patterns = []
+    for pat in patterns:
+        if pat[0] == '/' and pat[-1] == '/' and len(pat) > 2:
+            re_patterns.append(pat[1:-1])
+        else:
+            escaped_pat = (
+                    re.escape(pat)
+                    .replace(r'\*', r'[^/\\]*')
+                    .replace(r'\?', r'[^/\\]'))
+            re_patterns.append(escaped_pat)
+    return [re.compile(p) for p in re_patterns]
+
+
+def re_matchany(filename, patterns, dirname=None):
+    if dirname and dirname != '.':
+        filename = os.path.join(dirname, filename)
+
+    # skip patterns use a forward slash regardless of the platform.
+    filename = filename.replace('\\', '/')
+    for pattern in patterns:
+        if pattern.search(filename):
+            return True
+    return False
+