Mercurial > piecrust2
comparison piecrust/processing/pipeline.py @ 447:aefe70229fdd
bake: Commonize worker pool code between html and asset baking.
The `workerpool` package now defines a generic-ish worker pool. It's similar
to the Python framework pool but with a simpler use-case (only one way to
queue jobs) and support for workers to send a final "report" to the master
process, which we use to get timing information here.
The rest of the changes basically remove a whole bunch of duplicated code
that's not needed anymore.
author | Ludovic Chabant <ludovic@chabant.com> |
---|---|
date | Sun, 05 Jul 2015 00:09:41 -0700 |
parents | 171dde4f61dc |
children | d90ccdf18156 |
comparison
equal
deleted
inserted
replaced
446:4cdf6c2157a0 | 447:aefe70229fdd |
---|---|
1 import os | 1 import os |
2 import os.path | 2 import os.path |
3 import re | 3 import re |
4 import time | 4 import time |
5 import queue | |
6 import hashlib | 5 import hashlib |
7 import logging | 6 import logging |
8 import multiprocessing | 7 import multiprocessing |
9 from piecrust.chefutil import format_timed, format_timed_scope | 8 from piecrust.chefutil import format_timed, format_timed_scope |
10 from piecrust.processing.base import PipelineContext | 9 from piecrust.processing.base import PipelineContext |
11 from piecrust.processing.records import ( | 10 from piecrust.processing.records import ( |
12 ProcessorPipelineRecordEntry, TransitionalProcessorPipelineRecord, | 11 ProcessorPipelineRecordEntry, TransitionalProcessorPipelineRecord, |
13 FLAG_PROCESSED) | 12 FLAG_PROCESSED) |
14 from piecrust.processing.worker import ( | 13 from piecrust.processing.worker import ( |
15 ProcessingWorkerContext, ProcessingWorkerJob, | 14 ProcessingWorkerJob, |
16 worker_func, get_filtered_processors) | 15 get_filtered_processors) |
17 | 16 |
18 | 17 |
19 logger = logging.getLogger(__name__) | 18 logger = logging.getLogger(__name__) |
20 | 19 |
21 | 20 |
22 class _ProcessingContext(object): | 21 class _ProcessingContext(object): |
23 def __init__(self, pool, record, base_dir, mount_info): | 22 def __init__(self, jobs, record, base_dir, mount_info): |
24 self.pool = pool | 23 self.jobs = jobs |
25 self.record = record | 24 self.record = record |
26 self.base_dir = base_dir | 25 self.base_dir = base_dir |
27 self.mount_info = mount_info | 26 self.mount_info = mount_info |
28 | 27 |
29 | 28 |
91 proc.onPipelineStart(pipeline_ctx) | 90 proc.onPipelineStart(pipeline_ctx) |
92 | 91 |
93 # Pre-processors can define additional ignore patterns. | 92 # Pre-processors can define additional ignore patterns. |
94 self.ignore_patterns += make_re( | 93 self.ignore_patterns += make_re( |
95 pipeline_ctx._additional_ignore_patterns) | 94 pipeline_ctx._additional_ignore_patterns) |
96 | |
97 # Create the worker pool. | |
98 pool = _WorkerPool() | |
99 | 95 |
100 # Create the pipeline record. | 96 # Create the pipeline record. |
101 record = TransitionalProcessorPipelineRecord() | 97 record = TransitionalProcessorPipelineRecord() |
102 record_cache = self.app.cache.getCache('proc') | 98 record_cache = self.app.cache.getCache('proc') |
103 record_name = ( | 99 record_name = ( |
130 rel_path = os.path.relpath(res.path, self.app.root_dir) | 126 rel_path = os.path.relpath(res.path, self.app.root_dir) |
131 logger.error("Errors found in %s:" % rel_path) | 127 logger.error("Errors found in %s:" % rel_path) |
132 for e in entry.errors: | 128 for e in entry.errors: |
133 logger.error(" " + e) | 129 logger.error(" " + e) |
134 | 130 |
131 jobs = [] | |
132 self._process(src_dir_or_file, record, jobs) | |
135 pool = self._createWorkerPool() | 133 pool = self._createWorkerPool() |
136 expected_result_count = self._process(src_dir_or_file, pool, record) | 134 ar = pool.queueJobs(jobs, handler=_handler) |
137 self._waitOnWorkerPool(pool, expected_result_count, _handler) | 135 ar.wait() |
138 self._terminateWorkerPool(pool) | 136 |
139 | 137 # Shutdown the workers and get timing information from them. |
140 # Get timing information from the workers. | 138 reports = pool.close() |
141 record.current.timers = {} | 139 record.current.timers = {} |
142 for i in range(len(pool.workers)): | 140 for i in range(len(reports)): |
143 try: | 141 timers = reports[i] |
144 timers = pool.results.get(True, 0.1) | 142 if timers is None: |
145 except queue.Empty: | 143 continue |
146 logger.error("Didn't get timing information from all workers.") | |
147 break | |
148 | 144 |
149 worker_name = 'PipelineWorker_%d' % i | 145 worker_name = 'PipelineWorker_%d' % i |
150 record.current.timers[worker_name] = {} | 146 record.current.timers[worker_name] = {} |
151 for name, val in timers['data'].items(): | 147 for name, val in timers['data'].items(): |
152 main_val = record.current.timers.setdefault(name, 0) | 148 main_val = record.current.timers.setdefault(name, 0) |
183 start_time, | 179 start_time, |
184 "processed %d assets." % record.current.processed_count)) | 180 "processed %d assets." % record.current.processed_count)) |
185 | 181 |
186 return record.detach() | 182 return record.detach() |
187 | 183 |
188 def _process(self, src_dir_or_file, pool, record): | 184 def _process(self, src_dir_or_file, record, jobs): |
189 expected_result_count = 0 | |
190 | |
191 if src_dir_or_file is not None: | 185 if src_dir_or_file is not None: |
192 # Process only the given path. | 186 # Process only the given path. |
193 # Find out what mount point this is in. | 187 # Find out what mount point this is in. |
194 for name, info in self.mounts.items(): | 188 for name, info in self.mounts.items(): |
195 path = info['path'] | 189 path = info['path'] |
201 known_roots = [i['path'] for i in self.mounts.values()] | 195 known_roots = [i['path'] for i in self.mounts.values()] |
202 raise Exception("Input path '%s' is not part of any known " | 196 raise Exception("Input path '%s' is not part of any known " |
203 "mount point: %s" % | 197 "mount point: %s" % |
204 (src_dir_or_file, known_roots)) | 198 (src_dir_or_file, known_roots)) |
205 | 199 |
206 ctx = _ProcessingContext(pool, record, base_dir, mount_info) | 200 ctx = _ProcessingContext(jobs, record, base_dir, mount_info) |
207 logger.debug("Initiating processing pipeline on: %s" % | 201 logger.debug("Initiating processing pipeline on: %s" % |
208 src_dir_or_file) | 202 src_dir_or_file) |
209 if os.path.isdir(src_dir_or_file): | 203 if os.path.isdir(src_dir_or_file): |
210 expected_result_count = self._processDirectory( | 204 self._processDirectory(ctx, src_dir_or_file) |
211 ctx, src_dir_or_file) | |
212 elif os.path.isfile(src_dir_or_file): | 205 elif os.path.isfile(src_dir_or_file): |
213 self._processFile(ctx, src_dir_or_file) | 206 self._processFile(ctx, src_dir_or_file) |
214 expected_result_count = 1 | |
215 | 207 |
216 else: | 208 else: |
217 # Process everything. | 209 # Process everything. |
218 for name, info in self.mounts.items(): | 210 for name, info in self.mounts.items(): |
219 path = info['path'] | 211 path = info['path'] |
220 ctx = _ProcessingContext(pool, record, path, info) | 212 ctx = _ProcessingContext(jobs, record, path, info) |
221 logger.debug("Initiating processing pipeline on: %s" % path) | 213 logger.debug("Initiating processing pipeline on: %s" % path) |
222 expected_result_count = self._processDirectory(ctx, path) | 214 self._processDirectory(ctx, path) |
223 | |
224 return expected_result_count | |
225 | 215 |
226 def _processDirectory(self, ctx, start_dir): | 216 def _processDirectory(self, ctx, start_dir): |
227 queued_count = 0 | |
228 for dirpath, dirnames, filenames in os.walk(start_dir): | 217 for dirpath, dirnames, filenames in os.walk(start_dir): |
229 rel_dirpath = os.path.relpath(dirpath, start_dir) | 218 rel_dirpath = os.path.relpath(dirpath, start_dir) |
230 dirnames[:] = [d for d in dirnames | 219 dirnames[:] = [d for d in dirnames |
231 if not re_matchany( | 220 if not re_matchany( |
232 d, self.ignore_patterns, rel_dirpath)] | 221 d, self.ignore_patterns, rel_dirpath)] |
233 | 222 |
234 for filename in filenames: | 223 for filename in filenames: |
235 if re_matchany(filename, self.ignore_patterns, rel_dirpath): | 224 if re_matchany(filename, self.ignore_patterns, rel_dirpath): |
236 continue | 225 continue |
237 self._processFile(ctx, os.path.join(dirpath, filename)) | 226 self._processFile(ctx, os.path.join(dirpath, filename)) |
238 queued_count += 1 | |
239 return queued_count | |
240 | 227 |
241 def _processFile(self, ctx, path): | 228 def _processFile(self, ctx, path): |
242 # TODO: handle overrides between mount-points. | 229 # TODO: handle overrides between mount-points. |
243 | 230 |
244 entry = ProcessorPipelineRecordEntry(path) | 231 entry = ProcessorPipelineRecordEntry(path) |
248 force_this = (self.force or previous_entry is None or | 235 force_this = (self.force or previous_entry is None or |
249 not previous_entry.was_processed_successfully) | 236 not previous_entry.was_processed_successfully) |
250 | 237 |
251 job = ProcessingWorkerJob(ctx.base_dir, ctx.mount_info, path, | 238 job = ProcessingWorkerJob(ctx.base_dir, ctx.mount_info, path, |
252 force=force_this) | 239 force=force_this) |
253 | 240 ctx.jobs.append(job) |
254 logger.debug("Queuing: %s" % path) | |
255 ctx.pool.queue.put_nowait(job) | |
256 | 241 |
257 def _createWorkerPool(self): | 242 def _createWorkerPool(self): |
258 import sys | 243 from piecrust.workerpool import WorkerPool |
259 | 244 from piecrust.processing.worker import ( |
260 main_module = sys.modules['__main__'] | 245 ProcessingWorkerContext, ProcessingWorker) |
261 is_profiling = os.path.basename(main_module.__file__) in [ | 246 |
262 'profile.py', 'cProfile.py'] | 247 ctx = ProcessingWorkerContext( |
263 | 248 self.app.root_dir, self.out_dir, self.tmp_dir, |
264 pool = _WorkerPool() | 249 self.force, self.app.debug) |
265 for i in range(self.num_workers): | 250 ctx.enabled_processors = self.enabled_processors |
266 ctx = ProcessingWorkerContext( | 251 ctx.additional_processors = self.additional_processors |
267 self.app.root_dir, self.out_dir, self.tmp_dir, | 252 |
268 pool.queue, pool.results, pool.abort_event, | 253 pool = WorkerPool( |
269 self.force, self.app.debug) | 254 worker_class=ProcessingWorker, |
270 ctx.is_profiling = is_profiling | 255 initargs=(ctx,)) |
271 ctx.enabled_processors = self.enabled_processors | |
272 ctx.additional_processors = self.additional_processors | |
273 w = multiprocessing.Process( | |
274 name='PipelineWorker_%d' % i, | |
275 target=worker_func, args=(i, ctx)) | |
276 w.start() | |
277 pool.workers.append(w) | |
278 return pool | 256 return pool |
279 | |
280 def _waitOnWorkerPool(self, pool, expected_result_count, result_handler): | |
281 abort_with_exception = None | |
282 try: | |
283 got_count = 0 | |
284 while got_count < expected_result_count: | |
285 try: | |
286 res = pool.results.get(True, 10) | |
287 except queue.Empty: | |
288 logger.error( | |
289 "Got %d results, expected %d, and timed-out " | |
290 "for 10 seconds. A worker might be stuck?" % | |
291 (got_count, expected_result_count)) | |
292 abort_with_exception = Exception("Worker time-out.") | |
293 break | |
294 | |
295 if isinstance(res, dict) and res.get('type') == 'error': | |
296 abort_with_exception = Exception( | |
297 'Worker critical error:\n' + | |
298 '\n'.join(res['messages'])) | |
299 break | |
300 | |
301 got_count += 1 | |
302 result_handler(res) | |
303 except KeyboardInterrupt as kiex: | |
304 logger.warning("Bake aborted by user... " | |
305 "waiting for workers to stop.") | |
306 abort_with_exception = kiex | |
307 | |
308 if abort_with_exception: | |
309 pool.abort_event.set() | |
310 for w in pool.workers: | |
311 w.join(2) | |
312 raise abort_with_exception | |
313 | |
314 def _terminateWorkerPool(self, pool): | |
315 pool.abort_event.set() | |
316 for w in pool.workers: | |
317 w.join() | |
318 | |
319 | |
320 class _WorkerPool(object): | |
321 def __init__(self): | |
322 self.queue = multiprocessing.JoinableQueue() | |
323 self.results = multiprocessing.Queue() | |
324 self.abort_event = multiprocessing.Event() | |
325 self.workers = [] | |
326 | 257 |
327 | 258 |
328 def make_mount_infos(mounts, root_dir): | 259 def make_mount_infos(mounts, root_dir): |
329 if isinstance(mounts, list): | 260 if isinstance(mounts, list): |
330 mounts = {m: {} for m in mounts} | 261 mounts = {m: {} for m in mounts} |