diff piecrust/baking/baker.py @ 447:aefe70229fdd

bake: Commonize worker pool code between html and asset baking. The `workerpool` package now defines a generic-ish worker pool. It's similar to the Python framework pool but with a simpler use-case (only one way to queue jobs) and support for workers to send a final "report" to the master process, which we use to get timing information here. The rest of the changes basically remove a whole bunch of duplicated code that's not needed anymore.
author Ludovic Chabant <ludovic@chabant.com>
date Sun, 05 Jul 2015 00:09:41 -0700
parents 21e26ed867b6
children 838f3964f400
line wrap: on
line diff
--- a/piecrust/baking/baker.py	Thu Jul 02 23:28:24 2015 -0700
+++ b/piecrust/baking/baker.py	Sun Jul 05 00:09:41 2015 -0700
@@ -103,17 +103,13 @@
         # Bake taxonomies.
         self._bakeTaxonomies(record, pool)
 
-        # All done with the workers.
-        self._terminateWorkerPool(pool)
-
-        # Get the timing information from the workers.
+        # All done with the workers. Close the pool and get timing reports.
+        reports = pool.close()
         record.current.timers = {}
-        for i in range(len(pool.workers)):
-            try:
-                timers = pool.results.get(True, 0.1)
-            except queue.Empty:
-                logger.error("Didn't get timing information from all workers.")
-                break
+        for i in range(len(reports)):
+            timers = reports[i]
+            if timers is None:
+                continue
 
             worker_name = 'BakeWorker_%d' % i
             record.current.timers[worker_name] = {}
@@ -214,41 +210,43 @@
                     (page_count, REALM_NAMES[realm].lower())))
 
     def _loadRealmPages(self, record, pool, factories):
+        def _handler(res):
+            # Create the record entry for this page.
+            record_entry = BakeRecordEntry(res.source_name, res.path)
+            record_entry.config = res.config
+            if res.errors:
+                record_entry.errors += res.errors
+                record.current.success = False
+                self._logErrors(res.path, res.errors)
+            record.addEntry(record_entry)
+
         logger.debug("Loading %d realm pages..." % len(factories))
         with format_timed_scope(logger,
                                 "loaded %d pages" % len(factories),
                                 level=logging.DEBUG, colored=False,
                                 timer_env=self.app.env,
                                 timer_category='LoadJob'):
-            for fac in factories:
-                job = BakeWorkerJob(
-                        JOB_LOAD,
-                        LoadJobPayload(fac))
-                pool.queue.put_nowait(job)
-
-            def _handler(res):
-                # Create the record entry for this page.
-                record_entry = BakeRecordEntry(res.source_name, res.path)
-                record_entry.config = res.config
-                if res.errors:
-                    record_entry.errors += res.errors
-                    record.current.success = False
-                    self._logErrors(res.path, res.errors)
-                record.addEntry(record_entry)
-
-            self._waitOnWorkerPool(
-                    pool,
-                    expected_result_count=len(factories),
-                    result_handler=_handler)
+            jobs = [
+                BakeWorkerJob(JOB_LOAD, LoadJobPayload(fac))
+                for fac in factories]
+            ar = pool.queueJobs(jobs, handler=_handler)
+            ar.wait()
 
     def _renderRealmPages(self, record, pool, factories):
+        def _handler(res):
+            entry = record.getCurrentEntry(res.path)
+            if res.errors:
+                entry.errors += res.errors
+                record.current.success = False
+                self._logErrors(res.path, res.errors)
+
         logger.debug("Rendering %d realm pages..." % len(factories))
         with format_timed_scope(logger,
                                 "prepared %d pages" % len(factories),
                                 level=logging.DEBUG, colored=False,
                                 timer_env=self.app.env,
                                 timer_category='RenderFirstSubJob'):
-            expected_result_count = 0
+            jobs = []
             for fac in factories:
                 record_entry = record.getCurrentEntry(fac.path)
                 if record_entry.errors:
@@ -278,49 +276,38 @@
                 job = BakeWorkerJob(
                         JOB_RENDER_FIRST,
                         RenderFirstSubJobPayload(fac))
-                pool.queue.put_nowait(job)
-                expected_result_count += 1
+                jobs.append(job)
 
-            def _handler(res):
-                entry = record.getCurrentEntry(res.path)
-                if res.errors:
-                    entry.errors += res.errors
-                    record.current.success = False
-                    self._logErrors(res.path, res.errors)
-
-            self._waitOnWorkerPool(
-                    pool,
-                    expected_result_count=expected_result_count,
-                    result_handler=_handler)
+            ar = pool.queueJobs(jobs, handler=_handler)
+            ar.wait()
 
     def _bakeRealmPages(self, record, pool, realm, factories):
+        def _handler(res):
+            entry = record.getCurrentEntry(res.path, res.taxonomy_info)
+            entry.subs = res.sub_entries
+            if res.errors:
+                entry.errors += res.errors
+                self._logErrors(res.path, res.errors)
+            if entry.has_any_error:
+                record.current.success = False
+            if entry.was_any_sub_baked:
+                record.current.baked_count[realm] += 1
+                record.dirty_source_names.add(entry.source_name)
+
         logger.debug("Baking %d realm pages..." % len(factories))
         with format_timed_scope(logger,
                                 "baked %d pages" % len(factories),
                                 level=logging.DEBUG, colored=False,
                                 timer_env=self.app.env,
                                 timer_category='BakeJob'):
-            expected_result_count = 0
+            jobs = []
             for fac in factories:
-                if self._queueBakeJob(record, pool, fac):
-                    expected_result_count += 1
+                job = self._makeBakeJob(record, fac)
+                if job is not None:
+                    jobs.append(job)
 
-            def _handler(res):
-                entry = record.getCurrentEntry(res.path, res.taxonomy_info)
-                entry.subs = res.sub_entries
-                if res.errors:
-                    entry.errors += res.errors
-                    self._logErrors(res.path, res.errors)
-                if entry.has_any_error:
-                    record.current.success = False
-                if entry.was_any_sub_baked:
-                    record.current.baked_count[realm] += 1
-                    record.dirty_source_names.add(entry.source_name)
-
-            self._waitOnWorkerPool(
-                    pool,
-                    expected_result_count=expected_result_count,
-                    result_handler=_handler)
+            ar = pool.queueJobs(jobs, handler=_handler)
+            ar.wait()
 
     def _bakeTaxonomies(self, record, pool):
         logger.debug("Baking taxonomy pages...")
@@ -400,8 +387,16 @@
         return buckets
 
     def _bakeTaxonomyBuckets(self, record, pool, buckets):
+        def _handler(res):
+            entry = record.getCurrentEntry(res.path, res.taxonomy_info)
+            entry.subs = res.sub_entries
+            if res.errors:
+                entry.errors += res.errors
+            if entry.has_any_error:
+                record.current.success = False
+
         # Start baking those terms.
-        expected_result_count = 0
+        jobs = []
         for source_name, source_taxonomies in buckets.items():
             for tax_name, tt_info in source_taxonomies.items():
                 terms = tt_info.dirty_terms
@@ -435,21 +430,12 @@
                             fac.source.name, fac.path, tax_info)
                     record.addEntry(cur_entry)
 
-                    if self._queueBakeJob(record, pool, fac, tax_info):
-                        expected_result_count += 1
+                    job = self._makeBakeJob(record, fac, tax_info)
+                    if job is not None:
+                        jobs.append(job)
 
-        def _handler(res):
-            entry = record.getCurrentEntry(res.path, res.taxonomy_info)
-            entry.subs = res.sub_entries
-            if res.errors:
-                entry.errors += res.errors
-            if entry.has_any_error:
-                record.current.success = False
-
-        self._waitOnWorkerPool(
-                pool,
-                expected_result_count=expected_result_count,
-                result_handler=_handler)
+        ar = pool.queueJobs(jobs, handler=_handler)
+        ar.wait()
 
         # Now we create bake entries for all the terms that were *not* dirty.
         # This is because otherwise, on the next incremental bake, we wouldn't
@@ -470,9 +456,9 @@
                     logger.debug("Taxonomy term '%s:%s' isn't used anymore." %
                                  (ti.taxonomy_name, ti.term))
 
-        return expected_result_count
+        return len(jobs)
 
-    def _queueBakeJob(self, record, pool, fac, tax_info=None):
+    def _makeBakeJob(self, record, fac, tax_info=None):
         # Get the previous (if any) and current entry for this page.
         pair = record.getPreviousAndCurrentEntries(fac.path, tax_info)
         assert pair is not None
@@ -483,7 +469,7 @@
         if cur_entry.errors:
             logger.debug("Ignoring %s because it had previous "
                          "errors." % fac.ref_spec)
-            return False
+            return None
 
         # Build the route metadata and find the appropriate route.
         page = fac.buildPage()
@@ -515,15 +501,14 @@
                         (fac.ref_spec, uri, override_entry.path))
                 logger.error(cur_entry.errors[-1])
             cur_entry.flags |= BakeRecordEntry.FLAG_OVERRIDEN
-            return False
+            return None
 
         job = BakeWorkerJob(
                 JOB_BAKE,
                 BakeJobPayload(fac, route_metadata, prev_entry,
                                record.dirty_source_names,
                                tax_info))
-        pool.queue.put_nowait(job)
-        return True
+        return job
 
     def _handleDeletetions(self, record):
         logger.debug("Handling deletions...")
@@ -544,78 +529,16 @@
             logger.error("  " + e)
 
     def _createWorkerPool(self):
-        import sys
-        from piecrust.baking.worker import BakeWorkerContext, worker_func
-
-        main_module = sys.modules['__main__']
-        is_profiling = os.path.basename(main_module.__file__) in [
-                'profile.py', 'cProfile.py']
-
-        pool = _WorkerPool()
-        for i in range(self.num_workers):
-            ctx = BakeWorkerContext(
-                    self.app.root_dir, self.app.cache.base_dir, self.out_dir,
-                    pool.queue, pool.results, pool.abort_event,
-                    force=self.force, debug=self.app.debug,
-                    is_profiling=is_profiling)
-            w = multiprocessing.Process(
-                    name='BakeWorker_%d' % i,
-                    target=worker_func, args=(i, ctx))
-            w.start()
-            pool.workers.append(w)
-        return pool
-
-    def _terminateWorkerPool(self, pool):
-        pool.abort_event.set()
-        for w in pool.workers:
-            w.join()
+        from piecrust.workerpool import WorkerPool
+        from piecrust.baking.worker import BakeWorkerContext, BakeWorker
 
-    def _waitOnWorkerPool(self, pool,
-                          expected_result_count=-1, result_handler=None):
-        assert result_handler is None or expected_result_count >= 0
-        abort_with_exception = None
-        try:
-            if result_handler is None:
-                pool.queue.join()
-            else:
-                got_count = 0
-                while got_count < expected_result_count:
-                    try:
-                        res = pool.results.get(True, 10)
-                    except queue.Empty:
-                        logger.error(
-                                "Got %d results, expected %d, and timed-out "
-                                "for 10 seconds. A worker might be stuck?" %
-                                (got_count, expected_result_count))
-                        abort_with_exception = Exception("Worker time-out.")
-                        break
-
-                    if isinstance(res, dict) and res.get('type') == 'error':
-                        abort_with_exception = Exception(
-                                'Worker critical error:\n' +
-                                '\n'.join(res['messages']))
-                        break
-
-                    got_count += 1
-                    result_handler(res)
-        except KeyboardInterrupt as kiex:
-            logger.warning("Bake aborted by user... "
-                           "waiting for workers to stop.")
-            abort_with_exception = kiex
-
-        if abort_with_exception:
-            pool.abort_event.set()
-            for w in pool.workers:
-                w.join(2)
-            raise abort_with_exception
-
-
-class _WorkerPool(object):
-    def __init__(self):
-        self.queue = multiprocessing.JoinableQueue()
-        self.results = multiprocessing.Queue()
-        self.abort_event = multiprocessing.Event()
-        self.workers = []
+        ctx = BakeWorkerContext(
+                self.app.root_dir, self.app.cache.base_dir, self.out_dir,
+                force=self.force, debug=self.app.debug)
+        pool = WorkerPool(
+                worker_class=BakeWorker,
+                initargs=(ctx,))
+        return pool
 
 
 class _TaxonomyTermsInfo(object):