view piecrust/workerpool.py @ 991:1857dbd4580f

bake: Fix bugs introduced by bake optimizations, of course. - Make the execution stats JSON-serializable. - Re-add ability to differentiate between sources used during segment rendering and during layout rendering. Fixes problems with cache invalidation of pages that use other sources. - Make taxonomy-related stuff JSON-serializable.
author Ludovic Chabant <ludovic@chabant.com>
date Mon, 20 Nov 2017 23:06:47 -0800
parents 8adc27285d93
children 09c3d415d9e5
line wrap: on
line source

import io
import os
import sys
import time
import logging
import threading
import traceback
import multiprocessing
from piecrust.environment import ExecutionStats


logger = logging.getLogger(__name__)

use_fastqueue = False
use_fastpickle = False
use_msgpack = False
use_marshall = False
use_json = False


class IWorker(object):
    """ Interface for a pool worker.
    """
    def initialize(self):
        raise NotImplementedError()

    def process(self, job):
        raise NotImplementedError()

    def getStats(self):
        return None

    def shutdown(self):
        pass


def _get_worker_exception_data(wid):
    t, v, tb = sys.exc_info()
    return {
        'wid': wid,
        'type': str(t),
        'value': '\n'.join(_get_errors(v)),
        'traceback': ''.join(traceback.format_exception(t, v, tb))
    }


def _get_errors(ex):
    errors = []
    while ex is not None:
        msg = str(ex)
        errors.append(msg)
        ex = ex.__cause__
    return errors


TASK_JOB = 0
TASK_JOB_BATCH = 1
TASK_END = 2
_TASK_ABORT_WORKER = 10
_CRITICAL_WORKER_ERROR = 11


def worker_func(params):
    if params.is_profiling:
        try:
            import cProfile as profile
        except ImportError:
            import profile

        params.is_profiling = False
        name = params.worker_class.__name__
        profile.runctx('_real_worker_func(params)',
                       globals(), locals(),
                       filename='%s-%d.prof' % (name, params.wid))
    else:
        _real_worker_func(params)


def _real_worker_func(params):
    try:
        _real_worker_func_unsafe(params)
    except Exception as ex:
        logger.exception(ex)
        msg = ("CRITICAL ERROR IN WORKER %d\n%s" % (params.wid, str(ex)))
        params.outqueue.put((
            _CRITICAL_WORKER_ERROR, None, False, params.wid, msg))


def _pre_parse_pytest_args():
    # If we are unit-testing, we need to translate our test logging
    # arguments into something Chef can understand.
    import argparse
    parser = argparse.ArgumentParser()
    # This is adapted from our `conftest.py`.
    parser.add_argument('--log-debug', action='store_true')
    parser.add_argument('--log-file')
    res, _ = parser.parse_known_args(sys.argv[1:])

    chef_args = []
    if res.log_debug:
        chef_args.append('--debug')
    if res.log_file:
        chef_args += ['--log', res.log_file]

    root_logger = logging.getLogger()
    while len(root_logger.handlers) > 0:
        root_logger.removeHandler(root_logger.handlers[0])

    from piecrust.main import _pre_parse_chef_args
    _pre_parse_chef_args(chef_args)


def _real_worker_func_unsafe(params):
    wid = params.wid

    stats = ExecutionStats()
    stats.registerTimer('WorkerInit')
    init_start_time = time.perf_counter()

    # In a context where `multiprocessing` is using the `spawn` forking model,
    # the new process doesn't inherit anything, so we lost all our logging
    # configuration here. Let's set it up again.
    if (hasattr(multiprocessing, 'get_start_method') and
            multiprocessing.get_start_method() == 'spawn'):
        if not params.is_unit_testing:
            from piecrust.main import _pre_parse_chef_args
            _pre_parse_chef_args(sys.argv[1:])
        else:
            _pre_parse_pytest_args()
    elif params.is_unit_testing:
        _pre_parse_pytest_args()

    from piecrust.main import ColoredFormatter
    root_logger = logging.getLogger()
    root_logger.handlers[0].setFormatter(ColoredFormatter(
        ('[W-%d]' % wid) + '[%(name)s] %(message)s'))

    logger.debug("Worker %d initializing..." % wid)

    # We don't need those.
    params.inqueue._writer.close()
    params.outqueue._reader.close()

    # Initialize the underlying worker class.
    w = params.worker_class(*params.initargs)
    w.wid = wid
    try:
        w.initialize()
    except Exception as ex:
        logger.error("Worker %d failed to initialize." % wid)
        logger.exception(ex)
        raise

    stats.stepTimerSince('WorkerInit', init_start_time)

    # Start pumping!
    completed = 0
    time_in_get = 0
    time_in_put = 0
    get = params.inqueue.get
    put = params.outqueue.put

    while True:
        get_start_time = time.perf_counter()
        task = get()
        time_in_get += (time.perf_counter() - get_start_time)

        task_type, task_data = task

        # Job task(s)... just do it.
        if task_type == TASK_JOB or task_type == TASK_JOB_BATCH:

            task_data_list = task_data
            if task_type == TASK_JOB:
                task_data_list = [task_data]

            result_list = []

            for td in task_data_list:
                try:
                    res = w.process(td)
                    result_list.append((td, res, True))
                except Exception as e:
                    logger.debug(
                        "Error processing job, sending exception to main process:")
                    logger.debug(traceback.format_exc())
                    we = _get_worker_exception_data(wid)
                    res = (td, we, False)
                    result_list.append((td, res, False))

            res = (task_type, wid, result_list)
            put_start_time = time.perf_counter()
            put(res)
            time_in_put += (time.perf_counter() - put_start_time)

            completed += len(task_data_list)

        # End task... gather stats to send back to the main process.
        elif task_type == TASK_END:
            logger.debug("Worker %d got end task, exiting." % wid)
            stats.registerTimer('WorkerTaskGet', time=time_in_get)
            stats.registerTimer('WorkerResultPut', time=time_in_put)
            try:
                stats.mergeStats(w.getStats())
                stats_data = stats.toData()
                rep = (task_type, wid, [(task_data, (wid, stats_data), True)])
            except Exception as e:
                logger.debug(
                    "Error getting report, sending exception to main process:")
                logger.debug(traceback.format_exc())
                we = _get_worker_exception_data(wid)
                rep = (task_type, wid, [(task_data, (wid, we), False)])
            put(rep)
            break

        # Emergy abort.
        elif task_type == _TASK_ABORT_WORKER:
            logger.debug("Worker %d got abort signal." % wid)
            break

        else:
            raise Exception("Unknown task type: %s" % task_type)

    try:
        w.shutdown()
    except Exception as e:
        logger.error("Worker %s failed to shutdown.")
        logger.exception(e)
        raise

    logger.debug("Worker %d completed %d tasks." % (wid, completed))


class _WorkerParams:
    def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(),
                 is_profiling=False, is_unit_testing=False):
        self.wid = wid
        self.inqueue = inqueue
        self.outqueue = outqueue
        self.worker_class = worker_class
        self.initargs = initargs
        self.is_profiling = is_profiling
        self.is_unit_testing = is_unit_testing


class WorkerPool:
    def __init__(self, worker_class, initargs=(), *,
                 callback=None, error_callback=None,
                 worker_count=None, batch_size=None,
                 userdata=None):
        self.userdata = userdata

        worker_count = worker_count or os.cpu_count() or 1

        if use_fastqueue:
            self._task_queue = FastQueue()
            self._result_queue = FastQueue()
            self._quick_put = self._task_queue.put
            self._quick_get = self._result_queue.get
        else:
            self._task_queue = multiprocessing.SimpleQueue()
            self._result_queue = multiprocessing.SimpleQueue()
            self._quick_put = self._task_queue.put
            self._quick_get = self._result_queue.get

        self._callback = callback
        self._error_callback = error_callback
        self._batch_size = batch_size
        self._jobs_left = 0
        self._lock_jobs_left = threading.Lock()
        self._lock_workers = threading.Lock()
        self._event = threading.Event()
        self._error_on_join = None
        self._closed = False

        main_module = sys.modules['__main__']
        is_profiling = os.path.basename(main_module.__file__) in [
            'profile.py', 'cProfile.py']
        is_unit_testing = os.path.basename(main_module.__file__) in [
            'py.test']

        self._pool = []
        for i in range(worker_count):
            worker_params = _WorkerParams(
                i, self._task_queue, self._result_queue,
                worker_class, initargs,
                is_profiling=is_profiling,
                is_unit_testing=is_unit_testing)
            w = multiprocessing.Process(target=worker_func,
                                        args=(worker_params,))
            w.name = w.name.replace('Process', 'PoolWorker')
            w.daemon = True
            w.start()
            self._pool.append(w)

        self._result_handler = threading.Thread(
            target=WorkerPool._handleResults,
            args=(self,))
        self._result_handler.daemon = True
        self._result_handler.start()

    def queueJobs(self, jobs):
        if self._closed:
            if self._error_on_join:
                raise self._error_on_join
            raise Exception("This worker pool has been closed.")

        jobs = list(jobs)
        new_job_count = len(jobs)
        if new_job_count > 0:
            with self._lock_jobs_left:
                self._jobs_left += new_job_count

            self._event.clear()
            bs = self._batch_size
            if not bs:
                for job in jobs:
                    self._quick_put((TASK_JOB, job))
            else:
                cur_offset = 0
                while cur_offset < new_job_count:
                    next_batch_idx = min(cur_offset + bs, new_job_count)
                    job_batch = jobs[cur_offset:next_batch_idx]
                    self._quick_put((TASK_JOB_BATCH, job_batch))
                    cur_offset = next_batch_idx
        else:
            with self._lock_jobs_left:
                done = (self._jobs_left == 0)
            if done:
                self._event.set()

    def wait(self, timeout=None):
        if self._closed:
            raise Exception("This worker pool has been closed.")

        ret = self._event.wait(timeout)
        if self._error_on_join:
            raise self._error_on_join
        return ret

    def close(self):
        if self._closed:
            raise Exception("This worker pool has been closed.")
        if self._jobs_left > 0:
            raise Exception("A previous job queue has not finished yet.")
        if not self._event.is_set():
            raise Exception("A previous job queue hasn't been cleared.")

        logger.debug("Closing worker pool...")
        live_workers = list(filter(lambda w: w is not None, self._pool))
        handler = _ReportHandler(len(live_workers))
        self._callback = handler._handle
        self._error_callback = handler._handleError
        for w in live_workers:
            self._quick_put((TASK_END, None))
        for w in live_workers:
            w.join()

        logger.debug("Waiting for reports...")
        if not handler.wait(2):
            missing = handler.reports.index(None)
            logger.warning(
                "Didn't receive all worker reports before timeout. "
                "Missing report from worker %d." % missing)

        logger.debug("Exiting result handler thread...")
        self._result_queue.put(None)
        self._result_handler.join()
        self._closed = True

        return handler.reports

    def _onResultHandlerCriticalError(self, wid):
        logger.error("Result handler received a critical error from "
                     "worker %d." % wid)
        with self._lock_workers:
            self._pool[wid] = None
            if all(map(lambda w: w is None, self._pool)):
                logger.error("All workers have died!")
                self._closed = True
                self._error_on_join = Exception("All workers have died!")
                self._event.set()
                return False

        return True

    def _onTaskDone(self):
        with self._lock_jobs_left:
            left = self._jobs_left - 1
            self._jobs_left = left

        if left == 0:
            self._event.set()

    @staticmethod
    def _handleResults(pool):
        userdata = pool.userdata
        while True:
            try:
                res = pool._quick_get()
            except (EOFError, OSError):
                logger.debug("Result handler thread encountered connection "
                             "problem, exiting.")
                return

            if res is None:
                logger.debug("Result handler exiting.")
                return

            task_type, wid, res_data_list = res
            for res_data in res_data_list:
                try:
                    task_data, data, success = res_data
                    if success:
                        if pool._callback:
                            pool._callback(task_data, data, userdata)
                    else:
                        if task_type == _CRITICAL_WORKER_ERROR:
                            logger.error(data)
                            do_continue = pool._onResultHandlerCriticalError(wid)
                            if not do_continue:
                                logger.debug("Aborting result handling thread.")
                                return
                        else:
                            if pool._error_callback:
                                pool._error_callback(task_data, data, userdata)
                            else:
                                logger.error(
                                    "Worker %d failed to process a job:" % wid)
                                logger.error(data)
                except Exception as ex:
                    logger.exception(ex)

            if task_type == TASK_JOB:
                pool._onTaskDone()


class _ReportHandler:
    def __init__(self, worker_count):
        self.reports = [None] * worker_count
        self._count = worker_count
        self._received = 0
        self._lock = threading.Lock()
        self._event = threading.Event()

    def wait(self, timeout=None):
        return self._event.wait(timeout)

    def _handle(self, job, res, _):
        wid, data = res
        if wid < 0 or wid > self._count:
            logger.error("Ignoring report from unknown worker %d." % wid)
            return

        stats = ExecutionStats()
        stats.fromData(data)

        with self._lock:
            self.reports[wid] = stats
            self._received += 1
            if self._received == self._count:
                self._event.set()

    def _handleError(self, job, res, _):
        logger.error("Worker %d failed to send its report." % res.wid)
        logger.error(res)


class FastQueue:
    def __init__(self):
        self._reader, self._writer = multiprocessing.Pipe(duplex=False)
        self._rlock = multiprocessing.Lock()
        self._wlock = multiprocessing.Lock()
        self._initBuffers()
        self._initSerializer()

    def _initBuffers(self):
        self._rbuf = io.BytesIO()
        self._rbuf.truncate(256)
        self._wbuf = io.BytesIO()
        self._wbuf.truncate(256)

    def _initSerializer(self):
        pass

    def __getstate__(self):
        return (self._reader, self._writer, self._rlock, self._wlock)

    def __setstate__(self, state):
        (self._reader, self._writer, self._rlock, self._wlock) = state
        self._initBuffers()

    def get(self):
        with self._rlock:
            self._rbuf.seek(0)
            try:
                with self._rbuf.getbuffer() as b:
                    bufsize = self._reader.recv_bytes_into(b)
            except multiprocessing.BufferTooShort as e:
                bufsize = len(e.args[0])
                self._rbuf.truncate(bufsize * 2)
                self._rbuf.seek(0)
                self._rbuf.write(e.args[0])

        self._rbuf.seek(0)
        return _unpickle(self, self._rbuf, bufsize)

    def put(self, obj):
        self._wbuf.seek(0)
        _pickle(self, obj, self._wbuf)
        size = self._wbuf.tell()

        self._wbuf.seek(0)
        with self._wlock:
            with self._wbuf.getbuffer() as b:
                self._writer.send_bytes(b, 0, size)


class _BufferWrapper:
    def __init__(self, buf, read_size=0):
        self._buf = buf
        self._read_size = read_size

    def write(self, data):
        self._buf.write(data.encode('utf8'))

    def read(self):
        return self._buf.read(self._read_size).decode('utf8')


if use_fastpickle:
    from piecrust import fastpickle

    def _pickle_fast(queue, obj, buf):
        fastpickle.pickle_intob(obj, buf)

    def _unpickle_fast(queue, buf, bufsize):
        return fastpickle.unpickle_fromb(buf, bufsize)

    _pickle = _pickle_fast
    _unpickle = _unpickle_fast

elif use_msgpack:
    import msgpack

    def _pickle_msgpack(queue, obj, buf):
        buf.write(queue._packer.pack(obj))

    def _unpickle_msgpack(queue, buf, bufsize):
        queue._unpacker.feed(buf.getbuffer())
        for o in queue._unpacker:
            return o
        # return msgpack.unpack(buf)

    def _init_msgpack(queue):
        queue._packer = msgpack.Packer()
        queue._unpacker = msgpack.Unpacker()

    _pickle = _pickle_msgpack
    _unpickle = _unpickle_msgpack
    FastQueue._initSerializer = _init_msgpack

elif use_marshall:
    import marshal

    def _pickle_marshal(queue, obj, buf):
        marshal.dump(obj, buf)

    def _unpickle_marshal(queue, buf, bufsize):
        return marshal.load(buf)

    _pickle = _pickle_marshal
    _unpickle = _unpickle_marshal

elif use_json:
    import json

    def _pickle_json(queue, obj, buf):
        buf = _BufferWrapper(buf)
        json.dump(obj, buf, indent=None, separators=(',', ':'))

    def _unpickle_json(queue, buf, bufsize):
        buf = _BufferWrapper(buf, bufsize)
        return json.load(buf)

    _pickle = _pickle_json
    _unpickle = _unpickle_json

else:
    import pickle

    def _pickle_default(queue, obj, buf):
        pickle.dump(obj, buf, pickle.HIGHEST_PROTOCOL)

    def _unpickle_default(queue, buf, bufsize):
        return pickle.load(buf)

    _pickle = _pickle_default
    _unpickle = _unpickle_default