comparison piecrust/processing/pipeline.py @ 414:c4b3a7fd2f87

bake: Make pipeline processing multi-process. Not many changes here, as it's pretty straightforward, but an API change for processors so they know if they're being initialized/disposed from the main process or from one of the workers. This makes it possible to do global stuff that has side-effects (e.g. create a directory) vs. doing in-memory stuff.
author Ludovic Chabant <ludovic@chabant.com>
date Sat, 20 Jun 2015 19:20:30 -0700
parents
children 4a43d7015b75
comparison
equal deleted inserted replaced
413:eacf0a3afd0c 414:c4b3a7fd2f87
1 import os
2 import os.path
3 import re
4 import time
5 import queue
6 import hashlib
7 import logging
8 import multiprocessing
9 from piecrust.chefutil import format_timed, format_timed_scope
10 from piecrust.processing.base import PipelineContext
11 from piecrust.processing.records import (
12 ProcessorPipelineRecordEntry, TransitionalProcessorPipelineRecord,
13 FLAG_PROCESSED)
14 from piecrust.processing.worker import (
15 ProcessingWorkerContext, ProcessingWorkerJob,
16 worker_func, get_filtered_processors)
17
18
19 logger = logging.getLogger(__name__)
20
21
22 class _ProcessingContext(object):
23 def __init__(self, pool, record, base_dir, mount_info):
24 self.pool = pool
25 self.record = record
26 self.base_dir = base_dir
27 self.mount_info = mount_info
28
29
30 class ProcessorPipeline(object):
31 def __init__(self, app, out_dir, force=False):
32 assert app and out_dir
33 self.app = app
34 self.out_dir = out_dir
35 self.force = force
36
37 tmp_dir = app.sub_cache_dir
38 if not tmp_dir:
39 import tempfile
40 tmp_dir = os.path.join(tempfile.gettempdir(), 'piecrust')
41 self.tmp_dir = os.path.join(tmp_dir, 'proc')
42
43 baker_params = app.config.get('baker') or {}
44
45 assets_dirs = baker_params.get('assets_dirs', app.assets_dirs)
46 self.mounts = make_mount_infos(assets_dirs, self.app.root_dir)
47
48 self.num_workers = baker_params.get(
49 'workers', multiprocessing.cpu_count())
50
51 ignores = baker_params.get('ignore', [])
52 ignores += [
53 '_cache', '_counter',
54 'theme_info.yml',
55 '.DS_Store', 'Thumbs.db',
56 '.git*', '.hg*', '.svn']
57 self.ignore_patterns = make_re(ignores)
58 self.force_patterns = make_re(baker_params.get('force', []))
59
60 # Those things are mostly for unit-testing.
61 self.enabled_processors = None
62 self.additional_processors = None
63
64 def addIgnorePatterns(self, patterns):
65 self.ignore_patterns += make_re(patterns)
66
67 def run(self, src_dir_or_file=None, *,
68 delete=True, previous_record=None, save_record=True):
69 start_time = time.perf_counter()
70
71 # Get the list of processors for this run.
72 processors = self.app.plugin_loader.getProcessors()
73 if self.enabled_processors is not None:
74 logger.debug("Filtering processors to: %s" %
75 self.enabled_processors)
76 processors = get_filtered_processors(processors,
77 self.enabled_processors)
78 if self.additional_processors is not None:
79 logger.debug("Adding %s additional processors." %
80 len(self.additional_processors))
81 for proc in self.additional_processors:
82 self.app.env.registerTimer(proc.__class__.__name__,
83 raise_if_registered=False)
84 proc.initialize(self.app)
85 processors.append(proc)
86
87 # Invoke pre-processors.
88 pipeline_ctx = PipelineContext(-1, self.app, self.out_dir,
89 self.tmp_dir, self.force)
90 for proc in processors:
91 proc.onPipelineStart(pipeline_ctx)
92
93 # Pre-processors can define additional ignore patterns.
94 self.ignore_patterns += make_re(
95 pipeline_ctx._additional_ignore_patterns)
96
97 # Create the worker pool.
98 pool = _WorkerPool()
99
100 # Create the pipeline record.
101 record = TransitionalProcessorPipelineRecord()
102 record_cache = self.app.cache.getCache('proc')
103 record_name = (
104 hashlib.md5(self.out_dir.encode('utf8')).hexdigest() +
105 '.record')
106 if previous_record:
107 record.setPrevious(previous_record)
108 elif not self.force and record_cache.has(record_name):
109 with format_timed_scope(logger, 'loaded previous bake record',
110 level=logging.DEBUG, colored=False):
111 record.loadPrevious(record_cache.getCachePath(record_name))
112 logger.debug("Got %d entries in process record." %
113 len(record.previous.entries))
114 record.current.success = True
115 record.current.processed_count = 0
116
117 # Work!
118 def _handler(res):
119 entry = record.getCurrentEntry(res.path)
120 assert entry is not None
121 entry.flags |= res.flags
122 entry.proc_tree = res.proc_tree
123 entry.rel_outputs = res.rel_outputs
124 if res.errors:
125 entry.errors += res.errors
126 record.current.success = False
127 if entry.flags & FLAG_PROCESSED:
128 record.current.processed_count += 1
129
130 pool = self._createWorkerPool()
131 expected_result_count = self._process(src_dir_or_file, pool, record)
132 self._waitOnWorkerPool(pool, expected_result_count, _handler)
133 self._terminateWorkerPool(pool)
134
135 # Get timing information from the workers.
136 record.current.timers = {}
137 for _ in range(len(pool.workers)):
138 try:
139 timers = pool.results.get(True, 0.1)
140 except queue.Empty:
141 logger.error("Didn't get timing information from all workers.")
142 break
143
144 for name, val in timers['data'].items():
145 main_val = record.current.timers.setdefault(name, 0)
146 record.current.timers[name] = main_val + val
147
148 # Invoke post-processors.
149 pipeline_ctx.record = record.current
150 for proc in processors:
151 proc.onPipelineEnd(pipeline_ctx)
152
153 # Handle deletions.
154 if delete:
155 for path, reason in record.getDeletions():
156 logger.debug("Removing '%s': %s" % (path, reason))
157 try:
158 os.remove(path)
159 except FileNotFoundError:
160 pass
161 logger.info('[delete] %s' % path)
162
163 # Finalize the process record.
164 record.current.process_time = time.time()
165 record.current.out_dir = self.out_dir
166 record.collapseRecords()
167
168 # Save the process record.
169 if save_record:
170 with format_timed_scope(logger, 'saved bake record',
171 level=logging.DEBUG, colored=False):
172 record.saveCurrent(record_cache.getCachePath(record_name))
173
174 logger.info(format_timed(
175 start_time,
176 "processed %d assets." % record.current.processed_count))
177
178 return record.detach()
179
180 def _process(self, src_dir_or_file, pool, record):
181 expected_result_count = 0
182
183 if src_dir_or_file is not None:
184 # Process only the given path.
185 # Find out what mount point this is in.
186 for name, info in self.mounts.items():
187 path = info['path']
188 if src_dir_or_file[:len(path)] == path:
189 base_dir = path
190 mount_info = info
191 break
192 else:
193 known_roots = [i['path'] for i in self.mounts.values()]
194 raise Exception("Input path '%s' is not part of any known "
195 "mount point: %s" %
196 (src_dir_or_file, known_roots))
197
198 ctx = _ProcessingContext(pool, record, base_dir, mount_info)
199 logger.debug("Initiating processing pipeline on: %s" %
200 src_dir_or_file)
201 if os.path.isdir(src_dir_or_file):
202 expected_result_count = self._processDirectory(
203 ctx, src_dir_or_file)
204 elif os.path.isfile(src_dir_or_file):
205 self._processFile(ctx, src_dir_or_file)
206 expected_result_count = 1
207
208 else:
209 # Process everything.
210 for name, info in self.mounts.items():
211 path = info['path']
212 ctx = _ProcessingContext(pool, record, path, info)
213 logger.debug("Initiating processing pipeline on: %s" % path)
214 expected_result_count = self._processDirectory(ctx, path)
215
216 return expected_result_count
217
218 def _processDirectory(self, ctx, start_dir):
219 queued_count = 0
220 for dirpath, dirnames, filenames in os.walk(start_dir):
221 rel_dirpath = os.path.relpath(dirpath, start_dir)
222 dirnames[:] = [d for d in dirnames
223 if not re_matchany(
224 d, self.ignore_patterns, rel_dirpath)]
225
226 for filename in filenames:
227 if re_matchany(filename, self.ignore_patterns, rel_dirpath):
228 continue
229 self._processFile(ctx, os.path.join(dirpath, filename))
230 queued_count += 1
231 return queued_count
232
233 def _processFile(self, ctx, path):
234 # TODO: handle overrides between mount-points.
235
236 entry = ProcessorPipelineRecordEntry(path)
237 ctx.record.addEntry(entry)
238
239 previous_entry = ctx.record.getPreviousEntry(path)
240 force_this = (self.force or previous_entry is None or
241 not previous_entry.was_processed_successfully)
242
243 job = ProcessingWorkerJob(ctx.base_dir, ctx.mount_info, path,
244 force=force_this)
245
246 logger.debug("Queuing: %s" % path)
247 ctx.pool.queue.put_nowait(job)
248
249 def _createWorkerPool(self):
250 pool = _WorkerPool()
251 for i in range(self.num_workers):
252 ctx = ProcessingWorkerContext(
253 self.app.root_dir, self.out_dir, self.tmp_dir,
254 pool.queue, pool.results, pool.abort_event,
255 self.force, self.app.debug)
256 ctx.enabled_processors = self.enabled_processors
257 ctx.additional_processors = self.additional_processors
258 w = multiprocessing.Process(
259 name='Worker_%d' % i,
260 target=worker_func, args=(i, ctx))
261 w.start()
262 pool.workers.append(w)
263 return pool
264
265 def _waitOnWorkerPool(self, pool, expected_result_count, result_handler):
266 abort_with_exception = None
267 try:
268 got_count = 0
269 while got_count < expected_result_count:
270 try:
271 res = pool.results.get(True, 10)
272 except queue.Empty:
273 logger.error(
274 "Got %d results, expected %d, and timed-out "
275 "for 10 seconds. A worker might be stuck?" %
276 (got_count, expected_result_count))
277 abort_with_exception = Exception("Worker time-out.")
278 break
279
280 if isinstance(res, dict) and res.get('type') == 'error':
281 abort_with_exception = Exception(
282 'Worker critical error:\n' +
283 '\n'.join(res['messages']))
284 break
285
286 got_count += 1
287 result_handler(res)
288 except KeyboardInterrupt as kiex:
289 logger.warning("Bake aborted by user... "
290 "waiting for workers to stop.")
291 abort_with_exception = kiex
292
293 if abort_with_exception:
294 pool.abort_event.set()
295 for w in pool.workers:
296 w.join(2)
297 raise abort_with_exception
298
299 def _terminateWorkerPool(self, pool):
300 pool.abort_event.set()
301 for w in pool.workers:
302 w.join()
303
304
305 class _WorkerPool(object):
306 def __init__(self):
307 self.queue = multiprocessing.JoinableQueue()
308 self.results = multiprocessing.Queue()
309 self.abort_event = multiprocessing.Event()
310 self.workers = []
311
312
313 def make_mount_infos(mounts, root_dir):
314 if isinstance(mounts, list):
315 mounts = {m: {} for m in mounts}
316
317 for name, info in mounts.items():
318 if not isinstance(info, dict):
319 raise Exception("Asset directory info for '%s' is not a "
320 "dictionary." % name)
321 info.setdefault('processors', 'all -uglifyjs -cleancss')
322 info['path'] = os.path.join(root_dir, name)
323
324 return mounts
325
326
327 def make_re(patterns):
328 re_patterns = []
329 for pat in patterns:
330 if pat[0] == '/' and pat[-1] == '/' and len(pat) > 2:
331 re_patterns.append(pat[1:-1])
332 else:
333 escaped_pat = (
334 re.escape(pat)
335 .replace(r'\*', r'[^/\\]*')
336 .replace(r'\?', r'[^/\\]'))
337 re_patterns.append(escaped_pat)
338 return [re.compile(p) for p in re_patterns]
339
340
341 def re_matchany(filename, patterns, dirname=None):
342 if dirname and dirname != '.':
343 filename = os.path.join(dirname, filename)
344
345 # skip patterns use a forward slash regardless of the platform.
346 filename = filename.replace('\\', '/')
347 for pattern in patterns:
348 if pattern.search(filename):
349 return True
350 return False
351