changeset 971:5485a11591ec

internal: Worker pool improvements - Slightly more correct multi-processing code. - Handle critical worker failures. - Correctly setup workers when run from `py.test`. - Add support for using `msgpack` or `marshall` instead of `pickle`.
author Ludovic Chabant <ludovic@chabant.com>
date Tue, 17 Oct 2017 01:00:55 -0700
parents 660250c95246
children bbf5a96b56db
files piecrust/workerpool.py
diffstat 1 files changed, 174 insertions(+), 62 deletions(-) [+]
line wrap: on
line diff
--- a/piecrust/workerpool.py	Mon Oct 09 21:07:00 2017 -0700
+++ b/piecrust/workerpool.py	Tue Oct 17 01:00:55 2017 -0700
@@ -2,19 +2,20 @@
 import os
 import sys
 import time
-import pickle
 import logging
 import threading
 import traceback
 import multiprocessing
-from piecrust import fastpickle
 from piecrust.environment import ExecutionStats
 
 
 logger = logging.getLogger(__name__)
 
 use_fastqueue = False
+
 use_fastpickle = False
+use_msgpack = False
+use_marshall = False
 
 
 class IWorker(object):
@@ -57,6 +58,8 @@
 
 TASK_JOB = 0
 TASK_END = 1
+_TASK_ABORT_WORKER = 10
+_CRITICAL_WORKER_ERROR = 11
 
 
 def worker_func(params):
@@ -76,16 +79,46 @@
 
 
 def _real_worker_func(params):
+    try:
+        _real_worker_func_unsafe(params)
+    except Exception as ex:
+        logger.exception(ex)
+        msg = ("CRITICAL ERROR IN WORKER %d\n%s" % (params.wid, str(ex)))
+        params.outqueue.put((
+            _CRITICAL_WORKER_ERROR, None, False, params.wid, msg))
+
+
+def _real_worker_func_unsafe(params):
     wid = params.wid
 
     stats = ExecutionStats()
     stats.registerTimer('WorkerInit')
     init_start_time = time.perf_counter()
 
+    # If we are unit-testing, we didn't setup all the logging environment
+    # yet, since the executable is `py.test`. We need to translate our
+    # test logging arguments into something Chef can understand.
+    if params.is_unit_testing:
+        import argparse
+        parser = argparse.ArgumentParser()
+        # This is adapted from our `conftest.py`.
+        parser.add_argument('--log-debug', action='store_true')
+        parser.add_argument('--log-file')
+        res, _ = parser.parse_known_args(sys.argv[1:])
+
+        chef_args = []
+        if res.log_debug:
+            chef_args.append('--debug')
+        if res.log_file:
+            chef_args += ['--log', res.log_file]
+
+        from piecrust.main import _pre_parse_chef_args
+        _pre_parse_chef_args(chef_args)
+
     # In a context where `multiprocessing` is using the `spawn` forking model,
     # the new process doesn't inherit anything, so we lost all our logging
     # configuration here. Let's set it up again.
-    if (hasattr(multiprocessing, 'get_start_method') and
+    elif (hasattr(multiprocessing, 'get_start_method') and
             multiprocessing.get_start_method() == 'spawn'):
         from piecrust.main import _pre_parse_chef_args
         _pre_parse_chef_args(sys.argv[1:])
@@ -107,10 +140,9 @@
     try:
         w.initialize()
     except Exception as ex:
-        logger.error("Working failed to initialize:")
+        logger.error("Worker %d failed to initialize." % wid)
         logger.exception(ex)
-        params.outqueue.put(None)
-        return
+        raise
 
     stats.stepTimerSince('WorkerInit', init_start_time)
 
@@ -128,8 +160,25 @@
 
         task_type, task_data = task
 
+        # Job task... just do it.
+        if task_type == TASK_JOB:
+            try:
+                res = (task_type, task_data, True, wid, w.process(task_data))
+            except Exception as e:
+                logger.debug(
+                    "Error processing job, sending exception to main process:")
+                logger.debug(traceback.format_exc())
+                we = WorkerExceptionData(wid)
+                res = (task_type, task_data, False, wid, we)
+
+            put_start_time = time.perf_counter()
+            put(res)
+            time_in_put += (time.perf_counter() - put_start_time)
+
+            completed += 1
+
         # End task... gather stats to send back to the main process.
-        if task_type == TASK_END:
+        elif task_type == TASK_END:
             logger.debug("Worker %d got end task, exiting." % wid)
             stats.registerTimer('WorkerTaskGet', time=time_in_get)
             stats.registerTimer('WorkerResultPut', time=time_in_put)
@@ -145,39 +194,34 @@
             put(rep)
             break
 
-        # Job task... just do it.
-        elif task_type == TASK_JOB:
-            try:
-                res = (task_type, task_data, True, wid, w.process(task_data))
-            except Exception as e:
-                logger.debug(
-                    "Error processing job, sending exception to main process:")
-                logger.debug(traceback.format_exc())
-                we = WorkerExceptionData(wid)
-                res = (task_type, task_data, False, wid, we)
-
-            put_start_time = time.perf_counter()
-            put(res)
-            time_in_put += (time.perf_counter() - put_start_time)
-
-            completed += 1
+        # Emergy abort.
+        elif task_type == _TASK_ABORT_WORKER:
+            logger.debug("Worker %d got abort signal." % wid)
+            break
 
         else:
             raise Exception("Unknown task type: %s" % task_type)
 
-    w.shutdown()
+    try:
+        w.shutdown()
+    except Exception as e:
+        logger.error("Worker %s failed to shutdown.")
+        logger.exception(e)
+        raise
+
     logger.debug("Worker %d completed %d tasks." % (wid, completed))
 
 
 class _WorkerParams:
     def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(),
-                 is_profiling=False):
+                 is_profiling=False, is_unit_testing=False):
         self.wid = wid
         self.inqueue = inqueue
         self.outqueue = outqueue
         self.worker_class = worker_class
         self.initargs = initargs
         self.is_profiling = is_profiling
+        self.is_unit_testing = is_unit_testing
 
 
 class WorkerPool:
@@ -204,18 +248,25 @@
         self._error_callback = error_callback
         self._batch_size = batch_size
         self._jobs_left = 0
+        self._lock_jobs_left = threading.Lock()
+        self._lock_workers = threading.Lock()
         self._event = threading.Event()
+        self._error_on_join = None
+        self._closed = False
 
         main_module = sys.modules['__main__']
         is_profiling = os.path.basename(main_module.__file__) in [
             'profile.py', 'cProfile.py']
+        is_unit_testing = os.path.basename(main_module.__file__) in [
+            'py.test']
 
         self._pool = []
         for i in range(worker_count):
             worker_params = _WorkerParams(
                 i, self._task_queue, self._result_queue,
                 worker_class, initargs,
-                is_profiling=is_profiling)
+                is_profiling=is_profiling,
+                is_unit_testing=is_unit_testing)
             w = multiprocessing.Process(target=worker_func,
                                         args=(worker_params,))
             w.name = w.name.replace('Process', 'PoolWorker')
@@ -229,33 +280,45 @@
         self._result_handler.daemon = True
         self._result_handler.start()
 
-        self._closed = False
+    def queueJobs(self, jobs):
+        if self._closed:
+            if self._error_on_join:
+                raise self._error_on_join
+            raise Exception("This worker pool has been closed.")
 
-    def queueJobs(self, jobs):
+        jobs = list(jobs)
+        new_job_count = len(jobs)
+        if new_job_count > 0:
+            with self._lock_jobs_left:
+                self._jobs_left += new_job_count
+
+            self._event.clear()
+            for job in jobs:
+                self._quick_put((TASK_JOB, job))
+
+    def wait(self, timeout=None):
         if self._closed:
             raise Exception("This worker pool has been closed.")
 
-        for job in jobs:
-            self._jobs_left += 1
-            self._quick_put((TASK_JOB, job))
-
-        if self._jobs_left > 0:
-            self._event.clear()
-
-    def wait(self, timeout=None):
-        return self._event.wait(timeout)
+        ret = self._event.wait(timeout)
+        if self._error_on_join:
+            raise self._error_on_join
+        return ret
 
     def close(self):
+        if self._closed:
+            raise Exception("This worker pool has been closed.")
         if self._jobs_left > 0 or not self._event.is_set():
             raise Exception("A previous job queue has not finished yet.")
 
         logger.debug("Closing worker pool...")
-        handler = _ReportHandler(len(self._pool))
+        live_workers = list(filter(lambda w: w is not None, self._pool))
+        handler = _ReportHandler(len(live_workers))
         self._callback = handler._handle
         self._error_callback = handler._handleError
-        for w in self._pool:
+        for w in live_workers:
             self._quick_put((TASK_END, None))
-        for w in self._pool:
+        for w in live_workers:
             w.join()
 
         logger.debug("Waiting for reports...")
@@ -272,9 +335,26 @@
 
         return handler.reports
 
+    def _onResultHandlerCriticalError(self, wid):
+        logger.error("Result handler received a critical error from "
+                     "worker %d." % wid)
+        with self._lock_workers:
+            self._pool[wid] = None
+            if all(map(lambda w: w is None, self._pool)):
+                logger.error("All workers have died!")
+                self._closed = True
+                self._error_on_join = Exception("All workers have died!")
+                self._event.set()
+                return False
+
+        return True
+
     def _onTaskDone(self):
-        self._jobs_left -= 1
-        if self._jobs_left == 0:
+        with self._lock_jobs_left:
+            left = self._jobs_left - 1
+            self._jobs_left = left
+
+        if left == 0:
             self._event.set()
 
     @staticmethod
@@ -290,7 +370,7 @@
 
             if res is None:
                 logger.debug("Result handler exiting.")
-                break
+                return
 
             task_type, task_data, success, wid, data = res
             try:
@@ -298,12 +378,19 @@
                     if pool._callback:
                         pool._callback(task_data, data, userdata)
                 else:
-                    if pool._error_callback:
-                        pool._error_callback(task_data, data, userdata)
+                    if task_type == _CRITICAL_WORKER_ERROR:
+                        logger.error(data)
+                        do_continue = pool._onResultHandlerCriticalError(wid)
+                        if not do_continue:
+                            logger.debug("Aborting result handling thread.")
+                            return
                     else:
-                        logger.error(
-                            "Worker %d failed to process a job:" % wid)
-                        logger.error(data)
+                        if pool._error_callback:
+                            pool._error_callback(task_data, data, userdata)
+                        else:
+                            logger.error(
+                                "Worker %d failed to process a job:" % wid)
+                            logger.error(data)
             except Exception as ex:
                 logger.exception(ex)
 
@@ -383,26 +470,51 @@
                 self._writer.send_bytes(b, 0, size)
 
 
-def _pickle_fast(obj, buf):
-    fastpickle.pickle_intob(obj, buf)
-
-
-def _unpickle_fast(buf, bufsize):
-    return fastpickle.unpickle_fromb(buf, bufsize)
-
+if use_fastpickle:
+    from piecrust import fastpickle
 
-def _pickle_default(obj, buf):
-    pickle.dump(obj, buf, pickle.HIGHEST_PROTOCOL)
-
+    def _pickle_fast(obj, buf):
+        fastpickle.pickle_intob(obj, buf)
 
-def _unpickle_default(buf, bufsize):
-    return pickle.load(buf)
+    def _unpickle_fast(buf, bufsize):
+        return fastpickle.unpickle_fromb(buf, bufsize)
 
-
-if use_fastpickle:
     _pickle = _pickle_fast
     _unpickle = _unpickle_fast
+
+elif use_msgpack:
+    import msgpack
+
+    def _pickle_msgpack(obj, buf):
+        msgpack.pack(obj, buf)
+
+    def _unpickle_msgpack(buf, bufsize):
+        return msgpack.unpack(buf)
+
+    _pickle = _pickle_msgpack
+    _unpickle = _unpickle_msgpack
+
+elif use_marshall:
+    import marshal
+
+    def _pickle_marshal(obj, buf):
+        marshal.dump(obj, buf)
+
+    def _unpickle_marshal(buf, bufsize):
+        return marshal.load(buf)
+
+    _pickle = _pickle_marshal
+    _unpickle = _unpickle_marshal
+
 else:
+    import pickle
+
+    def _pickle_default(obj, buf):
+        pickle.dump(obj, buf, pickle.HIGHEST_PROTOCOL)
+
+    def _unpickle_default(buf, bufsize):
+        return pickle.load(buf)
+
     _pickle = _pickle_default
     _unpickle = _unpickle_default