diff piecrust/baking/baker.py @ 85:3471ffa059b2

Add a `BakeScheduler` to handle build dependencies. Add unit-tests.
author Ludovic Chabant <ludovic@chabant.com>
date Wed, 03 Sep 2014 17:27:50 -0700
parents 2fec3ee1298f
children e88e330eb8dc
line wrap: on
line diff
--- a/piecrust/baking/baker.py	Wed Sep 03 17:26:38 2014 -0700
+++ b/piecrust/baking/baker.py	Wed Sep 03 17:27:50 2014 -0700
@@ -6,7 +6,6 @@
 import logging
 import threading
 import urllib.request, urllib.error, urllib.parse
-from queue import Queue, Empty
 from piecrust.baking.records import TransitionalBakeRecord, BakeRecordPageEntry
 from piecrust.chefutil import format_timed, log_friendly_exception
 from piecrust.data.filters import (PaginationFilter, HasFilterClause,
@@ -27,7 +26,6 @@
         self.out_dir = out_dir
         self.force = force
         self.record = record
-        self.force = force
         self.copy_assets = copy_assets
         self.site_root = app.config.get('site/root')
         self.pretty_urls = app.config.get('site/pretty_urls')
@@ -136,8 +134,9 @@
         cur_record_entry = BakeRecordPageEntry(page)
         cur_record_entry.taxonomy_name = taxonomy_name
         cur_record_entry.taxonomy_term = taxonomy_term
-        prev_record_entry = self.record.getPreviousEntry(page, taxonomy_name,
-                taxonomy_term)
+        prev_record_entry = self.record.getPreviousEntry(
+                factory.source.name, factory.rel_path,
+                taxonomy_name, taxonomy_term)
 
         logger.debug("Baking '%s'..." % uri)
         while has_more_subs:
@@ -351,6 +350,7 @@
                 reason = "templates modified"
 
         if reason is not None:
+            # We have to bake everything from scratch.
             cache_dir = self.app.cache.getCacheDir('baker')
             if os.path.isdir(cache_dir):
                 logger.debug("Cleaning baker cache: %s" % cache_dir)
@@ -382,7 +382,7 @@
                     continue
 
                 logger.debug("Queuing: %s" % fac.ref_spec)
-                queue.put_nowait(BakeWorkerJob(fac, route))
+                queue.addJob(BakeWorkerJob(fac, route))
 
         self._waitOnWorkerPool(pool, abort)
 
@@ -469,7 +469,7 @@
                             {tax.term_name: term})
                     logger.debug("Queuing: %s [%s, %s]" %
                             (fac.ref_spec, tax_name, term))
-                    queue.put_nowait(
+                    queue.addJob(
                             BakeWorkerJob(fac, route, tax_name, term))
 
         self._waitOnWorkerPool(pool, abort)
@@ -487,23 +487,24 @@
 
     def _createWorkerPool(self, record, pool_size=4):
         pool = []
-        queue = Queue()
+        queue = BakeScheduler(record)
         abort = threading.Event()
         for i in range(pool_size):
             ctx = BakeWorkerContext(self.app, self.out_dir, self.force,
                     record, queue, abort)
             worker = BakeWorker(i, ctx)
-            worker.start()
             pool.append(worker)
         return pool, queue, abort
 
     def _waitOnWorkerPool(self, pool, abort):
         for w in pool:
+            w.start()
+        for w in pool:
             w.join()
         if abort.is_set():
             excs = [w.abort_exception for w in pool
                     if w.abort_exception is not None]
-            logger.error("%s errors" % len(excs))
+            logger.error("Baking was aborted due to %s error(s):" % len(excs))
             if self.app.debug:
                 for e in excs:
                     logger.exception(e)
@@ -513,6 +514,86 @@
             raise Exception("Baking was aborted due to errors.")
 
 
+class BakeScheduler(object):
+    _EMPTY = object()
+    _WAIT = object()
+
+    def __init__(self, record, jobs=None):
+        self.record = record
+        self.jobs = list(jobs) if jobs is not None else []
+        self._active_jobs = []
+        self._lock = threading.Lock()
+        self._added_event = threading.Event()
+        self._done_event = threading.Event()
+
+    def addJob(self, job):
+        logger.debug("Adding job '%s:%s' to scheduler." % (
+            job.factory.source.name, job.factory.rel_path))
+        with self._lock:
+            self.jobs.append(job)
+        self._added_event.set()
+
+    def onJobFinished(self, job):
+        logger.debug("Removing job '%s:%s' from scheduler." % (
+            job.factory.source.name, job.factory.rel_path))
+        with self._lock:
+            self._active_jobs.remove(job)
+        self._done_event.set()
+
+    def getNextJob(self, timeout=None):
+        self._added_event.clear()
+        self._done_event.clear()
+        job = self._doGetNextJob()
+        while job in (self._EMPTY, self._WAIT):
+            if timeout is None:
+                return None
+            if job == self._EMPTY:
+                logger.debug("Waiting for a new job to be added...")
+                res = self._added_event.wait(timeout)
+            elif job == self._WAIT:
+                logger.debug("Waiting for a job to be finished...")
+                res = self._done_event.wait(timeout)
+            if not res:
+                logger.debug("Timed-out. No job found.")
+                return None
+            job = self._doGetNextJob()
+        return job
+
+    def _doGetNextJob(self):
+        with self._lock:
+            if len(self.jobs) == 0:
+                return self._EMPTY
+
+            job = self.jobs.pop(0)
+            first_job = job
+            while not self._isJobReady(job):
+                logger.debug("Job '%s:%s' isn't ready yet." % (
+                        job.factory.source.name, job.factory.rel_path))
+                self.jobs.append(job)
+                job = self.jobs.pop(0)
+                if job == first_job:
+                    # None of the jobs are ready... we need to wait.
+                    return self._WAIT
+
+            logger.debug("Job '%s:%s' is ready to go, moving to active "
+                    "queue." % (job.factory.source.name, job.factory.rel_path))
+            self._active_jobs.append(job)
+            return job
+
+    def _isJobReady(self, job):
+        e = self.record.getPreviousEntry(job.factory.source.name,
+                job.factory.rel_path)
+        if not e:
+            return True
+        for sn in e.used_source_names:
+            if any(filter(lambda j: j.factory.source.name == sn, self.jobs)):
+                return False
+            if any(filter(lambda j: j.factory.source.name == sn,
+                    self._active_jobs)):
+                return False
+        return True
+
+
 class BakeWorkerContext(object):
     def __init__(self, app, out_dir, force, record, work_queue,
             abort_event):
@@ -547,16 +628,15 @@
 
     def run(self):
         while(not self.ctx.abort_event.is_set()):
-            try:
-                job = self.ctx.work_queue.get(True, 0.1)
-            except Empty:
+            job = self.ctx.work_queue.getNextJob()
+            if job is None:
                 logger.debug("[%d] No more work... shutting down." % self.wid)
                 break
 
             try:
                 self._unsafeRun(job)
                 logger.debug("[%d] Done with page." % self.wid)
-                self.ctx.work_queue.task_done()
+                self.ctx.work_queue.onJobFinished(job)
             except Exception as ex:
                 self.ctx.abort_event.set()
                 self.abort_exception = ex