Mercurial > piecrust2
comparison piecrust/workerpool.py @ 971:5485a11591ec
internal: Worker pool improvements
- Slightly more correct multi-processing code.
- Handle critical worker failures.
- Correctly setup workers when run from `py.test`.
- Add support for using `msgpack` or `marshall` instead of `pickle`.
author | Ludovic Chabant <ludovic@chabant.com> |
---|---|
date | Tue, 17 Oct 2017 01:00:55 -0700 |
parents | f2b75e4be981 |
children | 45ad976712ec |
comparison
equal
deleted
inserted
replaced
970:660250c95246 | 971:5485a11591ec |
---|---|
1 import io | 1 import io |
2 import os | 2 import os |
3 import sys | 3 import sys |
4 import time | 4 import time |
5 import pickle | |
6 import logging | 5 import logging |
7 import threading | 6 import threading |
8 import traceback | 7 import traceback |
9 import multiprocessing | 8 import multiprocessing |
10 from piecrust import fastpickle | |
11 from piecrust.environment import ExecutionStats | 9 from piecrust.environment import ExecutionStats |
12 | 10 |
13 | 11 |
14 logger = logging.getLogger(__name__) | 12 logger = logging.getLogger(__name__) |
15 | 13 |
16 use_fastqueue = False | 14 use_fastqueue = False |
15 | |
17 use_fastpickle = False | 16 use_fastpickle = False |
17 use_msgpack = False | |
18 use_marshall = False | |
18 | 19 |
19 | 20 |
20 class IWorker(object): | 21 class IWorker(object): |
21 """ Interface for a pool worker. | 22 """ Interface for a pool worker. |
22 """ | 23 """ |
55 return errors | 56 return errors |
56 | 57 |
57 | 58 |
58 TASK_JOB = 0 | 59 TASK_JOB = 0 |
59 TASK_END = 1 | 60 TASK_END = 1 |
61 _TASK_ABORT_WORKER = 10 | |
62 _CRITICAL_WORKER_ERROR = 11 | |
60 | 63 |
61 | 64 |
62 def worker_func(params): | 65 def worker_func(params): |
63 if params.is_profiling: | 66 if params.is_profiling: |
64 try: | 67 try: |
74 else: | 77 else: |
75 _real_worker_func(params) | 78 _real_worker_func(params) |
76 | 79 |
77 | 80 |
78 def _real_worker_func(params): | 81 def _real_worker_func(params): |
82 try: | |
83 _real_worker_func_unsafe(params) | |
84 except Exception as ex: | |
85 logger.exception(ex) | |
86 msg = ("CRITICAL ERROR IN WORKER %d\n%s" % (params.wid, str(ex))) | |
87 params.outqueue.put(( | |
88 _CRITICAL_WORKER_ERROR, None, False, params.wid, msg)) | |
89 | |
90 | |
91 def _real_worker_func_unsafe(params): | |
79 wid = params.wid | 92 wid = params.wid |
80 | 93 |
81 stats = ExecutionStats() | 94 stats = ExecutionStats() |
82 stats.registerTimer('WorkerInit') | 95 stats.registerTimer('WorkerInit') |
83 init_start_time = time.perf_counter() | 96 init_start_time = time.perf_counter() |
84 | 97 |
98 # If we are unit-testing, we didn't setup all the logging environment | |
99 # yet, since the executable is `py.test`. We need to translate our | |
100 # test logging arguments into something Chef can understand. | |
101 if params.is_unit_testing: | |
102 import argparse | |
103 parser = argparse.ArgumentParser() | |
104 # This is adapted from our `conftest.py`. | |
105 parser.add_argument('--log-debug', action='store_true') | |
106 parser.add_argument('--log-file') | |
107 res, _ = parser.parse_known_args(sys.argv[1:]) | |
108 | |
109 chef_args = [] | |
110 if res.log_debug: | |
111 chef_args.append('--debug') | |
112 if res.log_file: | |
113 chef_args += ['--log', res.log_file] | |
114 | |
115 from piecrust.main import _pre_parse_chef_args | |
116 _pre_parse_chef_args(chef_args) | |
117 | |
85 # In a context where `multiprocessing` is using the `spawn` forking model, | 118 # In a context where `multiprocessing` is using the `spawn` forking model, |
86 # the new process doesn't inherit anything, so we lost all our logging | 119 # the new process doesn't inherit anything, so we lost all our logging |
87 # configuration here. Let's set it up again. | 120 # configuration here. Let's set it up again. |
88 if (hasattr(multiprocessing, 'get_start_method') and | 121 elif (hasattr(multiprocessing, 'get_start_method') and |
89 multiprocessing.get_start_method() == 'spawn'): | 122 multiprocessing.get_start_method() == 'spawn'): |
90 from piecrust.main import _pre_parse_chef_args | 123 from piecrust.main import _pre_parse_chef_args |
91 _pre_parse_chef_args(sys.argv[1:]) | 124 _pre_parse_chef_args(sys.argv[1:]) |
92 | 125 |
93 from piecrust.main import ColoredFormatter | 126 from piecrust.main import ColoredFormatter |
105 w = params.worker_class(*params.initargs) | 138 w = params.worker_class(*params.initargs) |
106 w.wid = wid | 139 w.wid = wid |
107 try: | 140 try: |
108 w.initialize() | 141 w.initialize() |
109 except Exception as ex: | 142 except Exception as ex: |
110 logger.error("Working failed to initialize:") | 143 logger.error("Worker %d failed to initialize." % wid) |
111 logger.exception(ex) | 144 logger.exception(ex) |
112 params.outqueue.put(None) | 145 raise |
113 return | |
114 | 146 |
115 stats.stepTimerSince('WorkerInit', init_start_time) | 147 stats.stepTimerSince('WorkerInit', init_start_time) |
116 | 148 |
117 # Start pumping! | 149 # Start pumping! |
118 completed = 0 | 150 completed = 0 |
126 task = get() | 158 task = get() |
127 time_in_get += (time.perf_counter() - get_start_time) | 159 time_in_get += (time.perf_counter() - get_start_time) |
128 | 160 |
129 task_type, task_data = task | 161 task_type, task_data = task |
130 | 162 |
163 # Job task... just do it. | |
164 if task_type == TASK_JOB: | |
165 try: | |
166 res = (task_type, task_data, True, wid, w.process(task_data)) | |
167 except Exception as e: | |
168 logger.debug( | |
169 "Error processing job, sending exception to main process:") | |
170 logger.debug(traceback.format_exc()) | |
171 we = WorkerExceptionData(wid) | |
172 res = (task_type, task_data, False, wid, we) | |
173 | |
174 put_start_time = time.perf_counter() | |
175 put(res) | |
176 time_in_put += (time.perf_counter() - put_start_time) | |
177 | |
178 completed += 1 | |
179 | |
131 # End task... gather stats to send back to the main process. | 180 # End task... gather stats to send back to the main process. |
132 if task_type == TASK_END: | 181 elif task_type == TASK_END: |
133 logger.debug("Worker %d got end task, exiting." % wid) | 182 logger.debug("Worker %d got end task, exiting." % wid) |
134 stats.registerTimer('WorkerTaskGet', time=time_in_get) | 183 stats.registerTimer('WorkerTaskGet', time=time_in_get) |
135 stats.registerTimer('WorkerResultPut', time=time_in_put) | 184 stats.registerTimer('WorkerResultPut', time=time_in_put) |
136 try: | 185 try: |
137 stats.mergeStats(w.getStats()) | 186 stats.mergeStats(w.getStats()) |
143 we = WorkerExceptionData(wid) | 192 we = WorkerExceptionData(wid) |
144 rep = (task_type, task_data, False, wid, (wid, we)) | 193 rep = (task_type, task_data, False, wid, (wid, we)) |
145 put(rep) | 194 put(rep) |
146 break | 195 break |
147 | 196 |
148 # Job task... just do it. | 197 # Emergy abort. |
149 elif task_type == TASK_JOB: | 198 elif task_type == _TASK_ABORT_WORKER: |
150 try: | 199 logger.debug("Worker %d got abort signal." % wid) |
151 res = (task_type, task_data, True, wid, w.process(task_data)) | 200 break |
152 except Exception as e: | |
153 logger.debug( | |
154 "Error processing job, sending exception to main process:") | |
155 logger.debug(traceback.format_exc()) | |
156 we = WorkerExceptionData(wid) | |
157 res = (task_type, task_data, False, wid, we) | |
158 | |
159 put_start_time = time.perf_counter() | |
160 put(res) | |
161 time_in_put += (time.perf_counter() - put_start_time) | |
162 | |
163 completed += 1 | |
164 | 201 |
165 else: | 202 else: |
166 raise Exception("Unknown task type: %s" % task_type) | 203 raise Exception("Unknown task type: %s" % task_type) |
167 | 204 |
168 w.shutdown() | 205 try: |
206 w.shutdown() | |
207 except Exception as e: | |
208 logger.error("Worker %s failed to shutdown.") | |
209 logger.exception(e) | |
210 raise | |
211 | |
169 logger.debug("Worker %d completed %d tasks." % (wid, completed)) | 212 logger.debug("Worker %d completed %d tasks." % (wid, completed)) |
170 | 213 |
171 | 214 |
172 class _WorkerParams: | 215 class _WorkerParams: |
173 def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(), | 216 def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(), |
174 is_profiling=False): | 217 is_profiling=False, is_unit_testing=False): |
175 self.wid = wid | 218 self.wid = wid |
176 self.inqueue = inqueue | 219 self.inqueue = inqueue |
177 self.outqueue = outqueue | 220 self.outqueue = outqueue |
178 self.worker_class = worker_class | 221 self.worker_class = worker_class |
179 self.initargs = initargs | 222 self.initargs = initargs |
180 self.is_profiling = is_profiling | 223 self.is_profiling = is_profiling |
224 self.is_unit_testing = is_unit_testing | |
181 | 225 |
182 | 226 |
183 class WorkerPool: | 227 class WorkerPool: |
184 def __init__(self, worker_class, initargs=(), *, | 228 def __init__(self, worker_class, initargs=(), *, |
185 callback=None, error_callback=None, | 229 callback=None, error_callback=None, |
202 | 246 |
203 self._callback = callback | 247 self._callback = callback |
204 self._error_callback = error_callback | 248 self._error_callback = error_callback |
205 self._batch_size = batch_size | 249 self._batch_size = batch_size |
206 self._jobs_left = 0 | 250 self._jobs_left = 0 |
251 self._lock_jobs_left = threading.Lock() | |
252 self._lock_workers = threading.Lock() | |
207 self._event = threading.Event() | 253 self._event = threading.Event() |
254 self._error_on_join = None | |
255 self._closed = False | |
208 | 256 |
209 main_module = sys.modules['__main__'] | 257 main_module = sys.modules['__main__'] |
210 is_profiling = os.path.basename(main_module.__file__) in [ | 258 is_profiling = os.path.basename(main_module.__file__) in [ |
211 'profile.py', 'cProfile.py'] | 259 'profile.py', 'cProfile.py'] |
260 is_unit_testing = os.path.basename(main_module.__file__) in [ | |
261 'py.test'] | |
212 | 262 |
213 self._pool = [] | 263 self._pool = [] |
214 for i in range(worker_count): | 264 for i in range(worker_count): |
215 worker_params = _WorkerParams( | 265 worker_params = _WorkerParams( |
216 i, self._task_queue, self._result_queue, | 266 i, self._task_queue, self._result_queue, |
217 worker_class, initargs, | 267 worker_class, initargs, |
218 is_profiling=is_profiling) | 268 is_profiling=is_profiling, |
269 is_unit_testing=is_unit_testing) | |
219 w = multiprocessing.Process(target=worker_func, | 270 w = multiprocessing.Process(target=worker_func, |
220 args=(worker_params,)) | 271 args=(worker_params,)) |
221 w.name = w.name.replace('Process', 'PoolWorker') | 272 w.name = w.name.replace('Process', 'PoolWorker') |
222 w.daemon = True | 273 w.daemon = True |
223 w.start() | 274 w.start() |
227 target=WorkerPool._handleResults, | 278 target=WorkerPool._handleResults, |
228 args=(self,)) | 279 args=(self,)) |
229 self._result_handler.daemon = True | 280 self._result_handler.daemon = True |
230 self._result_handler.start() | 281 self._result_handler.start() |
231 | 282 |
232 self._closed = False | |
233 | |
234 def queueJobs(self, jobs): | 283 def queueJobs(self, jobs): |
235 if self._closed: | 284 if self._closed: |
285 if self._error_on_join: | |
286 raise self._error_on_join | |
236 raise Exception("This worker pool has been closed.") | 287 raise Exception("This worker pool has been closed.") |
237 | 288 |
238 for job in jobs: | 289 jobs = list(jobs) |
239 self._jobs_left += 1 | 290 new_job_count = len(jobs) |
240 self._quick_put((TASK_JOB, job)) | 291 if new_job_count > 0: |
241 | 292 with self._lock_jobs_left: |
242 if self._jobs_left > 0: | 293 self._jobs_left += new_job_count |
294 | |
243 self._event.clear() | 295 self._event.clear() |
296 for job in jobs: | |
297 self._quick_put((TASK_JOB, job)) | |
244 | 298 |
245 def wait(self, timeout=None): | 299 def wait(self, timeout=None): |
246 return self._event.wait(timeout) | 300 if self._closed: |
301 raise Exception("This worker pool has been closed.") | |
302 | |
303 ret = self._event.wait(timeout) | |
304 if self._error_on_join: | |
305 raise self._error_on_join | |
306 return ret | |
247 | 307 |
248 def close(self): | 308 def close(self): |
309 if self._closed: | |
310 raise Exception("This worker pool has been closed.") | |
249 if self._jobs_left > 0 or not self._event.is_set(): | 311 if self._jobs_left > 0 or not self._event.is_set(): |
250 raise Exception("A previous job queue has not finished yet.") | 312 raise Exception("A previous job queue has not finished yet.") |
251 | 313 |
252 logger.debug("Closing worker pool...") | 314 logger.debug("Closing worker pool...") |
253 handler = _ReportHandler(len(self._pool)) | 315 live_workers = list(filter(lambda w: w is not None, self._pool)) |
316 handler = _ReportHandler(len(live_workers)) | |
254 self._callback = handler._handle | 317 self._callback = handler._handle |
255 self._error_callback = handler._handleError | 318 self._error_callback = handler._handleError |
256 for w in self._pool: | 319 for w in live_workers: |
257 self._quick_put((TASK_END, None)) | 320 self._quick_put((TASK_END, None)) |
258 for w in self._pool: | 321 for w in live_workers: |
259 w.join() | 322 w.join() |
260 | 323 |
261 logger.debug("Waiting for reports...") | 324 logger.debug("Waiting for reports...") |
262 if not handler.wait(2): | 325 if not handler.wait(2): |
263 missing = handler.reports.index(None) | 326 missing = handler.reports.index(None) |
270 self._result_handler.join() | 333 self._result_handler.join() |
271 self._closed = True | 334 self._closed = True |
272 | 335 |
273 return handler.reports | 336 return handler.reports |
274 | 337 |
338 def _onResultHandlerCriticalError(self, wid): | |
339 logger.error("Result handler received a critical error from " | |
340 "worker %d." % wid) | |
341 with self._lock_workers: | |
342 self._pool[wid] = None | |
343 if all(map(lambda w: w is None, self._pool)): | |
344 logger.error("All workers have died!") | |
345 self._closed = True | |
346 self._error_on_join = Exception("All workers have died!") | |
347 self._event.set() | |
348 return False | |
349 | |
350 return True | |
351 | |
275 def _onTaskDone(self): | 352 def _onTaskDone(self): |
276 self._jobs_left -= 1 | 353 with self._lock_jobs_left: |
277 if self._jobs_left == 0: | 354 left = self._jobs_left - 1 |
355 self._jobs_left = left | |
356 | |
357 if left == 0: | |
278 self._event.set() | 358 self._event.set() |
279 | 359 |
280 @staticmethod | 360 @staticmethod |
281 def _handleResults(pool): | 361 def _handleResults(pool): |
282 userdata = pool.userdata | 362 userdata = pool.userdata |
288 "problem, exiting.") | 368 "problem, exiting.") |
289 return | 369 return |
290 | 370 |
291 if res is None: | 371 if res is None: |
292 logger.debug("Result handler exiting.") | 372 logger.debug("Result handler exiting.") |
293 break | 373 return |
294 | 374 |
295 task_type, task_data, success, wid, data = res | 375 task_type, task_data, success, wid, data = res |
296 try: | 376 try: |
297 if success: | 377 if success: |
298 if pool._callback: | 378 if pool._callback: |
299 pool._callback(task_data, data, userdata) | 379 pool._callback(task_data, data, userdata) |
300 else: | 380 else: |
301 if pool._error_callback: | 381 if task_type == _CRITICAL_WORKER_ERROR: |
302 pool._error_callback(task_data, data, userdata) | 382 logger.error(data) |
383 do_continue = pool._onResultHandlerCriticalError(wid) | |
384 if not do_continue: | |
385 logger.debug("Aborting result handling thread.") | |
386 return | |
303 else: | 387 else: |
304 logger.error( | 388 if pool._error_callback: |
305 "Worker %d failed to process a job:" % wid) | 389 pool._error_callback(task_data, data, userdata) |
306 logger.error(data) | 390 else: |
391 logger.error( | |
392 "Worker %d failed to process a job:" % wid) | |
393 logger.error(data) | |
307 except Exception as ex: | 394 except Exception as ex: |
308 logger.exception(ex) | 395 logger.exception(ex) |
309 | 396 |
310 if task_type == TASK_JOB: | 397 if task_type == TASK_JOB: |
311 pool._onTaskDone() | 398 pool._onTaskDone() |
381 with self._wlock: | 468 with self._wlock: |
382 with self._wbuf.getbuffer() as b: | 469 with self._wbuf.getbuffer() as b: |
383 self._writer.send_bytes(b, 0, size) | 470 self._writer.send_bytes(b, 0, size) |
384 | 471 |
385 | 472 |
386 def _pickle_fast(obj, buf): | |
387 fastpickle.pickle_intob(obj, buf) | |
388 | |
389 | |
390 def _unpickle_fast(buf, bufsize): | |
391 return fastpickle.unpickle_fromb(buf, bufsize) | |
392 | |
393 | |
394 def _pickle_default(obj, buf): | |
395 pickle.dump(obj, buf, pickle.HIGHEST_PROTOCOL) | |
396 | |
397 | |
398 def _unpickle_default(buf, bufsize): | |
399 return pickle.load(buf) | |
400 | |
401 | |
402 if use_fastpickle: | 473 if use_fastpickle: |
474 from piecrust import fastpickle | |
475 | |
476 def _pickle_fast(obj, buf): | |
477 fastpickle.pickle_intob(obj, buf) | |
478 | |
479 def _unpickle_fast(buf, bufsize): | |
480 return fastpickle.unpickle_fromb(buf, bufsize) | |
481 | |
403 _pickle = _pickle_fast | 482 _pickle = _pickle_fast |
404 _unpickle = _unpickle_fast | 483 _unpickle = _unpickle_fast |
484 | |
485 elif use_msgpack: | |
486 import msgpack | |
487 | |
488 def _pickle_msgpack(obj, buf): | |
489 msgpack.pack(obj, buf) | |
490 | |
491 def _unpickle_msgpack(buf, bufsize): | |
492 return msgpack.unpack(buf) | |
493 | |
494 _pickle = _pickle_msgpack | |
495 _unpickle = _unpickle_msgpack | |
496 | |
497 elif use_marshall: | |
498 import marshal | |
499 | |
500 def _pickle_marshal(obj, buf): | |
501 marshal.dump(obj, buf) | |
502 | |
503 def _unpickle_marshal(buf, bufsize): | |
504 return marshal.load(buf) | |
505 | |
506 _pickle = _pickle_marshal | |
507 _unpickle = _unpickle_marshal | |
508 | |
405 else: | 509 else: |
510 import pickle | |
511 | |
512 def _pickle_default(obj, buf): | |
513 pickle.dump(obj, buf, pickle.HIGHEST_PROTOCOL) | |
514 | |
515 def _unpickle_default(buf, bufsize): | |
516 return pickle.load(buf) | |
517 | |
406 _pickle = _pickle_default | 518 _pickle = _pickle_default |
407 _unpickle = _unpickle_default | 519 _unpickle = _unpickle_default |
408 | 520 |