diff piecrust/baking/baker.py @ 411:e7b865f8f335

bake: Enable multiprocess baking. Baking is now done by running a worker per CPU, and sending jobs to them. This changes several things across the codebase: * Ability to not cache things related to pages other than the 'main' page (i.e. the page at the bottom of the execution stack). * Decouple the baking process from the bake records, so only the main process keeps track (and modifies) the bake record. * Remove the need for 'batch page getters' and loading a page directly from the page factories. There are various smaller changes too included here, including support for scope performance timers that are saved with the bake record and can be printed out to the console. Yes I got carried away. For testing, the in-memory 'mock' file-system doesn't work anymore, since we're spawning processes, so this is replaced by a 'tmpfs' file-system which is saved in temporary files on disk and deleted after tests have run.
author Ludovic Chabant <ludovic@chabant.com>
date Fri, 12 Jun 2015 17:09:19 -0700
parents c2ca72fb7f0b
children 0e9a94b7fdfa
line wrap: on
line diff
--- a/piecrust/baking/baker.py	Sat May 30 15:41:52 2015 -0700
+++ b/piecrust/baking/baker.py	Fri Jun 12 17:09:19 2015 -0700
@@ -1,13 +1,18 @@
+import copy
 import time
 import os.path
+import queue
 import hashlib
 import logging
-import threading
+import multiprocessing
 from piecrust.baking.records import (
-        TransitionalBakeRecord, BakeRecordPageEntry)
-from piecrust.baking.scheduler import BakeScheduler
-from piecrust.baking.single import (BakingError, PageBaker)
-from piecrust.chefutil import format_timed, log_friendly_exception
+        BakeRecordEntry, TransitionalBakeRecord, TaxonomyInfo, FirstRenderInfo)
+from piecrust.baking.worker import (
+        BakeWorkerJob, LoadJobPayload, RenderFirstSubJobPayload,
+        BakeJobPayload,
+        JOB_LOAD, JOB_RENDER_FIRST, JOB_BAKE)
+from piecrust.chefutil import (
+        format_timed_scope, format_timed)
 from piecrust.sources.base import (
         REALM_NAMES, REALM_USER, REALM_THEME)
 
@@ -21,7 +26,8 @@
         self.app = app
         self.out_dir = out_dir
         self.force = force
-        self.num_workers = app.config.get('baker/workers', 4)
+        self.num_workers = app.config.get('baker/workers',
+                                          multiprocessing.cpu_count())
 
         # Remember what taxonomy pages we should skip
         # (we'll bake them repeatedly later with each taxonomy term)
@@ -29,8 +35,8 @@
         logger.debug("Gathering taxonomy page paths:")
         for tax in self.app.taxonomies:
             for src in self.app.sources:
-                path = tax.resolvePagePath(src.name)
-                if path is not None:
+                tax_page_ref = tax.getPageRef(src)
+                for path in tax_page_ref.possible_paths:
                     self.taxonomy_pages.append(path)
                     logger.debug(" - %s" % path)
 
@@ -39,7 +45,7 @@
         logger.debug("  Root URL: %s" % self.app.config.get('site/root'))
 
         # Get into bake mode.
-        start_time = time.clock()
+        start_time = time.perf_counter()
         self.app.config.set('baker/is_baking', True)
         self.app.env.base_asset_url_format = '%uri%'
 
@@ -52,35 +58,59 @@
         record_cache = self.app.cache.getCache('baker')
         record_id = hashlib.md5(self.out_dir.encode('utf8')).hexdigest()
         record_name = record_id + '.record'
+        previous_record_path = None
         if not self.force and record_cache.has(record_name):
-            t = time.clock()
-            record.loadPrevious(record_cache.getCachePath(record_name))
-            logger.debug(format_timed(
-                    t, 'loaded previous bake record',
-                    colored=False))
+            with format_timed_scope(logger, "loaded previous bake record",
+                                    level=logging.DEBUG, colored=False):
+                previous_record_path = record_cache.getCachePath(record_name)
+                record.loadPrevious(previous_record_path)
         record.current.success = True
 
         # Figure out if we need to clean the cache because important things
         # have changed.
         self._handleCacheValidity(record)
 
+        # Pre-create all caches.
+        for cache_name in ['app', 'baker', 'pages', 'renders']:
+            self.app.cache.getCache(cache_name)
+
         # Gather all sources by realm -- we're going to bake each realm
-        # separately so we can handle "overlaying" (i.e. one realm overrides
-        # another realm's pages).
+        # separately so we can handle "overriding" (i.e. one realm overrides
+        # another realm's pages, like the user realm overriding the theme
+        # realm).
         sources_by_realm = {}
         for source in self.app.sources:
             srclist = sources_by_realm.setdefault(source.realm, [])
             srclist.append(source)
 
+        # Create the worker processes.
+        pool = self._createWorkerPool()
+
         # Bake the realms.
         realm_list = [REALM_USER, REALM_THEME]
         for realm in realm_list:
             srclist = sources_by_realm.get(realm)
             if srclist is not None:
-                self._bakeRealm(record, realm, srclist)
+                self._bakeRealm(record, pool, realm, srclist)
 
         # Bake taxonomies.
-        self._bakeTaxonomies(record)
+        self._bakeTaxonomies(record, pool)
+
+        # All done with the workers.
+        self._terminateWorkerPool(pool)
+
+        # Get the timing information from the workers.
+        record.current.timers = {}
+        for _ in range(len(pool.workers)):
+            try:
+                timers = pool.results.get(True, 0.1)
+            except queue.Empty:
+                logger.error("Didn't get timing information from all workers.")
+                break
+
+            for name, val in timers.items():
+                main_val = record.current.timers.setdefault(name, 0)
+                record.current.timers[name] = main_val + val
 
         # Delete files from the output.
         self._handleDeletetions(record)
@@ -98,11 +128,11 @@
                 os.rename(record_path, record_path_next)
 
         # Save the bake record.
-        t = time.clock()
-        record.current.bake_time = time.time()
-        record.current.out_dir = self.out_dir
-        record.saveCurrent(record_cache.getCachePath(record_name))
-        logger.debug(format_timed(t, 'saved bake record', colored=False))
+        with format_timed_scope(logger, "saved bake record.",
+                                level=logging.DEBUG, colored=False):
+            record.current.bake_time = time.time()
+            record.current.out_dir = self.out_dir
+            record.saveCurrent(record_cache.getCachePath(record_name))
 
         # All done.
         self.app.config.set('baker/is_baking', False)
@@ -111,7 +141,7 @@
         return record.detach()
 
     def _handleCacheValidity(self, record):
-        start_time = time.clock()
+        start_time = time.perf_counter()
 
         reason = None
         if self.force:
@@ -152,41 +182,138 @@
                     start_time, "cache is assumed valid",
                     colored=False))
 
-    def _bakeRealm(self, record, realm, srclist):
-        # Gather all page factories from the sources and queue them
-        # for the workers to pick up. Just skip taxonomy pages for now.
-        logger.debug("Baking realm %s" % REALM_NAMES[realm])
-        pool, queue, abort = self._createWorkerPool(record, self.num_workers)
+    def _bakeRealm(self, record, pool, realm, srclist):
+        start_time = time.perf_counter()
+        try:
+            all_factories = []
+            for source in srclist:
+                factories = source.getPageFactories()
+                all_factories += [f for f in factories
+                                  if f.path not in self.taxonomy_pages]
+
+            self._loadRealmPages(record, pool, all_factories)
+            self._renderRealmPages(record, pool, all_factories)
+            self._bakeRealmPages(record, pool, all_factories)
+        finally:
+            page_count = len(all_factories)
+            logger.info(format_timed(
+                    start_time,
+                    "baked %d %s pages" %
+                    (page_count, REALM_NAMES[realm].lower())))
+
+    def _loadRealmPages(self, record, pool, factories):
+        with format_timed_scope(logger,
+                                "loaded %d pages" % len(factories),
+                                level=logging.DEBUG, colored=False):
+            for fac in factories:
+                job = BakeWorkerJob(
+                        JOB_LOAD,
+                        LoadJobPayload(fac))
+                pool.queue.put_nowait(job)
 
-        for source in srclist:
-            factories = source.getPageFactories()
+            def _handler(res):
+                # Create the record entry for this page.
+                record_entry = BakeRecordEntry(res.source_name, res.path)
+                record_entry.config = res.config
+                if res.errors:
+                    record_entry.errors += res.errors
+                    record.current.success = False
+                record.addEntry(record_entry)
+
+            self._waitOnWorkerPool(
+                    pool,
+                    expected_result_count=len(factories),
+                    result_handler=_handler)
+
+    def _renderRealmPages(self, record, pool, factories):
+        with format_timed_scope(logger,
+                                "prepared %d pages" % len(factories),
+                                level=logging.DEBUG, colored=False):
+            expected_result_count = 0
             for fac in factories:
-                if fac.path in self.taxonomy_pages:
-                    logger.debug(
-                            "Skipping taxonomy page: %s:%s" %
-                            (source.name, fac.ref_spec))
+                record_entry = record.getCurrentEntry(fac.path)
+                if record_entry.errors:
+                    logger.debug("Ignoring %s because it had previous "
+                                 "errors." % fac.ref_spec)
+                    continue
+
+                # Make sure the source and the route exist for this page,
+                # otherwise we add errors to the record entry and we'll skip
+                # this page for the rest of the bake.
+                source = self.app.getSource(fac.source.name)
+                if source is None:
+                    record_entry.errors.append(
+                            "Can't get source for page: %s" % fac.ref_spec)
+                    logger.error(record_entry.errors[-1])
                     continue
 
-                entry = BakeRecordPageEntry(fac.source.name, fac.rel_path,
-                                            fac.path)
-                record.addEntry(entry)
-
-                route = self.app.getRoute(source.name, fac.metadata,
+                route = self.app.getRoute(fac.source.name, fac.metadata,
                                           skip_taxonomies=True)
                 if route is None:
-                    entry.errors.append(
+                    record_entry.errors.append(
                             "Can't get route for page: %s" % fac.ref_spec)
-                    logger.error(entry.errors[-1])
+                    logger.error(record_entry.errors[-1])
                     continue
 
-                queue.addJob(BakeWorkerJob(fac, route, entry))
+                # All good, queue the job.
+                job = BakeWorkerJob(
+                        JOB_RENDER_FIRST,
+                        RenderFirstSubJobPayload(fac))
+                pool.queue.put_nowait(job)
+                expected_result_count += 1
+
+            def _handler(res):
+                entry = record.getCurrentEntry(res.path)
+
+                entry.first_render_info = FirstRenderInfo()
+                entry.first_render_info.used_assets = res.used_assets
+                entry.first_render_info.used_pagination = \
+                    res.used_pagination
+                entry.first_render_info.pagination_has_more = \
+                    res.pagination_has_more
+
+                if res.errors:
+                    entry.errors += res.errors
+                    record.current.success = False
+
+            self._waitOnWorkerPool(
+                    pool,
+                    expected_result_count=expected_result_count,
+                    result_handler=_handler)
 
-        success = self._waitOnWorkerPool(pool, abort)
-        record.current.success &= success
+    def _bakeRealmPages(self, record, pool, factories):
+        with format_timed_scope(logger,
+                                "baked %d pages" % len(factories),
+                                level=logging.DEBUG, colored=False):
+            expected_result_count = 0
+            for fac in factories:
+                if self._queueBakeJob(record, pool, fac):
+                    expected_result_count += 1
+
+            def _handler(res):
+                entry = record.getCurrentEntry(res.path, res.taxonomy_info)
+                entry.bake_info = res.bake_info
+                if res.errors:
+                    entry.errors += res.errors
+                if entry.has_any_error:
+                    record.current.success = False
 
-    def _bakeTaxonomies(self, record):
-        logger.debug("Baking taxonomies")
+            self._waitOnWorkerPool(
+                    pool,
+                    expected_result_count=expected_result_count,
+                    result_handler=_handler)
 
+    def _bakeTaxonomies(self, record, pool):
+        with format_timed_scope(logger, 'built taxonomy buckets',
+                                level=logging.DEBUG, colored=False):
+            buckets = self._buildTaxonomyBuckets(record)
+
+        start_time = time.perf_counter()
+        page_count = self._bakeTaxonomyBuckets(record, pool, buckets)
+        logger.info(format_timed(start_time,
+                                 "baked %d taxonomy pages." % page_count))
+
+    def _buildTaxonomyBuckets(self, record):
         # Let's see all the taxonomy terms for which we must bake a
         # listing page... first, pre-populate our big map of used terms.
         # For each source name, we have a list of taxonomies, and for each
@@ -250,8 +377,11 @@
             if not tt_info.dirty_terms.isdisjoint(set(terms)):
                 tt_info.dirty_terms.add(terms)
 
+        return buckets
+
+    def _bakeTaxonomyBuckets(self, record, pool, buckets):
         # Start baking those terms.
-        pool, queue, abort = self._createWorkerPool(record, self.num_workers)
+        expected_result_count = 0
         for source_name, source_taxonomies in buckets.items():
             for tax_name, tt_info in source_taxonomies.items():
                 terms = tt_info.dirty_terms
@@ -262,8 +392,8 @@
                         "Baking '%s' for source '%s': %s" %
                         (tax_name, source_name, terms))
                 tax = self.app.getTaxonomy(tax_name)
-                route = self.app.getTaxonomyRoute(tax_name, source_name)
-                tax_page_ref = tax.getPageRef(source_name)
+                source = self.app.getSource(source_name)
+                tax_page_ref = tax.getPageRef(source)
                 if not tax_page_ref.exists:
                     logger.debug(
                             "No taxonomy page found at '%s', skipping." %
@@ -273,19 +403,33 @@
                 logger.debug(
                         "Using taxonomy page: %s:%s" %
                         (tax_page_ref.source_name, tax_page_ref.rel_path))
+                fac = tax_page_ref.getFactory()
+
                 for term in terms:
-                    fac = tax_page_ref.getFactory()
                     logger.debug(
                             "Queuing: %s [%s=%s]" %
                             (fac.ref_spec, tax_name, term))
-                    entry = BakeRecordPageEntry(
-                            fac.source.name, fac.rel_path, fac.path,
-                            (tax_name, term, source_name))
-                    record.addEntry(entry)
-                    queue.addJob(BakeWorkerJob(fac, route, entry))
+                    tax_info = TaxonomyInfo(tax_name, source_name, term)
+
+                    cur_entry = BakeRecordEntry(
+                            fac.source.name, fac.path, tax_info)
+                    record.addEntry(cur_entry)
+
+                    if self._queueBakeJob(record, pool, fac, tax_info):
+                        expected_result_count += 1
 
-        success = self._waitOnWorkerPool(pool, abort)
-        record.current.success &= success
+        def _handler(res):
+            entry = record.getCurrentEntry(res.path, res.taxonomy_info)
+            entry.bake_info = res.bake_info
+            if res.errors:
+                entry.errors += res.errors
+            if entry.has_any_error:
+                record.current.success = False
+
+        self._waitOnWorkerPool(
+                pool,
+                expected_result_count=expected_result_count,
+                result_handler=_handler)
 
         # Now we create bake entries for all the terms that were *not* dirty.
         # This is because otherwise, on the next incremental bake, we wouldn't
@@ -296,16 +440,71 @@
             # current version.
             if (prev_entry and prev_entry.taxonomy_info and
                     not cur_entry):
-                sn = prev_entry.source_name
-                tn, tt, tsn = prev_entry.taxonomy_info
-                tt_info = buckets[tsn][tn]
-                if tt in tt_info.all_terms:
+                ti = prev_entry.taxonomy_info
+                tt_info = buckets[ti.source_name][ti.taxonomy_name]
+                if ti.term in tt_info.all_terms:
                     logger.debug("Creating unbaked entry for taxonomy "
-                                 "term '%s:%s'." % (tn, tt))
+                                 "term '%s:%s'." % (ti.taxonomy_name, ti.term))
                     record.collapseEntry(prev_entry)
                 else:
                     logger.debug("Taxonomy term '%s:%s' isn't used anymore." %
-                                 (tn, tt))
+                                 (ti.taxonomy_name, ti.term))
+
+        return expected_result_count
+
+    def _queueBakeJob(self, record, pool, fac, tax_info=None):
+        # Get the previous (if any) and current entry for this page.
+        pair = record.getPreviousAndCurrentEntries(fac.path, tax_info)
+        assert pair is not None
+        prev_entry, cur_entry = pair
+        assert cur_entry is not None
+
+        # Ignore if there were errors in the previous passes.
+        if cur_entry.errors:
+            logger.debug("Ignoring %s because it had previous "
+                         "errors." % fac.ref_spec)
+            return False
+
+        # Build the route metadata and find the appropriate route.
+        route_metadata = copy.deepcopy(fac.metadata)
+        if tax_info is not None:
+            tax = self.app.getTaxonomy(tax_info.taxonomy_name)
+            route = self.app.getTaxonomyRoute(tax_info.taxonomy_name,
+                                              tax_info.source_name)
+
+            slugified_term = route.slugifyTaxonomyTerm(tax_info.term)
+            route_metadata[tax.term_name] = slugified_term
+        else:
+            route = self.app.getRoute(fac.source.name, route_metadata,
+                                      skip_taxonomies=True)
+        assert route is not None
+
+        # Figure out if this page is overriden by another previously
+        # baked page. This happens for example when the user has
+        # made a page that has the same page/URL as a theme page.
+        page = fac.buildPage()
+        uri = route.getUri(route_metadata, provider=page)
+        override_entry = record.getOverrideEntry(page.path, uri)
+        if override_entry is not None:
+            override_source = self.app.getSource(
+                    override_entry.source_name)
+            if override_source.realm == fac.source.realm:
+                cur_entry.errors.append(
+                        "Page '%s' maps to URL '%s' but is overriden "
+                        "by page '%s'." %
+                        (fac.ref_spec, uri, override_entry.path))
+                logger.error(cur_entry.errors[-1])
+            cur_entry.flags |= BakeRecordEntry.FLAG_OVERRIDEN
+            return False
+
+        job = BakeWorkerJob(
+                JOB_BAKE,
+                BakeJobPayload(fac, route_metadata, prev_entry,
+                               cur_entry.first_render_info,
+                               record.dirty_source_names,
+                               tax_info))
+        pool.queue.put_nowait(job)
+        return True
 
     def _handleDeletetions(self, record):
         for path, reason in record.getDeletions():
@@ -318,139 +517,66 @@
                 # by the user.
                 pass
 
-    def _createWorkerPool(self, record, pool_size=4):
-        pool = []
-        queue = BakeScheduler(record)
-        abort = threading.Event()
-        for i in range(pool_size):
+    def _createWorkerPool(self):
+        from piecrust.baking.worker import BakeWorkerContext, worker_func
+
+        pool = _WorkerPool()
+        for i in range(self.num_workers):
             ctx = BakeWorkerContext(
-                    self.app, self.out_dir, self.force,
-                    record, queue, abort)
-            worker = BakeWorker(i, ctx)
-            pool.append(worker)
-        return pool, queue, abort
+                    self.app.root_dir, self.out_dir,
+                    pool.queue, pool.results, pool.abort_event,
+                    force=self.force, debug=self.app.debug)
+            w = multiprocessing.Process(
+                    target=worker_func, args=(i, ctx))
+            w.start()
+            pool.workers.append(w)
+        return pool
+
+    def _terminateWorkerPool(self, pool):
+        pool.abort_event.set()
+        for w in pool.workers:
+            w.join()
 
-    def _waitOnWorkerPool(self, pool, abort):
-        for w in pool:
-            w.start()
+    def _waitOnWorkerPool(self, pool,
+                          expected_result_count=-1, result_handler=None):
+        assert result_handler is None or expected_result_count >= 0
+        abort_with_exception = None
+        try:
+            if result_handler is None:
+                pool.queue.join()
+            else:
+                got_count = 0
+                while got_count < expected_result_count:
+                    try:
+                        res = pool.results.get(True, 10)
+                    except queue.Empty:
+                        logger.error(
+                                "Got %d results, expected %d, and timed-out "
+                                "for 10 seconds. A worker might be stuck?" %
+                                (got_count, expected_result_count))
+                        abort_with_exception = Exception("Worker time-out.")
+                        break
 
-        success = True
-        try:
-            for w in pool:
-                w.join()
-                success &= w.success
-        except KeyboardInterrupt:
+                    got_count += 1
+                    result_handler(res)
+        except KeyboardInterrupt as kiex:
             logger.warning("Bake aborted by user... "
                            "waiting for workers to stop.")
-            abort.set()
-            for w in pool:
-                w.join()
-            raise
-
-        if abort.is_set():
-            excs = [w.abort_exception for w in pool
-                    if w.abort_exception is not None]
-            logger.error("Baking was aborted due to %s error(s):" % len(excs))
-            if self.app.debug:
-                for e in excs:
-                    logger.exception(e)
-            else:
-                for e in excs:
-                    log_friendly_exception(logger, e)
-            raise BakingError("Baking was aborted due to errors.")
-
-        return success
-
+            abort_with_exception = kiex
 
-class BakeWorkerContext(object):
-    def __init__(self, app, out_dir, force, record, work_queue,
-                 abort_event):
-        self.app = app
-        self.out_dir = out_dir
-        self.force = force
-        self.record = record
-        self.work_queue = work_queue
-        self.abort_event = abort_event
-
-
-class BakeWorkerJob(object):
-    def __init__(self, factory, route, record_entry):
-        self.factory = factory
-        self.route = route
-        self.record_entry = record_entry
-
-    @property
-    def source(self):
-        return self.factory.source
+        if abort_with_exception:
+            pool.abort_event.set()
+            for w in pool.workers:
+                w.join(2)
+            raise abort_with_exception
 
 
-class BakeWorker(threading.Thread):
-    def __init__(self, wid, ctx):
-        super(BakeWorker, self).__init__(name=('worker%d' % wid))
-        self.wid = wid
-        self.ctx = ctx
-        self.abort_exception = None
-        self.success = True
-        self._page_baker = PageBaker(
-                ctx.app, ctx.out_dir, ctx.force,
-                ctx.record)
-
-    def run(self):
-        while(not self.ctx.abort_event.is_set()):
-            try:
-                job = self.ctx.work_queue.getNextJob(wait_timeout=1)
-                if job is None:
-                    logger.debug(
-                            "[%d] No more work... shutting down." %
-                            self.wid)
-                    break
-                success = self._unsafeRun(job)
-                logger.debug("[%d] Done with page." % self.wid)
-                self.ctx.work_queue.onJobFinished(job)
-                self.success &= success
-            except Exception as ex:
-                self.ctx.abort_event.set()
-                self.abort_exception = ex
-                self.success = False
-                logger.debug("[%d] Critical error, aborting." % self.wid)
-                if self.ctx.app.debug:
-                    logger.exception(ex)
-                break
-
-    def _unsafeRun(self, job):
-        start_time = time.clock()
-
-        entry = job.record_entry
-        try:
-            self._page_baker.bake(job.factory, job.route, entry)
-        except BakingError as ex:
-            logger.debug("Got baking error. Adding it to the record.")
-            while ex:
-                entry.errors.append(str(ex))
-                ex = ex.__cause__
-
-        has_error = False
-        for e in entry.getAllErrors():
-            has_error = True
-            logger.error(e)
-        if has_error:
-            return False
-
-        if entry.was_any_sub_baked:
-            first_sub = entry.subs[0]
-
-            friendly_uri = first_sub.out_uri
-            if friendly_uri == '':
-                friendly_uri = '[main page]'
-
-            friendly_count = ''
-            if entry.num_subs > 1:
-                friendly_count = ' (%d pages)' % entry.num_subs
-            logger.info(format_timed(
-                    start_time, '[%d] %s%s' %
-                    (self.wid, friendly_uri, friendly_count)))
-
-        return True
+class _WorkerPool(object):
+    def __init__(self):
+        self.queue = multiprocessing.JoinableQueue()
+        self.results = multiprocessing.Queue()
+        self.abort_event = multiprocessing.Event()
+        self.workers = []
 
 
 class _TaxonomyTermsInfo(object):
@@ -463,3 +589,4 @@
 
     def __repr__(self):
         return 'dirty:%s, all:%s' % (self.dirty_terms, self.all_terms)
+