Mercurial > piecrust2
diff piecrust/processing/pipeline.py @ 447:aefe70229fdd
bake: Commonize worker pool code between html and asset baking.
The `workerpool` package now defines a generic-ish worker pool. It's similar
to the Python framework pool but with a simpler use-case (only one way to
queue jobs) and support for workers to send a final "report" to the master
process, which we use to get timing information here.
The rest of the changes basically remove a whole bunch of duplicated code
that's not needed anymore.
author | Ludovic Chabant <ludovic@chabant.com> |
---|---|
date | Sun, 05 Jul 2015 00:09:41 -0700 |
parents | 171dde4f61dc |
children | d90ccdf18156 |
line wrap: on
line diff
--- a/piecrust/processing/pipeline.py Thu Jul 02 23:28:24 2015 -0700 +++ b/piecrust/processing/pipeline.py Sun Jul 05 00:09:41 2015 -0700 @@ -2,7 +2,6 @@ import os.path import re import time -import queue import hashlib import logging import multiprocessing @@ -12,16 +11,16 @@ ProcessorPipelineRecordEntry, TransitionalProcessorPipelineRecord, FLAG_PROCESSED) from piecrust.processing.worker import ( - ProcessingWorkerContext, ProcessingWorkerJob, - worker_func, get_filtered_processors) + ProcessingWorkerJob, + get_filtered_processors) logger = logging.getLogger(__name__) class _ProcessingContext(object): - def __init__(self, pool, record, base_dir, mount_info): - self.pool = pool + def __init__(self, jobs, record, base_dir, mount_info): + self.jobs = jobs self.record = record self.base_dir = base_dir self.mount_info = mount_info @@ -94,9 +93,6 @@ self.ignore_patterns += make_re( pipeline_ctx._additional_ignore_patterns) - # Create the worker pool. - pool = _WorkerPool() - # Create the pipeline record. record = TransitionalProcessorPipelineRecord() record_cache = self.app.cache.getCache('proc') @@ -132,19 +128,19 @@ for e in entry.errors: logger.error(" " + e) + jobs = [] + self._process(src_dir_or_file, record, jobs) pool = self._createWorkerPool() - expected_result_count = self._process(src_dir_or_file, pool, record) - self._waitOnWorkerPool(pool, expected_result_count, _handler) - self._terminateWorkerPool(pool) + ar = pool.queueJobs(jobs, handler=_handler) + ar.wait() - # Get timing information from the workers. + # Shutdown the workers and get timing information from them. + reports = pool.close() record.current.timers = {} - for i in range(len(pool.workers)): - try: - timers = pool.results.get(True, 0.1) - except queue.Empty: - logger.error("Didn't get timing information from all workers.") - break + for i in range(len(reports)): + timers = reports[i] + if timers is None: + continue worker_name = 'PipelineWorker_%d' % i record.current.timers[worker_name] = {} @@ -185,9 +181,7 @@ return record.detach() - def _process(self, src_dir_or_file, pool, record): - expected_result_count = 0 - + def _process(self, src_dir_or_file, record, jobs): if src_dir_or_file is not None: # Process only the given path. # Find out what mount point this is in. @@ -203,28 +197,23 @@ "mount point: %s" % (src_dir_or_file, known_roots)) - ctx = _ProcessingContext(pool, record, base_dir, mount_info) + ctx = _ProcessingContext(jobs, record, base_dir, mount_info) logger.debug("Initiating processing pipeline on: %s" % src_dir_or_file) if os.path.isdir(src_dir_or_file): - expected_result_count = self._processDirectory( - ctx, src_dir_or_file) + self._processDirectory(ctx, src_dir_or_file) elif os.path.isfile(src_dir_or_file): self._processFile(ctx, src_dir_or_file) - expected_result_count = 1 else: # Process everything. for name, info in self.mounts.items(): path = info['path'] - ctx = _ProcessingContext(pool, record, path, info) + ctx = _ProcessingContext(jobs, record, path, info) logger.debug("Initiating processing pipeline on: %s" % path) - expected_result_count = self._processDirectory(ctx, path) - - return expected_result_count + self._processDirectory(ctx, path) def _processDirectory(self, ctx, start_dir): - queued_count = 0 for dirpath, dirnames, filenames in os.walk(start_dir): rel_dirpath = os.path.relpath(dirpath, start_dir) dirnames[:] = [d for d in dirnames @@ -235,8 +224,6 @@ if re_matchany(filename, self.ignore_patterns, rel_dirpath): continue self._processFile(ctx, os.path.join(dirpath, filename)) - queued_count += 1 - return queued_count def _processFile(self, ctx, path): # TODO: handle overrides between mount-points. @@ -250,79 +237,23 @@ job = ProcessingWorkerJob(ctx.base_dir, ctx.mount_info, path, force=force_this) - - logger.debug("Queuing: %s" % path) - ctx.pool.queue.put_nowait(job) + ctx.jobs.append(job) def _createWorkerPool(self): - import sys - - main_module = sys.modules['__main__'] - is_profiling = os.path.basename(main_module.__file__) in [ - 'profile.py', 'cProfile.py'] - - pool = _WorkerPool() - for i in range(self.num_workers): - ctx = ProcessingWorkerContext( - self.app.root_dir, self.out_dir, self.tmp_dir, - pool.queue, pool.results, pool.abort_event, - self.force, self.app.debug) - ctx.is_profiling = is_profiling - ctx.enabled_processors = self.enabled_processors - ctx.additional_processors = self.additional_processors - w = multiprocessing.Process( - name='PipelineWorker_%d' % i, - target=worker_func, args=(i, ctx)) - w.start() - pool.workers.append(w) - return pool + from piecrust.workerpool import WorkerPool + from piecrust.processing.worker import ( + ProcessingWorkerContext, ProcessingWorker) - def _waitOnWorkerPool(self, pool, expected_result_count, result_handler): - abort_with_exception = None - try: - got_count = 0 - while got_count < expected_result_count: - try: - res = pool.results.get(True, 10) - except queue.Empty: - logger.error( - "Got %d results, expected %d, and timed-out " - "for 10 seconds. A worker might be stuck?" % - (got_count, expected_result_count)) - abort_with_exception = Exception("Worker time-out.") - break - - if isinstance(res, dict) and res.get('type') == 'error': - abort_with_exception = Exception( - 'Worker critical error:\n' + - '\n'.join(res['messages'])) - break + ctx = ProcessingWorkerContext( + self.app.root_dir, self.out_dir, self.tmp_dir, + self.force, self.app.debug) + ctx.enabled_processors = self.enabled_processors + ctx.additional_processors = self.additional_processors - got_count += 1 - result_handler(res) - except KeyboardInterrupt as kiex: - logger.warning("Bake aborted by user... " - "waiting for workers to stop.") - abort_with_exception = kiex - - if abort_with_exception: - pool.abort_event.set() - for w in pool.workers: - w.join(2) - raise abort_with_exception - - def _terminateWorkerPool(self, pool): - pool.abort_event.set() - for w in pool.workers: - w.join() - - -class _WorkerPool(object): - def __init__(self): - self.queue = multiprocessing.JoinableQueue() - self.results = multiprocessing.Queue() - self.abort_event = multiprocessing.Event() - self.workers = [] + pool = WorkerPool( + worker_class=ProcessingWorker, + initargs=(ctx,)) + return pool def make_mount_infos(mounts, root_dir):