Mercurial > piecrust2
changeset 447:aefe70229fdd
bake: Commonize worker pool code between html and asset baking.
The `workerpool` package now defines a generic-ish worker pool. It's similar
to the Python framework pool but with a simpler use-case (only one way to
queue jobs) and support for workers to send a final "report" to the master
process, which we use to get timing information here.
The rest of the changes basically remove a whole bunch of duplicated code
that's not needed anymore.
author | Ludovic Chabant <ludovic@chabant.com> |
---|---|
date | Sun, 05 Jul 2015 00:09:41 -0700 |
parents | 4cdf6c2157a0 |
children | a17774094db8 |
files | piecrust/baking/baker.py piecrust/baking/worker.py piecrust/processing/pipeline.py piecrust/processing/worker.py piecrust/workerpool.py |
diffstat | 5 files changed, 448 insertions(+), 425 deletions(-) [+] |
line wrap: on
line diff
--- a/piecrust/baking/baker.py Thu Jul 02 23:28:24 2015 -0700 +++ b/piecrust/baking/baker.py Sun Jul 05 00:09:41 2015 -0700 @@ -103,17 +103,13 @@ # Bake taxonomies. self._bakeTaxonomies(record, pool) - # All done with the workers. - self._terminateWorkerPool(pool) - - # Get the timing information from the workers. + # All done with the workers. Close the pool and get timing reports. + reports = pool.close() record.current.timers = {} - for i in range(len(pool.workers)): - try: - timers = pool.results.get(True, 0.1) - except queue.Empty: - logger.error("Didn't get timing information from all workers.") - break + for i in range(len(reports)): + timers = reports[i] + if timers is None: + continue worker_name = 'BakeWorker_%d' % i record.current.timers[worker_name] = {} @@ -214,41 +210,43 @@ (page_count, REALM_NAMES[realm].lower()))) def _loadRealmPages(self, record, pool, factories): + def _handler(res): + # Create the record entry for this page. + record_entry = BakeRecordEntry(res.source_name, res.path) + record_entry.config = res.config + if res.errors: + record_entry.errors += res.errors + record.current.success = False + self._logErrors(res.path, res.errors) + record.addEntry(record_entry) + logger.debug("Loading %d realm pages..." % len(factories)) with format_timed_scope(logger, "loaded %d pages" % len(factories), level=logging.DEBUG, colored=False, timer_env=self.app.env, timer_category='LoadJob'): - for fac in factories: - job = BakeWorkerJob( - JOB_LOAD, - LoadJobPayload(fac)) - pool.queue.put_nowait(job) - - def _handler(res): - # Create the record entry for this page. - record_entry = BakeRecordEntry(res.source_name, res.path) - record_entry.config = res.config - if res.errors: - record_entry.errors += res.errors - record.current.success = False - self._logErrors(res.path, res.errors) - record.addEntry(record_entry) - - self._waitOnWorkerPool( - pool, - expected_result_count=len(factories), - result_handler=_handler) + jobs = [ + BakeWorkerJob(JOB_LOAD, LoadJobPayload(fac)) + for fac in factories] + ar = pool.queueJobs(jobs, handler=_handler) + ar.wait() def _renderRealmPages(self, record, pool, factories): + def _handler(res): + entry = record.getCurrentEntry(res.path) + if res.errors: + entry.errors += res.errors + record.current.success = False + self._logErrors(res.path, res.errors) + logger.debug("Rendering %d realm pages..." % len(factories)) with format_timed_scope(logger, "prepared %d pages" % len(factories), level=logging.DEBUG, colored=False, timer_env=self.app.env, timer_category='RenderFirstSubJob'): - expected_result_count = 0 + jobs = [] for fac in factories: record_entry = record.getCurrentEntry(fac.path) if record_entry.errors: @@ -278,49 +276,38 @@ job = BakeWorkerJob( JOB_RENDER_FIRST, RenderFirstSubJobPayload(fac)) - pool.queue.put_nowait(job) - expected_result_count += 1 + jobs.append(job) - def _handler(res): - entry = record.getCurrentEntry(res.path) - if res.errors: - entry.errors += res.errors - record.current.success = False - self._logErrors(res.path, res.errors) - - self._waitOnWorkerPool( - pool, - expected_result_count=expected_result_count, - result_handler=_handler) + ar = pool.queueJobs(jobs, handler=_handler) + ar.wait() def _bakeRealmPages(self, record, pool, realm, factories): + def _handler(res): + entry = record.getCurrentEntry(res.path, res.taxonomy_info) + entry.subs = res.sub_entries + if res.errors: + entry.errors += res.errors + self._logErrors(res.path, res.errors) + if entry.has_any_error: + record.current.success = False + if entry.was_any_sub_baked: + record.current.baked_count[realm] += 1 + record.dirty_source_names.add(entry.source_name) + logger.debug("Baking %d realm pages..." % len(factories)) with format_timed_scope(logger, "baked %d pages" % len(factories), level=logging.DEBUG, colored=False, timer_env=self.app.env, timer_category='BakeJob'): - expected_result_count = 0 + jobs = [] for fac in factories: - if self._queueBakeJob(record, pool, fac): - expected_result_count += 1 + job = self._makeBakeJob(record, fac) + if job is not None: + jobs.append(job) - def _handler(res): - entry = record.getCurrentEntry(res.path, res.taxonomy_info) - entry.subs = res.sub_entries - if res.errors: - entry.errors += res.errors - self._logErrors(res.path, res.errors) - if entry.has_any_error: - record.current.success = False - if entry.was_any_sub_baked: - record.current.baked_count[realm] += 1 - record.dirty_source_names.add(entry.source_name) - - self._waitOnWorkerPool( - pool, - expected_result_count=expected_result_count, - result_handler=_handler) + ar = pool.queueJobs(jobs, handler=_handler) + ar.wait() def _bakeTaxonomies(self, record, pool): logger.debug("Baking taxonomy pages...") @@ -400,8 +387,16 @@ return buckets def _bakeTaxonomyBuckets(self, record, pool, buckets): + def _handler(res): + entry = record.getCurrentEntry(res.path, res.taxonomy_info) + entry.subs = res.sub_entries + if res.errors: + entry.errors += res.errors + if entry.has_any_error: + record.current.success = False + # Start baking those terms. - expected_result_count = 0 + jobs = [] for source_name, source_taxonomies in buckets.items(): for tax_name, tt_info in source_taxonomies.items(): terms = tt_info.dirty_terms @@ -435,21 +430,12 @@ fac.source.name, fac.path, tax_info) record.addEntry(cur_entry) - if self._queueBakeJob(record, pool, fac, tax_info): - expected_result_count += 1 + job = self._makeBakeJob(record, fac, tax_info) + if job is not None: + jobs.append(job) - def _handler(res): - entry = record.getCurrentEntry(res.path, res.taxonomy_info) - entry.subs = res.sub_entries - if res.errors: - entry.errors += res.errors - if entry.has_any_error: - record.current.success = False - - self._waitOnWorkerPool( - pool, - expected_result_count=expected_result_count, - result_handler=_handler) + ar = pool.queueJobs(jobs, handler=_handler) + ar.wait() # Now we create bake entries for all the terms that were *not* dirty. # This is because otherwise, on the next incremental bake, we wouldn't @@ -470,9 +456,9 @@ logger.debug("Taxonomy term '%s:%s' isn't used anymore." % (ti.taxonomy_name, ti.term)) - return expected_result_count + return len(jobs) - def _queueBakeJob(self, record, pool, fac, tax_info=None): + def _makeBakeJob(self, record, fac, tax_info=None): # Get the previous (if any) and current entry for this page. pair = record.getPreviousAndCurrentEntries(fac.path, tax_info) assert pair is not None @@ -483,7 +469,7 @@ if cur_entry.errors: logger.debug("Ignoring %s because it had previous " "errors." % fac.ref_spec) - return False + return None # Build the route metadata and find the appropriate route. page = fac.buildPage() @@ -515,15 +501,14 @@ (fac.ref_spec, uri, override_entry.path)) logger.error(cur_entry.errors[-1]) cur_entry.flags |= BakeRecordEntry.FLAG_OVERRIDEN - return False + return None job = BakeWorkerJob( JOB_BAKE, BakeJobPayload(fac, route_metadata, prev_entry, record.dirty_source_names, tax_info)) - pool.queue.put_nowait(job) - return True + return job def _handleDeletetions(self, record): logger.debug("Handling deletions...") @@ -544,78 +529,16 @@ logger.error(" " + e) def _createWorkerPool(self): - import sys - from piecrust.baking.worker import BakeWorkerContext, worker_func - - main_module = sys.modules['__main__'] - is_profiling = os.path.basename(main_module.__file__) in [ - 'profile.py', 'cProfile.py'] - - pool = _WorkerPool() - for i in range(self.num_workers): - ctx = BakeWorkerContext( - self.app.root_dir, self.app.cache.base_dir, self.out_dir, - pool.queue, pool.results, pool.abort_event, - force=self.force, debug=self.app.debug, - is_profiling=is_profiling) - w = multiprocessing.Process( - name='BakeWorker_%d' % i, - target=worker_func, args=(i, ctx)) - w.start() - pool.workers.append(w) - return pool - - def _terminateWorkerPool(self, pool): - pool.abort_event.set() - for w in pool.workers: - w.join() + from piecrust.workerpool import WorkerPool + from piecrust.baking.worker import BakeWorkerContext, BakeWorker - def _waitOnWorkerPool(self, pool, - expected_result_count=-1, result_handler=None): - assert result_handler is None or expected_result_count >= 0 - abort_with_exception = None - try: - if result_handler is None: - pool.queue.join() - else: - got_count = 0 - while got_count < expected_result_count: - try: - res = pool.results.get(True, 10) - except queue.Empty: - logger.error( - "Got %d results, expected %d, and timed-out " - "for 10 seconds. A worker might be stuck?" % - (got_count, expected_result_count)) - abort_with_exception = Exception("Worker time-out.") - break - - if isinstance(res, dict) and res.get('type') == 'error': - abort_with_exception = Exception( - 'Worker critical error:\n' + - '\n'.join(res['messages'])) - break - - got_count += 1 - result_handler(res) - except KeyboardInterrupt as kiex: - logger.warning("Bake aborted by user... " - "waiting for workers to stop.") - abort_with_exception = kiex - - if abort_with_exception: - pool.abort_event.set() - for w in pool.workers: - w.join(2) - raise abort_with_exception - - -class _WorkerPool(object): - def __init__(self): - self.queue = multiprocessing.JoinableQueue() - self.results = multiprocessing.Queue() - self.abort_event = multiprocessing.Event() - self.workers = [] + ctx = BakeWorkerContext( + self.app.root_dir, self.app.cache.base_dir, self.out_dir, + force=self.force, debug=self.app.debug) + pool = WorkerPool( + worker_class=BakeWorker, + initargs=(ctx,)) + return pool class _TaxonomyTermsInfo(object):
--- a/piecrust/baking/worker.py Thu Jul 02 23:28:24 2015 -0700 +++ b/piecrust/baking/worker.py Sun Jul 05 00:09:41 2015 -0700 @@ -1,5 +1,4 @@ import time -import queue import logging from piecrust.app import PieCrust from piecrust.baking.single import PageBaker, BakingError @@ -7,45 +6,59 @@ QualifiedPage, PageRenderingContext, render_page_segments) from piecrust.routing import create_route_metadata from piecrust.sources.base import PageFactory +from piecrust.workerpool import IWorker logger = logging.getLogger(__name__) -def worker_func(wid, ctx): - if ctx.is_profiling: - try: - import cProfile as profile - except ImportError: - import profile - - ctx.is_profiling = False - profile.runctx('_real_worker_func(wid, ctx)', - globals(), locals(), - filename='BakeWorker-%d.prof' % wid) - else: - _real_worker_func(wid, ctx) - - -def _real_worker_func(wid, ctx): - logger.debug("Worker %d booting up..." % wid) - w = BakeWorker(wid, ctx) - w.run() - - class BakeWorkerContext(object): def __init__(self, root_dir, sub_cache_dir, out_dir, - work_queue, results, abort_event, - force=False, debug=False, is_profiling=False): + force=False, debug=False): self.root_dir = root_dir self.sub_cache_dir = sub_cache_dir self.out_dir = out_dir - self.work_queue = work_queue - self.results = results - self.abort_event = abort_event self.force = force self.debug = debug - self.is_profiling = is_profiling + + +class BakeWorker(IWorker): + def __init__(self, ctx): + self.ctx = ctx + self.work_start_time = time.perf_counter() + + def initialize(self): + # Create the app local to this worker. + app = PieCrust(self.ctx.root_dir, debug=self.ctx.debug) + app._useSubCacheDir(self.ctx.sub_cache_dir) + app.env.fs_cache_only_for_main_page = True + app.env.registerTimer("BakeWorker_%d_Total" % self.wid) + app.env.registerTimer("BakeWorkerInit") + app.env.registerTimer("JobReceive") + self.app = app + + # Create the job handlers. + job_handlers = { + JOB_LOAD: LoadJobHandler(app, self.ctx), + JOB_RENDER_FIRST: RenderFirstSubJobHandler(app, self.ctx), + JOB_BAKE: BakeJobHandler(app, self.ctx)} + for jt, jh in job_handlers.items(): + app.env.registerTimer(type(jh).__name__) + self.job_handlers = job_handlers + + app.env.stepTimerSince("BakeWorkerInit", self.work_start_time) + + def process(self, job): + handler = self.job_handlers[job.job_type] + with self.app.env.timerScope(type(handler).__name__): + return handler.handleJob(job) + + def getReport(self): + self.app.env.stepTimerSince("BakeWorker_%d_Total" % self.wid, + self.work_start_time) + return { + 'type': 'timers', + 'data': self.app.env._timers} JOB_LOAD, JOB_RENDER_FIRST, JOB_BAKE = range(0, 3) @@ -57,67 +70,6 @@ self.payload = payload -class BakeWorker(object): - def __init__(self, wid, ctx): - self.wid = wid - self.ctx = ctx - - def run(self): - logger.debug("Working %d initializing..." % self.wid) - work_start_time = time.perf_counter() - - # Create the app local to this worker. - app = PieCrust(self.ctx.root_dir, debug=self.ctx.debug) - app._useSubCacheDir(self.ctx.sub_cache_dir) - app.env.fs_cache_only_for_main_page = True - app.env.registerTimer("BakeWorker_%d_Total" % self.wid) - app.env.registerTimer("BakeWorkerInit") - app.env.registerTimer("JobReceive") - - # Create the job handlers. - job_handlers = { - JOB_LOAD: LoadJobHandler(app, self.ctx), - JOB_RENDER_FIRST: RenderFirstSubJobHandler(app, self.ctx), - JOB_BAKE: BakeJobHandler(app, self.ctx)} - for jt, jh in job_handlers.items(): - app.env.registerTimer(type(jh).__name__) - - app.env.stepTimerSince("BakeWorkerInit", work_start_time) - - # Start working! - aborted_with_exception = None - while not self.ctx.abort_event.is_set(): - try: - with app.env.timerScope('JobReceive'): - job = self.ctx.work_queue.get(True, 0.01) - except queue.Empty: - continue - - try: - handler = job_handlers[job.job_type] - with app.env.timerScope(type(handler).__name__): - handler.handleJob(job) - except Exception as ex: - self.ctx.abort_event.set() - aborted_with_exception = ex - logger.debug("[%d] Critical error, aborting." % self.wid) - if self.ctx.debug: - logger.exception(ex) - break - finally: - self.ctx.work_queue.task_done() - - if aborted_with_exception is not None: - msgs = _get_errors(aborted_with_exception) - self.ctx.results.put_nowait({'type': 'error', 'messages': msgs}) - - # Send our timers to the main process before exiting. - app.env.stepTimerSince("BakeWorker_%d_Total" % self.wid, - work_start_time) - self.ctx.results.put_nowait({ - 'type': 'timers', 'data': app.env._timers}) - - class JobHandler(object): def __init__(self, app, ctx): self.app = app @@ -203,8 +155,7 @@ result.errors = _get_errors(ex) if self.ctx.debug: logger.exception(ex) - - self.ctx.results.put_nowait(result) + return result class RenderFirstSubJobHandler(JobHandler): @@ -231,8 +182,7 @@ result.errors = _get_errors(ex) if self.ctx.debug: logger.exception(ex) - - self.ctx.results.put_nowait(result) + return result class BakeJobHandler(JobHandler): @@ -272,5 +222,5 @@ if self.ctx.debug: logger.exception(ex) - self.ctx.results.put_nowait(result) + return result
--- a/piecrust/processing/pipeline.py Thu Jul 02 23:28:24 2015 -0700 +++ b/piecrust/processing/pipeline.py Sun Jul 05 00:09:41 2015 -0700 @@ -2,7 +2,6 @@ import os.path import re import time -import queue import hashlib import logging import multiprocessing @@ -12,16 +11,16 @@ ProcessorPipelineRecordEntry, TransitionalProcessorPipelineRecord, FLAG_PROCESSED) from piecrust.processing.worker import ( - ProcessingWorkerContext, ProcessingWorkerJob, - worker_func, get_filtered_processors) + ProcessingWorkerJob, + get_filtered_processors) logger = logging.getLogger(__name__) class _ProcessingContext(object): - def __init__(self, pool, record, base_dir, mount_info): - self.pool = pool + def __init__(self, jobs, record, base_dir, mount_info): + self.jobs = jobs self.record = record self.base_dir = base_dir self.mount_info = mount_info @@ -94,9 +93,6 @@ self.ignore_patterns += make_re( pipeline_ctx._additional_ignore_patterns) - # Create the worker pool. - pool = _WorkerPool() - # Create the pipeline record. record = TransitionalProcessorPipelineRecord() record_cache = self.app.cache.getCache('proc') @@ -132,19 +128,19 @@ for e in entry.errors: logger.error(" " + e) + jobs = [] + self._process(src_dir_or_file, record, jobs) pool = self._createWorkerPool() - expected_result_count = self._process(src_dir_or_file, pool, record) - self._waitOnWorkerPool(pool, expected_result_count, _handler) - self._terminateWorkerPool(pool) + ar = pool.queueJobs(jobs, handler=_handler) + ar.wait() - # Get timing information from the workers. + # Shutdown the workers and get timing information from them. + reports = pool.close() record.current.timers = {} - for i in range(len(pool.workers)): - try: - timers = pool.results.get(True, 0.1) - except queue.Empty: - logger.error("Didn't get timing information from all workers.") - break + for i in range(len(reports)): + timers = reports[i] + if timers is None: + continue worker_name = 'PipelineWorker_%d' % i record.current.timers[worker_name] = {} @@ -185,9 +181,7 @@ return record.detach() - def _process(self, src_dir_or_file, pool, record): - expected_result_count = 0 - + def _process(self, src_dir_or_file, record, jobs): if src_dir_or_file is not None: # Process only the given path. # Find out what mount point this is in. @@ -203,28 +197,23 @@ "mount point: %s" % (src_dir_or_file, known_roots)) - ctx = _ProcessingContext(pool, record, base_dir, mount_info) + ctx = _ProcessingContext(jobs, record, base_dir, mount_info) logger.debug("Initiating processing pipeline on: %s" % src_dir_or_file) if os.path.isdir(src_dir_or_file): - expected_result_count = self._processDirectory( - ctx, src_dir_or_file) + self._processDirectory(ctx, src_dir_or_file) elif os.path.isfile(src_dir_or_file): self._processFile(ctx, src_dir_or_file) - expected_result_count = 1 else: # Process everything. for name, info in self.mounts.items(): path = info['path'] - ctx = _ProcessingContext(pool, record, path, info) + ctx = _ProcessingContext(jobs, record, path, info) logger.debug("Initiating processing pipeline on: %s" % path) - expected_result_count = self._processDirectory(ctx, path) - - return expected_result_count + self._processDirectory(ctx, path) def _processDirectory(self, ctx, start_dir): - queued_count = 0 for dirpath, dirnames, filenames in os.walk(start_dir): rel_dirpath = os.path.relpath(dirpath, start_dir) dirnames[:] = [d for d in dirnames @@ -235,8 +224,6 @@ if re_matchany(filename, self.ignore_patterns, rel_dirpath): continue self._processFile(ctx, os.path.join(dirpath, filename)) - queued_count += 1 - return queued_count def _processFile(self, ctx, path): # TODO: handle overrides between mount-points. @@ -250,79 +237,23 @@ job = ProcessingWorkerJob(ctx.base_dir, ctx.mount_info, path, force=force_this) - - logger.debug("Queuing: %s" % path) - ctx.pool.queue.put_nowait(job) + ctx.jobs.append(job) def _createWorkerPool(self): - import sys - - main_module = sys.modules['__main__'] - is_profiling = os.path.basename(main_module.__file__) in [ - 'profile.py', 'cProfile.py'] - - pool = _WorkerPool() - for i in range(self.num_workers): - ctx = ProcessingWorkerContext( - self.app.root_dir, self.out_dir, self.tmp_dir, - pool.queue, pool.results, pool.abort_event, - self.force, self.app.debug) - ctx.is_profiling = is_profiling - ctx.enabled_processors = self.enabled_processors - ctx.additional_processors = self.additional_processors - w = multiprocessing.Process( - name='PipelineWorker_%d' % i, - target=worker_func, args=(i, ctx)) - w.start() - pool.workers.append(w) - return pool + from piecrust.workerpool import WorkerPool + from piecrust.processing.worker import ( + ProcessingWorkerContext, ProcessingWorker) - def _waitOnWorkerPool(self, pool, expected_result_count, result_handler): - abort_with_exception = None - try: - got_count = 0 - while got_count < expected_result_count: - try: - res = pool.results.get(True, 10) - except queue.Empty: - logger.error( - "Got %d results, expected %d, and timed-out " - "for 10 seconds. A worker might be stuck?" % - (got_count, expected_result_count)) - abort_with_exception = Exception("Worker time-out.") - break - - if isinstance(res, dict) and res.get('type') == 'error': - abort_with_exception = Exception( - 'Worker critical error:\n' + - '\n'.join(res['messages'])) - break + ctx = ProcessingWorkerContext( + self.app.root_dir, self.out_dir, self.tmp_dir, + self.force, self.app.debug) + ctx.enabled_processors = self.enabled_processors + ctx.additional_processors = self.additional_processors - got_count += 1 - result_handler(res) - except KeyboardInterrupt as kiex: - logger.warning("Bake aborted by user... " - "waiting for workers to stop.") - abort_with_exception = kiex - - if abort_with_exception: - pool.abort_event.set() - for w in pool.workers: - w.join(2) - raise abort_with_exception - - def _terminateWorkerPool(self, pool): - pool.abort_event.set() - for w in pool.workers: - w.join() - - -class _WorkerPool(object): - def __init__(self): - self.queue = multiprocessing.JoinableQueue() - self.results = multiprocessing.Queue() - self.abort_event = multiprocessing.Event() - self.workers = [] + pool = WorkerPool( + worker_class=ProcessingWorker, + initargs=(ctx,)) + return pool def make_mount_infos(mounts, root_dir):
--- a/piecrust/processing/worker.py Thu Jul 02 23:28:24 2015 -0700 +++ b/piecrust/processing/worker.py Sun Jul 05 00:09:41 2015 -0700 @@ -1,7 +1,6 @@ +import re import os.path -import re import time -import queue import logging from piecrust.app import PieCrust from piecrust.processing.base import PipelineContext @@ -13,6 +12,7 @@ ProcessingTreeError, ProcessorError, get_node_name_tree, print_node, STATE_DIRTY) +from piecrust.workerpool import IWorker logger = logging.getLogger(__name__) @@ -22,37 +22,12 @@ re_ansicolors = re.compile('\033\\[\d+m') -def worker_func(wid, ctx): - if ctx.is_profiling: - try: - import cProfile as profile - except ImportError: - import profile - - ctx.is_profiling = False - profile.runctx('_real_worker_func(wid, ctx)', - globals(), locals(), - filename='PipelineWorker-%d.prof' % wid) - else: - _real_worker_func(wid, ctx) - - -def _real_worker_func(wid, ctx): - logger.debug("Worker %d booting up..." % wid) - w = ProcessingWorker(wid, ctx) - w.run() - - class ProcessingWorkerContext(object): def __init__(self, root_dir, out_dir, tmp_dir, - work_queue, results, abort_event, force=False, debug=False): self.root_dir = root_dir self.out_dir = out_dir self.tmp_dir = tmp_dir - self.work_queue = work_queue - self.results = results - self.abort_event = abort_event self.force = force self.debug = debug self.is_profiling = False @@ -77,23 +52,20 @@ self.errors = None -class ProcessingWorker(object): - def __init__(self, wid, ctx): - self.wid = wid +class ProcessingWorker(IWorker): + def __init__(self, ctx): self.ctx = ctx + self.work_start_time = time.perf_counter() - def run(self): - logger.debug("Worker %d initializing..." % self.wid) - work_start_time = time.perf_counter() - + def initialize(self): # Create the app local to this worker. app = PieCrust(self.ctx.root_dir, debug=self.ctx.debug) - app.env.fs_cache_only_for_main_page = True app.env.registerTimer("PipelineWorker_%d_Total" % self.wid) app.env.registerTimer("PipelineWorkerInit") app.env.registerTimer("JobReceive") app.env.registerTimer('BuildProcessingTree') app.env.registerTimer('RunProcessingTree') + self.app = app processors = app.plugin_loader.getProcessors() if self.ctx.enabled_processors: @@ -108,9 +80,10 @@ app.env.registerTimer(proc.__class__.__name__) proc.initialize(app) processors.append(proc) + self.processors = processors # Invoke pre-processors. - pipeline_ctx = PipelineContext(self.wid, app, self.ctx.out_dir, + pipeline_ctx = PipelineContext(self.wid, self.app, self.ctx.out_dir, self.ctx.tmp_dir, self.ctx.force) for proc in processors: proc.onPipelineStart(pipeline_ctx) @@ -119,52 +92,18 @@ # patching the processors with some new ones. processors.sort(key=lambda p: p.priority) - app.env.stepTimerSince("PipelineWorkerInit", work_start_time) - - aborted_with_exception = None - while not self.ctx.abort_event.is_set(): - try: - with app.env.timerScope('JobReceive'): - job = self.ctx.work_queue.get(True, 0.01) - except queue.Empty: - continue + app.env.stepTimerSince("PipelineWorkerInit", self.work_start_time) - try: - result = self._unsafeRun(app, processors, job) - self.ctx.results.put_nowait(result) - except Exception as ex: - self.ctx.abort_event.set() - aborted_with_exception = ex - logger.error("[%d] Critical error, aborting." % self.wid) - if self.ctx.debug: - logger.exception(ex) - break - finally: - self.ctx.work_queue.task_done() - - if aborted_with_exception is not None: - msgs = _get_errors(aborted_with_exception) - self.ctx.results.put_nowait({'type': 'error', 'messages': msgs}) - - # Invoke post-processors. - for proc in processors: - proc.onPipelineEnd(pipeline_ctx) - - app.env.stepTimerSince("PipelineWorker_%d_Total" % self.wid, - work_start_time) - self.ctx.results.put_nowait({ - 'type': 'timers', 'data': app.env._timers}) - - def _unsafeRun(self, app, processors, job): + def process(self, job): result = ProcessingWorkerResult(job.path) processors = get_filtered_processors( - processors, job.mount_info['processors']) + self.processors, job.mount_info['processors']) # Build the processing tree for this job. rel_path = os.path.relpath(job.path, job.base_dir) try: - with app.env.timerScope('BuildProcessingTree'): + with self.app.env.timerScope('BuildProcessingTree'): builder = ProcessingTreeBuilder(processors) tree_root = builder.build(rel_path) result.flags |= FLAG_PREPARED @@ -184,7 +123,7 @@ tree_root.setState(STATE_DIRTY, True) try: - with app.env.timerScope('RunProcessingTree'): + with self.app.env.timerScope('RunProcessingTree'): runner = ProcessingTreeRunner( job.base_dir, self.ctx.tmp_dir, self.ctx.out_dir) if runner.processSubTree(tree_root): @@ -197,6 +136,19 @@ return result + def getReport(self): + # Invoke post-processors. + pipeline_ctx = PipelineContext(self.wid, self.app, self.ctx.out_dir, + self.ctx.tmp_dir, self.ctx.force) + for proc in self.processors: + proc.onPipelineEnd(pipeline_ctx) + + self.app.env.stepTimerSince("PipelineWorker_%d_Total" % self.wid, + self.work_start_time) + return { + 'type': 'timers', + 'data': self.app.env._timers} + def get_filtered_processors(processors, authorized_names): if not authorized_names or authorized_names == 'all':
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/piecrust/workerpool.py Sun Jul 05 00:09:41 2015 -0700 @@ -0,0 +1,267 @@ +import os +import sys +import logging +import threading +import multiprocessing + + +logger = logging.getLogger(__name__) + + +class IWorker(object): + def initialize(self): + raise NotImplementedError() + + def process(self, job): + raise NotImplementedError() + + def getReport(self): + return None + + +TASK_JOB = 0 +TASK_END = 1 + + +def worker_func(params): + if params.is_profiling: + try: + import cProfile as profile + except ImportError: + import profile + + params.is_profiling = False + name = params.worker_class.__name__ + profile.runctx('_real_worker_func(params)', + globals(), locals(), + filename='%s-%d.prof' % (name, params.wid)) + else: + _real_worker_func(params) + + +def _real_worker_func(params): + if hasattr(params.inqueue, '_writer'): + params.inqueue._writer.close() + params.outqueue._reader.close() + + wid = params.wid + logger.debug("Worker %d initializing..." % wid) + + w = params.worker_class(*params.initargs) + w.wid = wid + w.initialize() + + get = params.inqueue.get + put = params.outqueue.put + + completed = 0 + while True: + try: + task = get() + except (EOFError, OSError): + logger.debug("Worker %d encountered connection problem." % wid) + break + + task_type, task_data = task + if task_type == TASK_END: + logger.debug("Worker %d got end task, exiting." % wid) + try: + rep = (task_type, True, wid, (wid, w.getReport())) + except Exception as e: + if params.wrap_exception: + e = multiprocessing.ExceptionWithTraceback( + e, e.__traceback__) + rep = (task_type, False, wid, (wid, e)) + put(rep) + break + + try: + res = (task_type, True, wid, w.process(task_data)) + except Exception as e: + if params.wrap_exception: + e = multiprocessing.ExceptionWithTraceback(e, e.__traceback__) + res = (task_type, False, wid, e) + put(res) + + completed += 1 + + logger.debug("Worker %d completed %d tasks." % (wid, completed)) + + +class _WorkerParams(object): + def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(), + wrap_exception=False, is_profiling=False): + self.wid = wid + self.inqueue = inqueue + self.outqueue = outqueue + self.worker_class = worker_class + self.initargs = initargs + self.wrap_exception = wrap_exception + self.is_profiling = is_profiling + + +class WorkerPool(object): + def __init__(self, worker_class, worker_count=None, initargs=()): + worker_count = worker_count or os.cpu_count() or 1 + + self._task_queue = multiprocessing.SimpleQueue() + self._result_queue = multiprocessing.SimpleQueue() + self._quick_put = self._task_queue._writer.send + self._quick_get = self._result_queue._reader.recv + + self._callback = None + self._error_callback = None + self._listener = None + + main_module = sys.modules['__main__'] + is_profiling = os.path.basename(main_module.__file__) in [ + 'profile.py', 'cProfile.py'] + + self._pool = [] + for i in range(worker_count): + worker_params = _WorkerParams( + i, self._task_queue, self._result_queue, + worker_class, initargs, + is_profiling=is_profiling) + w = multiprocessing.Process(target=worker_func, + args=(worker_params,)) + w.name = w.name.replace('Process', 'PoolWorker') + w.daemon = True + w.start() + self._pool.append(w) + + self._result_handler = threading.Thread( + target=WorkerPool._handleResults, + args=(self,)) + self._result_handler.daemon = True + self._result_handler.start() + + self._closed = False + + def setHandler(self, callback=None, error_callback=None): + self._callback = callback + self._error_callback = error_callback + + def queueJobs(self, jobs, handler=None): + if self._closed: + raise Exception("This worker pool has been closed.") + if self._listener is not None: + raise Exception("A previous job queue has not finished yet.") + + if handler is not None: + self.setHandler(handler) + + if not hasattr(jobs, '__len__'): + jobs = list(jobs) + + res = AsyncResult(self, len(jobs)) + if res._count == 0: + res._event.set() + return res + + self._listener = res + for job in jobs: + self._quick_put((TASK_JOB, job)) + + return res + + def close(self): + if self._listener is not None: + raise Exception("A previous job queue has not finished yet.") + + logger.debug("Closing worker pool...") + handler = _ReportHandler(len(self._pool)) + self._callback = handler._handle + for w in self._pool: + self._quick_put((TASK_END, None)) + for w in self._pool: + w.join() + + logger.debug("Waiting for reports...") + if not handler.wait(2): + missing = handler.reports.index(None) + logger.warning( + "Didn't receive all worker reports before timeout. " + "Missing report from worker %d." % missing) + + logger.debug("Exiting result handler thread...") + self._result_queue.put(None) + self._result_handler.join() + self._closed = True + + return handler.reports + + @staticmethod + def _handleResults(pool): + while True: + try: + res = pool._quick_get() + except (EOFError, OSError): + logger.debug("Result handler thread encountered connection " + "problem, exiting.") + return + + if res is None: + logger.debug("Result handler exiting.") + break + + task_type, success, wid, data = res + try: + if success and pool._callback: + pool._callback(data) + elif not success and pool._error_callback: + pool._error_callback(data) + except Exception as ex: + logger.exception(ex) + + if task_type == TASK_JOB: + pool._listener._onTaskDone() + + +class AsyncResult(object): + def __init__(self, pool, count): + self._pool = pool + self._count = count + self._event = threading.Event() + + def ready(self): + return self._event.is_set() + + def wait(self, timeout=None): + return self._event.wait(timeout) + + def _onTaskDone(self): + self._count -= 1 + if self._count == 0: + self._pool.setHandler(None) + self._pool._listener = None + self._event.set() + + +class _ReportHandler(object): + def __init__(self, worker_count): + self.reports = [None] * worker_count + self._count = worker_count + self._received = 0 + self._event = threading.Event() + + def wait(self, timeout=None): + return self._event.wait(timeout) + + def _handle(self, res): + wid, data = res + if wid < 0 or wid > self._count: + logger.error("Ignoring report from unknown worker %d." % wid) + return + + self._received += 1 + self.reports[wid] = data + + if self._received == self._count: + self._event.set() + + def _handleError(self, res): + wid, data = res + logger.error("Worker %d failed to send its report." % wid) + logger.exception(data) +