diff piecrust/baking/baker.py @ 854:08e02c2a2a1a

core: Keep refactoring, this time to prepare for generator sources. - Make a few APIs simpler. - Content pipelines create their own jobs, so that generator sources can keep aborting in `getContents`, but rely on their pipeline to generate pages for baking.
author Ludovic Chabant <ludovic@chabant.com>
date Sun, 04 Jun 2017 23:34:28 -0700
parents f070a4fc033c
children 448710d84121
line wrap: on
line diff
--- a/piecrust/baking/baker.py	Sun May 21 00:06:59 2017 -0700
+++ b/piecrust/baking/baker.py	Sun Jun 04 23:34:28 2017 -0700
@@ -2,11 +2,12 @@
 import os.path
 import hashlib
 import logging
-from piecrust.baking.worker import BakeJob
 from piecrust.chefutil import (
     format_timed_scope, format_timed)
 from piecrust.environment import ExecutionStats
-from piecrust.pipelines.base import PipelineContext
+from piecrust.pipelines.base import (
+    PipelineMergeRecordContext, PipelineManager,
+    get_pipeline_name_for_source)
 from piecrust.pipelines.records import (
     MultiRecordHistory, MultiRecord, RecordEntry,
     load_records)
@@ -31,16 +32,10 @@
         self.out_dir = out_dir
         self.force = force
 
-        self._pipeline_classes = {}
-        for pclass in app.plugin_loader.getPipelines():
-            self._pipeline_classes[pclass.PIPELINE_NAME] = pclass
-
         self.allowed_pipelines = allowed_pipelines
         if allowed_pipelines is None:
             self.allowed_pipelines = list(self._pipeline_classes.keys())
 
-        self._records = None
-
     def bake(self):
         start_time = time.perf_counter()
         logger.debug("  Bake Output: %s" % self.out_dir)
@@ -63,18 +58,19 @@
                 previous_records = load_records(records_path)
         else:
             previous_records = MultiRecord()
-        self._records = MultiRecord()
+        current_records = MultiRecord()
 
         # Figure out if we need to clean the cache because important things
         # have changed.
         is_cache_valid = self._handleCacheValidity(previous_records,
-                                                   self._records)
+                                                   current_records)
         if not is_cache_valid:
             previous_records = MultiRecord()
 
         # Create the bake records history which tracks what's up-to-date
         # or not since last time we baked to the given output folder.
-        record_histories = MultiRecordHistory(previous_records, self._records)
+        record_histories = MultiRecordHistory(
+            previous_records, current_records)
 
         # Pre-create all caches.
         for cache_name in ['app', 'baker', 'pages', 'renders']:
@@ -86,49 +82,54 @@
         # realm).
         #
         # Also, create and initialize each pipeline for each source.
-        sources_by_realm = {}
+        has_any_pp = False
+        ppmngr = PipelineManager(
+            self.app, self.out_dir, record_histories)
         for source in self.app.sources:
-            pname = source.config['pipeline']
+            pname = get_pipeline_name_for_source(source)
             if pname in self.allowed_pipelines:
-                srclist = sources_by_realm.setdefault(
-                    source.config['realm'], [])
-
-                pp = self._pipeline_classes[pname](source)
-
-                record_name = _get_record_name(source.name, pname)
-                record_history = record_histories.getHistory(record_name)
-                ppctx = PipelineContext(self.out_dir, record_history,
-                                        force=self.force)
-                pp.initialize(ppctx)
-
-                srclist.append((source, pp, ppctx))
+                ppinfo = ppmngr.createPipeline(source)
+                logger.debug(
+                    "Created pipeline '%s' for source: %s" %
+                    (ppinfo.pipeline.PIPELINE_NAME, source.name))
+                has_any_pp = True
             else:
                 logger.debug(
                     "Skip source '%s' because pipeline '%s' is ignored." %
                     (source.name, pname))
+        if not has_any_pp:
+            raise Exception("The website has no content sources, or the bake "
+                            "command was invoked with all pipelines filtered "
+                            "out. There's nothing to do.")
 
         # Create the worker processes.
-        pool = self._createWorkerPool(records_path)
+        pool_userdata = _PoolUserData(self, ppmngr, current_records)
+        pool = self._createWorkerPool(records_path, pool_userdata)
+        realm_list = [REALM_USER, REALM_THEME]
 
         # Bake the realms -- user first, theme second, so that a user item
         # can override a theme item.
-        realm_list = [REALM_USER, REALM_THEME]
-        for realm in realm_list:
-            srclist = sources_by_realm.get(realm)
-            if srclist is not None:
-                self._bakeRealm(pool, srclist)
+        # Do this for as many times as we have pipeline passes left to do.
+        pp_by_pass_and_realm = {}
+        for ppinfo in ppmngr.getPipelines():
+            pp_by_realm = pp_by_pass_and_realm.setdefault(
+                ppinfo.pipeline.PASS_NUM, {})
+            pplist = pp_by_realm.setdefault(
+                ppinfo.pipeline.source.config['realm'], [])
+            pplist.append(ppinfo)
 
-        # Handle deletions.
-        for realm in realm_list:
-            srclist = sources_by_realm.get(realm)
-            if srclist is not None:
-                self._deleteStaleOutputs(pool, srclist)
+        for pp_pass in sorted(pp_by_pass_and_realm.keys()):
+            logger.debug("Pipelines pass %d" % pp_pass)
+            pp_by_realm = pp_by_pass_and_realm[pp_pass]
+            for realm in realm_list:
+                pplist = pp_by_realm.get(realm)
+                if pplist is not None:
+                    self._bakeRealm(pool, pplist)
 
-        # Collapse records.
-        for realm in realm_list:
-            srclist = sources_by_realm.get(realm)
-            if srclist is not None:
-                self._collapseRecords(srclist)
+        # Handle deletions, collapse records, etc.
+        ppmngr.buildHistoryDiffs()
+        ppmngr.deleteStaleOutputs()
+        ppmngr.collapseRecords()
 
         # All done with the workers. Close the pool and get reports.
         pool_stats = pool.close()
@@ -136,14 +137,10 @@
         for ps in pool_stats:
             if ps is not None:
                 total_stats.mergeStats(ps)
-        record_histories.current.stats = total_stats
+        current_records.stats = total_stats
 
         # Shutdown the pipelines.
-        for realm in realm_list:
-            srclist = sources_by_realm.get(realm)
-            if srclist is not None:
-                for _, pp, ppctx in srclist:
-                    pp.shutdown(ppctx)
+        ppmngr.shutdownPipelines()
 
         # Backup previous records.
         records_dir, records_fn = os.path.split(records_path)
@@ -164,16 +161,15 @@
         # Save the bake records.
         with format_timed_scope(logger, "saved bake records.",
                                 level=logging.DEBUG, colored=False):
-            record_histories.current.bake_time = time.time()
-            record_histories.current.out_dir = self.out_dir
-            record_histories.current.save(records_path)
+            current_records.bake_time = time.time()
+            current_records.out_dir = self.out_dir
+            current_records.save(records_path)
 
         # All done.
         self.app.config.set('baker/is_baking', False)
         logger.debug(format_timed(start_time, 'done baking'))
 
-        self._records = None
-        return record_histories.current
+        return current_records
 
     def _handleCacheValidity(self, previous_records, current_records):
         start_time = time.perf_counter()
@@ -216,40 +212,58 @@
                 start_time, "cache is assumed valid", colored=False))
             return True
 
-    def _bakeRealm(self, pool, srclist):
-        for source, pp, ppctx in srclist:
-            logger.debug("Queuing jobs for source '%s' using pipeline '%s'." %
-                         (source.name, pp.PIPELINE_NAME))
-            jobs = [BakeJob(source.name, item.spec, item.metadata)
-                    for item in source.getAllContents()]
+    def _bakeRealm(self, pool, pplist):
+        # Start with the first pass, where we iterate on the content sources'
+        # items and run jobs on those.
+        pool.userdata.cur_pass = 0
+        next_pass_jobs = {}
+        pool.userdata.next_pass_jobs = next_pass_jobs
+        for ppinfo in pplist:
+            src = ppinfo.source
+            pp = ppinfo.pipeline
+
+            logger.debug(
+                "Queuing jobs for source '%s' using pipeline '%s' (pass 0)." %
+                (src.name, pp.PIPELINE_NAME))
+
+            next_pass_jobs[src.name] = []
+            jobs = pp.createJobs()
             pool.queueJobs(jobs)
         pool.wait()
 
-    def _deleteStaleOutputs(self, pool, srclist):
-        for source, pp, ppctx in srclist:
-            ppctx.record_history.build()
+        # Now let's see if any job created a follow-up job. Let's keep
+        # processing those jobs as long as they create new ones.
+        pool.userdata.cur_pass = 1
+        while True:
+            had_any_job = False
+
+            # Make a copy of out next pass jobs and reset the list, so
+            # the first jobs to be processed don't mess it up as we're
+            # still iterating on it.
+            next_pass_jobs = pool.userdata.next_pass_jobs
+            pool.userdata.next_pass_jobs = {}
 
-            to_delete = pp.getDeletions(ppctx)
-            if to_delete is not None:
-                for path, reason in to_delete:
-                    logger.debug("Removing '%s': %s" % (path, reason))
-                    ppctx.current_record.deleted_out_paths.append(path)
-                    try:
-                        os.remove(path)
-                    except FileNotFoundError:
-                        pass
-                    logger.info('[delete] %s' % path)
+            for sn, jobs in next_pass_jobs.items():
+                if jobs:
+                    logger.debug(
+                        "Queuing jobs for source '%s' (pass %d)." %
+                        (sn, pool.userdata.cur_pass))
+                    pool.userdata.next_pass_jobs[sn] = []
+                    pool.queueJobs(jobs)
+                    had_any_job = True
 
-    def _collapseRecords(self, srclist):
-        for source, pp, ppctx in srclist:
-            pp.collapseRecords(ppctx)
+            if not had_any_job:
+                break
+
+            pool.wait()
+            pool.userdata.cur_pass += 1
 
     def _logErrors(self, item_spec, errors):
         logger.error("Errors found in %s:" % item_spec)
         for e in errors:
             logger.error("  " + e)
 
-    def _createWorkerPool(self, previous_records_path):
+    def _createWorkerPool(self, previous_records_path, pool_userdata):
         from piecrust.workerpool import WorkerPool
         from piecrust.baking.worker import BakeWorkerContext, BakeWorker
 
@@ -268,36 +282,59 @@
             worker_class=BakeWorker,
             initargs=(ctx,),
             callback=self._handleWorkerResult,
-            error_callback=self._handleWorkerError)
+            error_callback=self._handleWorkerError,
+            userdata=pool_userdata)
         return pool
 
-    def _handleWorkerResult(self, job, res):
-        record_name = _get_record_name(job.source_name, res.pipeline_name)
-        record = self._records.getRecord(record_name)
-        record.entries.append(res.record_entry)
+    def _handleWorkerResult(self, job, res, userdata):
+        cur_pass = userdata.cur_pass
+        record = userdata.records.getRecord(job.record_name)
+
+        if cur_pass == 0:
+            record.addEntry(res.record_entry)
+        else:
+            ppinfo = userdata.ppmngr.getPipeline(job.source_name)
+            ppmrctx = PipelineMergeRecordContext(
+                record, job, cur_pass)
+            ppinfo.pipeline.mergeRecordEntry(res.record_entry, ppmrctx)
+
+        npj = res.next_pass_job
+        if npj is not None:
+            npj.data['pass'] = cur_pass + 1
+            userdata.next_pass_jobs[job.source_name].append(npj)
+
         if not res.record_entry.success:
             record.success = False
-            self._records.success = False
-            self._logErrors(job.item_spec, res.record_entry.errors)
+            userdata.records.success = False
+            self._logErrors(job.content_item.spec, res.record_entry.errors)
+
+    def _handleWorkerError(self, job, exc_data, userdata):
+        cur_pass = userdata.cur_pass
+        record = userdata.records.getRecord(job.record_name)
 
-    def _handleWorkerError(self, job, exc_data):
-        e = RecordEntry()
-        e.item_spec = job.item_spec
-        e.errors.append(str(exc_data))
-
-        ppname = self.app.getSource(job.source_name).config['pipeline']
-        record_name = _get_record_name(job.source_name, ppname)
-        record_name = self._getRecordName(job)
-        record = self._records.getRecord(record_name)
-        record.entries.append(e)
+        if cur_pass == 0:
+            ppinfo = userdata.ppmngr.getPipeline(job.source_name)
+            entry_class = ppinfo.pipeline.RECORD_ENTRY_CLASS or RecordEntry
+            e = entry_class()
+            e.item_spec = job.content_item.spec
+            e.errors.append(str(exc_data))
+            record.addEntry(e)
+        else:
+            e = record.getEntry(job.content_item.spec)
+            e.errors.append(str(exc_data))
 
         record.success = False
-        self._records.success = False
+        userdata.records.success = False
 
-        self._logErrors(job.item_spec, e.errors)
+        self._logErrors(job.content_item.spec, e.errors)
         if self.app.debug:
             logger.error(exc_data.traceback)
 
 
-def _get_record_name(source_name, pipeline_name):
-    return '%s@%s' % (source_name, pipeline_name)
+class _PoolUserData:
+    def __init__(self, baker, ppmngr, current_records):
+        self.baker = baker
+        self.ppmngr = ppmngr
+        self.records = current_records
+        self.cur_pass = 0
+        self.next_pass_jobs = {}