diff piecrust/workerpool.py @ 852:4850f8c21b6e

core: Start of the big refactor for PieCrust 3.0. * Everything is a `ContentSource`, including assets directories. * Most content sources are subclasses of the base file-system source. * A source is processed by a "pipeline", and there are 2 built-in pipelines, one for assets and one for pages. The asset pipeline is vaguely functional, but the page pipeline is completely broken right now. * Rewrite the baking process as just running appropriate pipelines on each content item. This should allow for better parallelization.
author Ludovic Chabant <ludovic@chabant.com>
date Wed, 17 May 2017 00:11:48 -0700
parents c62d83e17abf
children 08e02c2a2a1a
line wrap: on
line diff
--- a/piecrust/workerpool.py	Sat Apr 29 21:42:22 2017 -0700
+++ b/piecrust/workerpool.py	Wed May 17 00:11:48 2017 -0700
@@ -2,37 +2,61 @@
 import os
 import sys
 import time
-import zlib
-import queue
+import pickle
 import logging
-import itertools
 import threading
+import traceback
 import multiprocessing
 from piecrust import fastpickle
+from piecrust.environment import ExecutionStats
 
 
 logger = logging.getLogger(__name__)
 
 use_fastqueue = True
+use_fastpickle = False
 
 
 class IWorker(object):
+    """ Interface for a pool worker.
+    """
     def initialize(self):
         raise NotImplementedError()
 
     def process(self, job):
         raise NotImplementedError()
 
-    def getReport(self, pool_reports):
+    def getStats(self):
         return None
 
     def shutdown(self):
         pass
 
 
+class WorkerExceptionData:
+    def __init__(self, wid):
+        super().__init__()
+        self.wid = wid
+        t, v, tb = sys.exc_info()
+        self.type = t
+        self.value = '\n'.join(_get_errors(v))
+        self.traceback = ''.join(traceback.format_exception(t, v, tb))
+
+    def __str__(self):
+        return str(self.value)
+
+
+def _get_errors(ex):
+    errors = []
+    while ex is not None:
+        msg = str(ex)
+        errors.append(msg)
+        ex = ex.__cause__
+    return errors
+
+
 TASK_JOB = 0
-TASK_BATCH = 1
-TASK_END = 2
+TASK_END = 1
 
 
 def worker_func(params):
@@ -52,6 +76,12 @@
 
 
 def _real_worker_func(params):
+    wid = params.wid
+
+    stats = ExecutionStats()
+    stats.registerTimer('WorkerInit')
+    init_start_time = time.perf_counter()
+
     # In a context where `multiprocessing` is using the `spawn` forking model,
     # the new process doesn't inherit anything, so we lost all our logging
     # configuration here. Let's set it up again.
@@ -60,7 +90,6 @@
         from piecrust.main import _pre_parse_chef_args
         _pre_parse_chef_args(sys.argv[1:])
 
-    wid = params.wid
     logger.debug("Worker %d initializing..." % wid)
 
     # We don't need those.
@@ -78,67 +107,49 @@
         params.outqueue.put(None)
         return
 
-    use_threads = False
-    if use_threads:
-        # Create threads to read/write the jobs and results from/to the
-        # main arbitrator process.
-        local_job_queue = queue.Queue()
-        reader_thread = threading.Thread(
-                target=_job_queue_reader,
-                args=(params.inqueue.get, local_job_queue),
-                name="JobQueueReaderThread")
-        reader_thread.start()
-
-        local_result_queue = queue.Queue()
-        writer_thread = threading.Thread(
-                target=_job_results_writer,
-                args=(local_result_queue, params.outqueue.put),
-                name="JobResultWriterThread")
-        writer_thread.start()
-
-        get = local_job_queue.get
-        put = local_result_queue.put_nowait
-    else:
-        get = params.inqueue.get
-        put = params.outqueue.put
+    stats.stepTimerSince('WorkerInit', init_start_time)
 
     # Start pumping!
     completed = 0
     time_in_get = 0
     time_in_put = 0
+    get = params.inqueue.get
+    put = params.outqueue.put
+
     while True:
         get_start_time = time.perf_counter()
         task = get()
         time_in_get += (time.perf_counter() - get_start_time)
 
         task_type, task_data = task
+
+        # End task... gather stats to send back to the main process.
         if task_type == TASK_END:
             logger.debug("Worker %d got end task, exiting." % wid)
-            wprep = {
-                    'WorkerTaskGet': time_in_get,
-                    'WorkerResultPut': time_in_put}
+            stats.registerTimer('WorkerTaskGet', time=time_in_get)
+            stats.registerTimer('WorkerResultPut', time=time_in_put)
             try:
-                rep = (task_type, True, wid, (wid, w.getReport(wprep)))
+                stats.mergeStats(w.getStats())
+                rep = (task_type, task_data, True, wid, (wid, stats))
             except Exception as e:
-                logger.debug("Error getting report: %s" % e)
-                if params.wrap_exception:
-                    e = multiprocessing.ExceptionWithTraceback(
-                            e, e.__traceback__)
-                rep = (task_type, False, wid, (wid, e))
+                logger.debug(
+                    "Error getting report, sending exception to main process:")
+                logger.debug(traceback.format_exc())
+                we = WorkerExceptionData(wid)
+                rep = (task_type, task_data, False, wid, (wid, we))
             put(rep)
             break
 
-        if task_type == TASK_JOB:
-            task_data = (task_data,)
-
-        for t in task_data:
+        # Job task... just do it.
+        elif task_type == TASK_JOB:
             try:
-                res = (TASK_JOB, True, wid, w.process(t))
+                res = (task_type, task_data, True, wid, w.process(task_data))
             except Exception as e:
-                if params.wrap_exception:
-                    e = multiprocessing.ExceptionWithTraceback(
-                            e, e.__traceback__)
-                res = (TASK_JOB, False, wid, e)
+                logger.debug(
+                    "Error processing job, sending exception to main process:")
+                logger.debug(traceback.format_exc())
+                we = WorkerExceptionData(wid)
+                res = (task_type, task_data, False, wid, we)
 
             put_start_time = time.perf_counter()
             put(res)
@@ -146,62 +157,28 @@
 
             completed += 1
 
-    if use_threads:
-        logger.debug("Worker %d waiting for reader/writer threads." % wid)
-        local_result_queue.put_nowait(None)
-        reader_thread.join()
-        writer_thread.join()
+        else:
+            raise Exception("Unknown task type: %s" % task_type)
 
     w.shutdown()
-
     logger.debug("Worker %d completed %d tasks." % (wid, completed))
 
 
-def _job_queue_reader(getter, out_queue):
-    while True:
-        try:
-            task = getter()
-        except (EOFError, OSError):
-            logger.debug("Worker encountered connection problem.")
-            break
-
-        out_queue.put_nowait(task)
-
-        if task[0] == TASK_END:
-            # Done reading jobs from the main process.
-            logger.debug("Got end task, exiting task queue reader thread.")
-            break
-
-
-def _job_results_writer(in_queue, putter):
-    while True:
-        res = in_queue.get()
-        if res is not None:
-            putter(res)
-            in_queue.task_done()
-        else:
-            # Got sentinel. Exit.
-            in_queue.task_done()
-            break
-    logger.debug("Exiting result queue writer thread.")
-
-
-class _WorkerParams(object):
+class _WorkerParams:
     def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(),
-                 wrap_exception=False, is_profiling=False):
+                 is_profiling=False):
         self.wid = wid
         self.inqueue = inqueue
         self.outqueue = outqueue
         self.worker_class = worker_class
         self.initargs = initargs
-        self.wrap_exception = wrap_exception
         self.is_profiling = is_profiling
 
 
-class WorkerPool(object):
-    def __init__(self, worker_class, initargs=(),
-                 worker_count=None, batch_size=None,
-                 wrap_exception=False):
+class WorkerPool:
+    def __init__(self, worker_class, initargs=(), *,
+                 callback=None, error_callback=None,
+                 worker_count=None, batch_size=None):
         worker_count = worker_count or os.cpu_count() or 1
 
         if use_fastqueue:
@@ -215,22 +192,22 @@
             self._quick_put = self._task_queue._writer.send
             self._quick_get = self._result_queue._reader.recv
 
+        self._callback = callback
+        self._error_callback = error_callback
         self._batch_size = batch_size
-        self._callback = None
-        self._error_callback = None
-        self._listener = None
+        self._jobs_left = 0
+        self._event = threading.Event()
 
         main_module = sys.modules['__main__']
         is_profiling = os.path.basename(main_module.__file__) in [
-                'profile.py', 'cProfile.py']
+            'profile.py', 'cProfile.py']
 
         self._pool = []
         for i in range(worker_count):
             worker_params = _WorkerParams(
-                    i, self._task_queue, self._result_queue,
-                    worker_class, initargs,
-                    wrap_exception=wrap_exception,
-                    is_profiling=is_profiling)
+                i, self._task_queue, self._result_queue,
+                worker_class, initargs,
+                is_profiling=is_profiling)
             w = multiprocessing.Process(target=worker_func,
                                         args=(worker_params,))
             w.name = w.name.replace('Process', 'PoolWorker')
@@ -239,66 +216,35 @@
             self._pool.append(w)
 
         self._result_handler = threading.Thread(
-                target=WorkerPool._handleResults,
-                args=(self,))
+            target=WorkerPool._handleResults,
+            args=(self,))
         self._result_handler.daemon = True
         self._result_handler.start()
 
         self._closed = False
 
-    def setHandler(self, callback=None, error_callback=None):
-        self._callback = callback
-        self._error_callback = error_callback
-
-    def queueJobs(self, jobs, handler=None, chunk_size=None):
+    def queueJobs(self, jobs):
         if self._closed:
             raise Exception("This worker pool has been closed.")
-        if self._listener is not None:
-            raise Exception("A previous job queue has not finished yet.")
 
-        if any([not p.is_alive() for p in self._pool]):
-            raise Exception("Some workers have prematurely exited.")
-
-        if handler is not None:
-            self.setHandler(handler)
-
-        if not hasattr(jobs, '__len__'):
-            jobs = list(jobs)
-        job_count = len(jobs)
-
-        res = AsyncResult(self, job_count)
-        if res._count == 0:
-            res._event.set()
-            return res
+        for job in jobs:
+            self._jobs_left += 1
+            self._quick_put((TASK_JOB, job))
 
-        self._listener = res
-
-        if chunk_size is None:
-            chunk_size = self._batch_size
-        if chunk_size is None:
-            chunk_size = max(1, job_count // 50)
-            logger.debug("Using chunk size of %d" % chunk_size)
+        if self._jobs_left > 0:
+            self._event.clear()
 
-        if chunk_size is None or chunk_size == 1:
-            for job in jobs:
-                self._quick_put((TASK_JOB, job))
-        else:
-            it = iter(jobs)
-            while True:
-                batch = tuple([i for i in itertools.islice(it, chunk_size)])
-                if not batch:
-                    break
-                self._quick_put((TASK_BATCH, batch))
-
-        return res
+    def wait(self, timeout=None):
+        return self._event.wait(timeout)
 
     def close(self):
-        if self._listener is not None:
+        if self._jobs_left > 0 or not self._event.is_set():
             raise Exception("A previous job queue has not finished yet.")
 
         logger.debug("Closing worker pool...")
         handler = _ReportHandler(len(self._pool))
         self._callback = handler._handle
+        self._error_callback = handler._handleError
         for w in self._pool:
             self._quick_put((TASK_END, None))
         for w in self._pool:
@@ -308,8 +254,8 @@
         if not handler.wait(2):
             missing = handler.reports.index(None)
             logger.warning(
-                    "Didn't receive all worker reports before timeout. "
-                    "Missing report from worker %d." % missing)
+                "Didn't receive all worker reports before timeout. "
+                "Missing report from worker %d." % missing)
 
         logger.debug("Exiting result handler thread...")
         self._result_queue.put(None)
@@ -318,6 +264,11 @@
 
         return handler.reports
 
+    def _onTaskDone(self):
+        self._jobs_left -= 1
+        if self._jobs_left == 0:
+            self._event.set()
+
     @staticmethod
     def _handleResults(pool):
         while True:
@@ -332,44 +283,26 @@
                 logger.debug("Result handler exiting.")
                 break
 
-            task_type, success, wid, data = res
+            task_type, task_data, success, wid, data = res
             try:
-                if success and pool._callback:
-                    pool._callback(data)
-                elif not success:
+                if success:
+                    if pool._callback:
+                        pool._callback(task_data, data)
+                else:
                     if pool._error_callback:
-                        pool._error_callback(data)
+                        pool._error_callback(task_data, data)
                     else:
-                        logger.error("Got error data:")
+                        logger.error(
+                            "Worker %d failed to process a job:" % wid)
                         logger.error(data)
             except Exception as ex:
                 logger.exception(ex)
 
             if task_type == TASK_JOB:
-                pool._listener._onTaskDone()
+                pool._onTaskDone()
 
 
-class AsyncResult(object):
-    def __init__(self, pool, count):
-        self._pool = pool
-        self._count = count
-        self._event = threading.Event()
-
-    def ready(self):
-        return self._event.is_set()
-
-    def wait(self, timeout=None):
-        return self._event.wait(timeout)
-
-    def _onTaskDone(self):
-        self._count -= 1
-        if self._count == 0:
-            self._pool.setHandler(None)
-            self._pool._listener = None
-            self._event.set()
-
-
-class _ReportHandler(object):
+class _ReportHandler:
     def __init__(self, worker_count):
         self.reports = [None] * worker_count
         self._count = worker_count
@@ -379,7 +312,7 @@
     def wait(self, timeout=None):
         return self._event.wait(timeout)
 
-    def _handle(self, res):
+    def _handle(self, job, res):
         wid, data = res
         if wid < 0 or wid > self._count:
             logger.error("Ignoring report from unknown worker %d." % wid)
@@ -391,13 +324,12 @@
         if self._received == self._count:
             self._event.set()
 
-    def _handleError(self, res):
-        wid, data = res
-        logger.error("Worker %d failed to send its report." % wid)
-        logger.exception(data)
+    def _handleError(self, job, res):
+        logger.error("Worker %d failed to send its report." % res.wid)
+        logger.error(res)
 
 
-class FastQueue(object):
+class FastQueue:
     def __init__(self):
         self._reader, self._writer = multiprocessing.Pipe(duplex=False)
         self._rlock = multiprocessing.Lock()
@@ -429,11 +361,11 @@
                 self._rbuf.write(e.args[0])
 
         self._rbuf.seek(0)
-        return self._unpickle(self._rbuf, bufsize)
+        return _unpickle(self._rbuf, bufsize)
 
     def put(self, obj):
         self._wbuf.seek(0)
-        self._pickle(obj, self._wbuf)
+        _pickle(obj, self._wbuf)
         size = self._wbuf.tell()
 
         self._wbuf.seek(0)
@@ -441,9 +373,27 @@
             with self._wbuf.getbuffer() as b:
                 self._writer.send_bytes(b, 0, size)
 
-    def _pickle(self, obj, buf):
-        fastpickle.pickle_intob(obj, buf)
+
+def _pickle_fast(obj, buf):
+    fastpickle.pickle_intob(obj, buf)
+
+
+def _unpickle_fast(buf, bufsize):
+    return fastpickle.unpickle_fromb(buf, bufsize)
+
+
+def _pickle_default(obj, buf):
+    pickle.dump(obj, buf)
 
-    def _unpickle(self, buf, bufsize):
-        return fastpickle.unpickle_fromb(buf, bufsize)
+
+def _unpickle_default(buf, bufsize):
+    return pickle.load(buf)
+
 
+if use_fastpickle:
+    _pickle = _pickle_fast
+    _unpickle = _unpickle_fast
+else:
+    _pickle = _pickle_default
+    _unpickle = _unpickle_default
+