view piecrust/workerpool.py @ 550:6f216c1ab6b1

bake: Add a flag to know which record entries got collapsed from last run. This makes it possible to find entries for things that were actually baked during the current run, as opposed to skipped because they were "clean".
author Ludovic Chabant <ludovic@chabant.com>
date Tue, 04 Aug 2015 21:22:30 -0700
parents 22a230d99621
children 8073ae8cb164
line wrap: on
line source

import os
import sys
import zlib
import logging
import itertools
import threading
import multiprocessing
from piecrust.fastpickle import pickle, unpickle


logger = logging.getLogger(__name__)


class IWorker(object):
    def initialize(self):
        raise NotImplementedError()

    def process(self, job):
        raise NotImplementedError()

    def getReport(self):
        return None


TASK_JOB = 0
TASK_BATCH = 1
TASK_END = 2


def worker_func(params):
    if params.is_profiling:
        try:
            import cProfile as profile
        except ImportError:
            import profile

        params.is_profiling = False
        name = params.worker_class.__name__
        profile.runctx('_real_worker_func(params)',
                       globals(), locals(),
                       filename='%s-%d.prof' % (name, params.wid))
    else:
        _real_worker_func(params)


def _real_worker_func(params):
    # In a context where `multiprocessing` is using the `spawn` forking model,
    # the new process doesn't inherit anything, so we lost all our logging
    # configuration here. Let's set it up again.
    from piecrust.main import _pre_parse_chef_args
    _pre_parse_chef_args(sys.argv[1:])

    wid = params.wid
    logger.debug("Worker %d initializing..." % wid)

    params.inqueue._writer.close()
    params.outqueue._reader.close()

    w = params.worker_class(*params.initargs)
    w.wid = wid
    try:
        w.initialize()
    except Exception as ex:
        logger.error("Working failed to initialize:")
        logger.exception(ex)
        params.outqueue.put(None)
        return

    get = params.inqueue.get
    put = params.outqueue.put

    completed = 0
    while True:
        try:
            task = get()
        except (EOFError, OSError):
            logger.debug("Worker %d encountered connection problem." % wid)
            break

        task_type, task_data = task
        if task_type == TASK_END:
            logger.debug("Worker %d got end task, exiting." % wid)
            try:
                rep = (task_type, True, wid, (wid, w.getReport()))
            except Exception as e:
                if params.wrap_exception:
                    e = multiprocessing.ExceptionWithTraceback(
                            e, e.__traceback__)
                rep = (task_type, False, wid, (wid, e))
            put(rep)
            break

        if task_type == TASK_JOB:
            task_data = (task_data,)

        for t in task_data:
            try:
                res = (TASK_JOB, True, wid, w.process(t))
            except Exception as e:
                if params.wrap_exception:
                    e = multiprocessing.ExceptionWithTraceback(
                            e, e.__traceback__)
                res = (TASK_JOB, False, wid, e)
            put(res)

            completed += 1

    logger.debug("Worker %d completed %d tasks." % (wid, completed))


class _WorkerParams(object):
    def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(),
                 wrap_exception=False, is_profiling=False):
        self.wid = wid
        self.inqueue = inqueue
        self.outqueue = outqueue
        self.worker_class = worker_class
        self.initargs = initargs
        self.wrap_exception = wrap_exception
        self.is_profiling = is_profiling


class WorkerPool(object):
    def __init__(self, worker_class, initargs=(),
                 worker_count=None, batch_size=None,
                 wrap_exception=False):
        worker_count = worker_count or os.cpu_count() or 1

        use_fastqueue = True
        if use_fastqueue:
            self._task_queue = FastQueue()
            self._result_queue = FastQueue()
            self._quick_put = self._task_queue.put
            self._quick_get = self._result_queue.get
        else:
            self._task_queue = multiprocessing.SimpleQueue()
            self._result_queue = multiprocessing.SimpleQueue()
            self._quick_put = self._task_queue._writer.send
            self._quick_get = self._result_queue._reader.recv

        self._batch_size = batch_size
        self._callback = None
        self._error_callback = None
        self._listener = None

        main_module = sys.modules['__main__']
        is_profiling = os.path.basename(main_module.__file__) in [
                'profile.py', 'cProfile.py']

        self._pool = []
        for i in range(worker_count):
            worker_params = _WorkerParams(
                    i, self._task_queue, self._result_queue,
                    worker_class, initargs,
                    wrap_exception=wrap_exception,
                    is_profiling=is_profiling)
            w = multiprocessing.Process(target=worker_func,
                                        args=(worker_params,))
            w.name = w.name.replace('Process', 'PoolWorker')
            w.daemon = True
            w.start()
            self._pool.append(w)

        self._result_handler = threading.Thread(
                target=WorkerPool._handleResults,
                args=(self,))
        self._result_handler.daemon = True
        self._result_handler.start()

        self._closed = False

    def setHandler(self, callback=None, error_callback=None):
        self._callback = callback
        self._error_callback = error_callback

    def queueJobs(self, jobs, handler=None, chunk_size=None):
        if self._closed:
            raise Exception("This worker pool has been closed.")
        if self._listener is not None:
            raise Exception("A previous job queue has not finished yet.")

        if any([not p.is_alive() for p in self._pool]):
            raise Exception("Some workers have prematurely exited.")

        if handler is not None:
            self.setHandler(handler)

        if not hasattr(jobs, '__len__'):
            jobs = list(jobs)
        job_count = len(jobs)

        res = AsyncResult(self, job_count)
        if res._count == 0:
            res._event.set()
            return res

        self._listener = res

        if chunk_size is None:
            chunk_size = self._batch_size
        if chunk_size is None:
            chunk_size = max(1, job_count // 50)
            logger.debug("Using chunk size of %d" % chunk_size)

        if chunk_size is None or chunk_size == 1:
            for job in jobs:
                self._quick_put((TASK_JOB, job))
        else:
            it = iter(jobs)
            while True:
                batch = tuple([i for i in itertools.islice(it, chunk_size)])
                if not batch:
                    break
                self._quick_put((TASK_BATCH, batch))

        return res

    def close(self):
        if self._listener is not None:
            raise Exception("A previous job queue has not finished yet.")

        logger.debug("Closing worker pool...")
        handler = _ReportHandler(len(self._pool))
        self._callback = handler._handle
        for w in self._pool:
            self._quick_put((TASK_END, None))
        for w in self._pool:
            w.join()

        logger.debug("Waiting for reports...")
        if not handler.wait(2):
            missing = handler.reports.index(None)
            logger.warning(
                    "Didn't receive all worker reports before timeout. "
                    "Missing report from worker %d." % missing)

        logger.debug("Exiting result handler thread...")
        self._result_queue.put(None)
        self._result_handler.join()
        self._closed = True

        return handler.reports

    @staticmethod
    def _handleResults(pool):
        while True:
            try:
                res = pool._quick_get()
            except (EOFError, OSError):
                logger.debug("Result handler thread encountered connection "
                             "problem, exiting.")
                return

            if res is None:
                logger.debug("Result handler exiting.")
                break

            task_type, success, wid, data = res
            try:
                if success and pool._callback:
                    pool._callback(data)
                elif not success:
                    if pool._error_callback:
                        pool._error_callback(data)
                    else:
                        logger.error(data)
            except Exception as ex:
                logger.exception(ex)

            if task_type == TASK_JOB:
                pool._listener._onTaskDone()


class AsyncResult(object):
    def __init__(self, pool, count):
        self._pool = pool
        self._count = count
        self._event = threading.Event()

    def ready(self):
        return self._event.is_set()

    def wait(self, timeout=None):
        return self._event.wait(timeout)

    def _onTaskDone(self):
        self._count -= 1
        if self._count == 0:
            self._pool.setHandler(None)
            self._pool._listener = None
            self._event.set()


class _ReportHandler(object):
    def __init__(self, worker_count):
        self.reports = [None] * worker_count
        self._count = worker_count
        self._received = 0
        self._event = threading.Event()

    def wait(self, timeout=None):
        return self._event.wait(timeout)

    def _handle(self, res):
        wid, data = res
        if wid < 0 or wid > self._count:
            logger.error("Ignoring report from unknown worker %d." % wid)
            return

        self._received += 1
        self.reports[wid] = data

        if self._received == self._count:
            self._event.set()

    def _handleError(self, res):
        wid, data = res
        logger.error("Worker %d failed to send its report." % wid)
        logger.exception(data)


class FastQueue(object):
    def __init__(self, compress=False):
        self._reader, self._writer = multiprocessing.Pipe(duplex=False)
        self._rlock = multiprocessing.Lock()
        self._wlock = multiprocessing.Lock()
        self._compress = compress

    def __getstate__(self):
        return (self._reader, self._writer, self._rlock, self._wlock,
                self._compress)

    def __setstate__(self, state):
        (self._reader, self._writer, self._rlock, self._wlock,
            self._compress) = state

    def get(self):
        with self._rlock:
            raw = self._reader.recv_bytes()
        if self._compress:
            data = zlib.decompress(raw)
        else:
            data = raw
        obj = unpickle(data)
        return obj

    def put(self, obj):
        data = pickle(obj)
        if self._compress:
            raw = zlib.compress(data)
        else:
            raw = data
        with self._wlock:
            self._writer.send_bytes(raw)