Mercurial > piecrust2
changeset 971:5485a11591ec
internal: Worker pool improvements
- Slightly more correct multi-processing code.
- Handle critical worker failures.
- Correctly setup workers when run from `py.test`.
- Add support for using `msgpack` or `marshall` instead of `pickle`.
author | Ludovic Chabant <ludovic@chabant.com> |
---|---|
date | Tue, 17 Oct 2017 01:00:55 -0700 |
parents | 660250c95246 |
children | bbf5a96b56db |
files | piecrust/workerpool.py |
diffstat | 1 files changed, 174 insertions(+), 62 deletions(-) [+] |
line wrap: on
line diff
--- a/piecrust/workerpool.py Mon Oct 09 21:07:00 2017 -0700 +++ b/piecrust/workerpool.py Tue Oct 17 01:00:55 2017 -0700 @@ -2,19 +2,20 @@ import os import sys import time -import pickle import logging import threading import traceback import multiprocessing -from piecrust import fastpickle from piecrust.environment import ExecutionStats logger = logging.getLogger(__name__) use_fastqueue = False + use_fastpickle = False +use_msgpack = False +use_marshall = False class IWorker(object): @@ -57,6 +58,8 @@ TASK_JOB = 0 TASK_END = 1 +_TASK_ABORT_WORKER = 10 +_CRITICAL_WORKER_ERROR = 11 def worker_func(params): @@ -76,16 +79,46 @@ def _real_worker_func(params): + try: + _real_worker_func_unsafe(params) + except Exception as ex: + logger.exception(ex) + msg = ("CRITICAL ERROR IN WORKER %d\n%s" % (params.wid, str(ex))) + params.outqueue.put(( + _CRITICAL_WORKER_ERROR, None, False, params.wid, msg)) + + +def _real_worker_func_unsafe(params): wid = params.wid stats = ExecutionStats() stats.registerTimer('WorkerInit') init_start_time = time.perf_counter() + # If we are unit-testing, we didn't setup all the logging environment + # yet, since the executable is `py.test`. We need to translate our + # test logging arguments into something Chef can understand. + if params.is_unit_testing: + import argparse + parser = argparse.ArgumentParser() + # This is adapted from our `conftest.py`. + parser.add_argument('--log-debug', action='store_true') + parser.add_argument('--log-file') + res, _ = parser.parse_known_args(sys.argv[1:]) + + chef_args = [] + if res.log_debug: + chef_args.append('--debug') + if res.log_file: + chef_args += ['--log', res.log_file] + + from piecrust.main import _pre_parse_chef_args + _pre_parse_chef_args(chef_args) + # In a context where `multiprocessing` is using the `spawn` forking model, # the new process doesn't inherit anything, so we lost all our logging # configuration here. Let's set it up again. - if (hasattr(multiprocessing, 'get_start_method') and + elif (hasattr(multiprocessing, 'get_start_method') and multiprocessing.get_start_method() == 'spawn'): from piecrust.main import _pre_parse_chef_args _pre_parse_chef_args(sys.argv[1:]) @@ -107,10 +140,9 @@ try: w.initialize() except Exception as ex: - logger.error("Working failed to initialize:") + logger.error("Worker %d failed to initialize." % wid) logger.exception(ex) - params.outqueue.put(None) - return + raise stats.stepTimerSince('WorkerInit', init_start_time) @@ -128,8 +160,25 @@ task_type, task_data = task + # Job task... just do it. + if task_type == TASK_JOB: + try: + res = (task_type, task_data, True, wid, w.process(task_data)) + except Exception as e: + logger.debug( + "Error processing job, sending exception to main process:") + logger.debug(traceback.format_exc()) + we = WorkerExceptionData(wid) + res = (task_type, task_data, False, wid, we) + + put_start_time = time.perf_counter() + put(res) + time_in_put += (time.perf_counter() - put_start_time) + + completed += 1 + # End task... gather stats to send back to the main process. - if task_type == TASK_END: + elif task_type == TASK_END: logger.debug("Worker %d got end task, exiting." % wid) stats.registerTimer('WorkerTaskGet', time=time_in_get) stats.registerTimer('WorkerResultPut', time=time_in_put) @@ -145,39 +194,34 @@ put(rep) break - # Job task... just do it. - elif task_type == TASK_JOB: - try: - res = (task_type, task_data, True, wid, w.process(task_data)) - except Exception as e: - logger.debug( - "Error processing job, sending exception to main process:") - logger.debug(traceback.format_exc()) - we = WorkerExceptionData(wid) - res = (task_type, task_data, False, wid, we) - - put_start_time = time.perf_counter() - put(res) - time_in_put += (time.perf_counter() - put_start_time) - - completed += 1 + # Emergy abort. + elif task_type == _TASK_ABORT_WORKER: + logger.debug("Worker %d got abort signal." % wid) + break else: raise Exception("Unknown task type: %s" % task_type) - w.shutdown() + try: + w.shutdown() + except Exception as e: + logger.error("Worker %s failed to shutdown.") + logger.exception(e) + raise + logger.debug("Worker %d completed %d tasks." % (wid, completed)) class _WorkerParams: def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(), - is_profiling=False): + is_profiling=False, is_unit_testing=False): self.wid = wid self.inqueue = inqueue self.outqueue = outqueue self.worker_class = worker_class self.initargs = initargs self.is_profiling = is_profiling + self.is_unit_testing = is_unit_testing class WorkerPool: @@ -204,18 +248,25 @@ self._error_callback = error_callback self._batch_size = batch_size self._jobs_left = 0 + self._lock_jobs_left = threading.Lock() + self._lock_workers = threading.Lock() self._event = threading.Event() + self._error_on_join = None + self._closed = False main_module = sys.modules['__main__'] is_profiling = os.path.basename(main_module.__file__) in [ 'profile.py', 'cProfile.py'] + is_unit_testing = os.path.basename(main_module.__file__) in [ + 'py.test'] self._pool = [] for i in range(worker_count): worker_params = _WorkerParams( i, self._task_queue, self._result_queue, worker_class, initargs, - is_profiling=is_profiling) + is_profiling=is_profiling, + is_unit_testing=is_unit_testing) w = multiprocessing.Process(target=worker_func, args=(worker_params,)) w.name = w.name.replace('Process', 'PoolWorker') @@ -229,33 +280,45 @@ self._result_handler.daemon = True self._result_handler.start() - self._closed = False + def queueJobs(self, jobs): + if self._closed: + if self._error_on_join: + raise self._error_on_join + raise Exception("This worker pool has been closed.") - def queueJobs(self, jobs): + jobs = list(jobs) + new_job_count = len(jobs) + if new_job_count > 0: + with self._lock_jobs_left: + self._jobs_left += new_job_count + + self._event.clear() + for job in jobs: + self._quick_put((TASK_JOB, job)) + + def wait(self, timeout=None): if self._closed: raise Exception("This worker pool has been closed.") - for job in jobs: - self._jobs_left += 1 - self._quick_put((TASK_JOB, job)) - - if self._jobs_left > 0: - self._event.clear() - - def wait(self, timeout=None): - return self._event.wait(timeout) + ret = self._event.wait(timeout) + if self._error_on_join: + raise self._error_on_join + return ret def close(self): + if self._closed: + raise Exception("This worker pool has been closed.") if self._jobs_left > 0 or not self._event.is_set(): raise Exception("A previous job queue has not finished yet.") logger.debug("Closing worker pool...") - handler = _ReportHandler(len(self._pool)) + live_workers = list(filter(lambda w: w is not None, self._pool)) + handler = _ReportHandler(len(live_workers)) self._callback = handler._handle self._error_callback = handler._handleError - for w in self._pool: + for w in live_workers: self._quick_put((TASK_END, None)) - for w in self._pool: + for w in live_workers: w.join() logger.debug("Waiting for reports...") @@ -272,9 +335,26 @@ return handler.reports + def _onResultHandlerCriticalError(self, wid): + logger.error("Result handler received a critical error from " + "worker %d." % wid) + with self._lock_workers: + self._pool[wid] = None + if all(map(lambda w: w is None, self._pool)): + logger.error("All workers have died!") + self._closed = True + self._error_on_join = Exception("All workers have died!") + self._event.set() + return False + + return True + def _onTaskDone(self): - self._jobs_left -= 1 - if self._jobs_left == 0: + with self._lock_jobs_left: + left = self._jobs_left - 1 + self._jobs_left = left + + if left == 0: self._event.set() @staticmethod @@ -290,7 +370,7 @@ if res is None: logger.debug("Result handler exiting.") - break + return task_type, task_data, success, wid, data = res try: @@ -298,12 +378,19 @@ if pool._callback: pool._callback(task_data, data, userdata) else: - if pool._error_callback: - pool._error_callback(task_data, data, userdata) + if task_type == _CRITICAL_WORKER_ERROR: + logger.error(data) + do_continue = pool._onResultHandlerCriticalError(wid) + if not do_continue: + logger.debug("Aborting result handling thread.") + return else: - logger.error( - "Worker %d failed to process a job:" % wid) - logger.error(data) + if pool._error_callback: + pool._error_callback(task_data, data, userdata) + else: + logger.error( + "Worker %d failed to process a job:" % wid) + logger.error(data) except Exception as ex: logger.exception(ex) @@ -383,26 +470,51 @@ self._writer.send_bytes(b, 0, size) -def _pickle_fast(obj, buf): - fastpickle.pickle_intob(obj, buf) - - -def _unpickle_fast(buf, bufsize): - return fastpickle.unpickle_fromb(buf, bufsize) - +if use_fastpickle: + from piecrust import fastpickle -def _pickle_default(obj, buf): - pickle.dump(obj, buf, pickle.HIGHEST_PROTOCOL) - + def _pickle_fast(obj, buf): + fastpickle.pickle_intob(obj, buf) -def _unpickle_default(buf, bufsize): - return pickle.load(buf) + def _unpickle_fast(buf, bufsize): + return fastpickle.unpickle_fromb(buf, bufsize) - -if use_fastpickle: _pickle = _pickle_fast _unpickle = _unpickle_fast + +elif use_msgpack: + import msgpack + + def _pickle_msgpack(obj, buf): + msgpack.pack(obj, buf) + + def _unpickle_msgpack(buf, bufsize): + return msgpack.unpack(buf) + + _pickle = _pickle_msgpack + _unpickle = _unpickle_msgpack + +elif use_marshall: + import marshal + + def _pickle_marshal(obj, buf): + marshal.dump(obj, buf) + + def _unpickle_marshal(buf, bufsize): + return marshal.load(buf) + + _pickle = _pickle_marshal + _unpickle = _unpickle_marshal + else: + import pickle + + def _pickle_default(obj, buf): + pickle.dump(obj, buf, pickle.HIGHEST_PROTOCOL) + + def _unpickle_default(buf, bufsize): + return pickle.load(buf) + _pickle = _pickle_default _unpickle = _unpickle_default