comparison piecrust/workerpool.py @ 971:5485a11591ec

internal: Worker pool improvements - Slightly more correct multi-processing code. - Handle critical worker failures. - Correctly setup workers when run from `py.test`. - Add support for using `msgpack` or `marshall` instead of `pickle`.
author Ludovic Chabant <ludovic@chabant.com>
date Tue, 17 Oct 2017 01:00:55 -0700
parents f2b75e4be981
children 45ad976712ec
comparison
equal deleted inserted replaced
970:660250c95246 971:5485a11591ec
1 import io 1 import io
2 import os 2 import os
3 import sys 3 import sys
4 import time 4 import time
5 import pickle
6 import logging 5 import logging
7 import threading 6 import threading
8 import traceback 7 import traceback
9 import multiprocessing 8 import multiprocessing
10 from piecrust import fastpickle
11 from piecrust.environment import ExecutionStats 9 from piecrust.environment import ExecutionStats
12 10
13 11
14 logger = logging.getLogger(__name__) 12 logger = logging.getLogger(__name__)
15 13
16 use_fastqueue = False 14 use_fastqueue = False
15
17 use_fastpickle = False 16 use_fastpickle = False
17 use_msgpack = False
18 use_marshall = False
18 19
19 20
20 class IWorker(object): 21 class IWorker(object):
21 """ Interface for a pool worker. 22 """ Interface for a pool worker.
22 """ 23 """
55 return errors 56 return errors
56 57
57 58
58 TASK_JOB = 0 59 TASK_JOB = 0
59 TASK_END = 1 60 TASK_END = 1
61 _TASK_ABORT_WORKER = 10
62 _CRITICAL_WORKER_ERROR = 11
60 63
61 64
62 def worker_func(params): 65 def worker_func(params):
63 if params.is_profiling: 66 if params.is_profiling:
64 try: 67 try:
74 else: 77 else:
75 _real_worker_func(params) 78 _real_worker_func(params)
76 79
77 80
78 def _real_worker_func(params): 81 def _real_worker_func(params):
82 try:
83 _real_worker_func_unsafe(params)
84 except Exception as ex:
85 logger.exception(ex)
86 msg = ("CRITICAL ERROR IN WORKER %d\n%s" % (params.wid, str(ex)))
87 params.outqueue.put((
88 _CRITICAL_WORKER_ERROR, None, False, params.wid, msg))
89
90
91 def _real_worker_func_unsafe(params):
79 wid = params.wid 92 wid = params.wid
80 93
81 stats = ExecutionStats() 94 stats = ExecutionStats()
82 stats.registerTimer('WorkerInit') 95 stats.registerTimer('WorkerInit')
83 init_start_time = time.perf_counter() 96 init_start_time = time.perf_counter()
84 97
98 # If we are unit-testing, we didn't setup all the logging environment
99 # yet, since the executable is `py.test`. We need to translate our
100 # test logging arguments into something Chef can understand.
101 if params.is_unit_testing:
102 import argparse
103 parser = argparse.ArgumentParser()
104 # This is adapted from our `conftest.py`.
105 parser.add_argument('--log-debug', action='store_true')
106 parser.add_argument('--log-file')
107 res, _ = parser.parse_known_args(sys.argv[1:])
108
109 chef_args = []
110 if res.log_debug:
111 chef_args.append('--debug')
112 if res.log_file:
113 chef_args += ['--log', res.log_file]
114
115 from piecrust.main import _pre_parse_chef_args
116 _pre_parse_chef_args(chef_args)
117
85 # In a context where `multiprocessing` is using the `spawn` forking model, 118 # In a context where `multiprocessing` is using the `spawn` forking model,
86 # the new process doesn't inherit anything, so we lost all our logging 119 # the new process doesn't inherit anything, so we lost all our logging
87 # configuration here. Let's set it up again. 120 # configuration here. Let's set it up again.
88 if (hasattr(multiprocessing, 'get_start_method') and 121 elif (hasattr(multiprocessing, 'get_start_method') and
89 multiprocessing.get_start_method() == 'spawn'): 122 multiprocessing.get_start_method() == 'spawn'):
90 from piecrust.main import _pre_parse_chef_args 123 from piecrust.main import _pre_parse_chef_args
91 _pre_parse_chef_args(sys.argv[1:]) 124 _pre_parse_chef_args(sys.argv[1:])
92 125
93 from piecrust.main import ColoredFormatter 126 from piecrust.main import ColoredFormatter
105 w = params.worker_class(*params.initargs) 138 w = params.worker_class(*params.initargs)
106 w.wid = wid 139 w.wid = wid
107 try: 140 try:
108 w.initialize() 141 w.initialize()
109 except Exception as ex: 142 except Exception as ex:
110 logger.error("Working failed to initialize:") 143 logger.error("Worker %d failed to initialize." % wid)
111 logger.exception(ex) 144 logger.exception(ex)
112 params.outqueue.put(None) 145 raise
113 return
114 146
115 stats.stepTimerSince('WorkerInit', init_start_time) 147 stats.stepTimerSince('WorkerInit', init_start_time)
116 148
117 # Start pumping! 149 # Start pumping!
118 completed = 0 150 completed = 0
126 task = get() 158 task = get()
127 time_in_get += (time.perf_counter() - get_start_time) 159 time_in_get += (time.perf_counter() - get_start_time)
128 160
129 task_type, task_data = task 161 task_type, task_data = task
130 162
163 # Job task... just do it.
164 if task_type == TASK_JOB:
165 try:
166 res = (task_type, task_data, True, wid, w.process(task_data))
167 except Exception as e:
168 logger.debug(
169 "Error processing job, sending exception to main process:")
170 logger.debug(traceback.format_exc())
171 we = WorkerExceptionData(wid)
172 res = (task_type, task_data, False, wid, we)
173
174 put_start_time = time.perf_counter()
175 put(res)
176 time_in_put += (time.perf_counter() - put_start_time)
177
178 completed += 1
179
131 # End task... gather stats to send back to the main process. 180 # End task... gather stats to send back to the main process.
132 if task_type == TASK_END: 181 elif task_type == TASK_END:
133 logger.debug("Worker %d got end task, exiting." % wid) 182 logger.debug("Worker %d got end task, exiting." % wid)
134 stats.registerTimer('WorkerTaskGet', time=time_in_get) 183 stats.registerTimer('WorkerTaskGet', time=time_in_get)
135 stats.registerTimer('WorkerResultPut', time=time_in_put) 184 stats.registerTimer('WorkerResultPut', time=time_in_put)
136 try: 185 try:
137 stats.mergeStats(w.getStats()) 186 stats.mergeStats(w.getStats())
143 we = WorkerExceptionData(wid) 192 we = WorkerExceptionData(wid)
144 rep = (task_type, task_data, False, wid, (wid, we)) 193 rep = (task_type, task_data, False, wid, (wid, we))
145 put(rep) 194 put(rep)
146 break 195 break
147 196
148 # Job task... just do it. 197 # Emergy abort.
149 elif task_type == TASK_JOB: 198 elif task_type == _TASK_ABORT_WORKER:
150 try: 199 logger.debug("Worker %d got abort signal." % wid)
151 res = (task_type, task_data, True, wid, w.process(task_data)) 200 break
152 except Exception as e:
153 logger.debug(
154 "Error processing job, sending exception to main process:")
155 logger.debug(traceback.format_exc())
156 we = WorkerExceptionData(wid)
157 res = (task_type, task_data, False, wid, we)
158
159 put_start_time = time.perf_counter()
160 put(res)
161 time_in_put += (time.perf_counter() - put_start_time)
162
163 completed += 1
164 201
165 else: 202 else:
166 raise Exception("Unknown task type: %s" % task_type) 203 raise Exception("Unknown task type: %s" % task_type)
167 204
168 w.shutdown() 205 try:
206 w.shutdown()
207 except Exception as e:
208 logger.error("Worker %s failed to shutdown.")
209 logger.exception(e)
210 raise
211
169 logger.debug("Worker %d completed %d tasks." % (wid, completed)) 212 logger.debug("Worker %d completed %d tasks." % (wid, completed))
170 213
171 214
172 class _WorkerParams: 215 class _WorkerParams:
173 def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(), 216 def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(),
174 is_profiling=False): 217 is_profiling=False, is_unit_testing=False):
175 self.wid = wid 218 self.wid = wid
176 self.inqueue = inqueue 219 self.inqueue = inqueue
177 self.outqueue = outqueue 220 self.outqueue = outqueue
178 self.worker_class = worker_class 221 self.worker_class = worker_class
179 self.initargs = initargs 222 self.initargs = initargs
180 self.is_profiling = is_profiling 223 self.is_profiling = is_profiling
224 self.is_unit_testing = is_unit_testing
181 225
182 226
183 class WorkerPool: 227 class WorkerPool:
184 def __init__(self, worker_class, initargs=(), *, 228 def __init__(self, worker_class, initargs=(), *,
185 callback=None, error_callback=None, 229 callback=None, error_callback=None,
202 246
203 self._callback = callback 247 self._callback = callback
204 self._error_callback = error_callback 248 self._error_callback = error_callback
205 self._batch_size = batch_size 249 self._batch_size = batch_size
206 self._jobs_left = 0 250 self._jobs_left = 0
251 self._lock_jobs_left = threading.Lock()
252 self._lock_workers = threading.Lock()
207 self._event = threading.Event() 253 self._event = threading.Event()
254 self._error_on_join = None
255 self._closed = False
208 256
209 main_module = sys.modules['__main__'] 257 main_module = sys.modules['__main__']
210 is_profiling = os.path.basename(main_module.__file__) in [ 258 is_profiling = os.path.basename(main_module.__file__) in [
211 'profile.py', 'cProfile.py'] 259 'profile.py', 'cProfile.py']
260 is_unit_testing = os.path.basename(main_module.__file__) in [
261 'py.test']
212 262
213 self._pool = [] 263 self._pool = []
214 for i in range(worker_count): 264 for i in range(worker_count):
215 worker_params = _WorkerParams( 265 worker_params = _WorkerParams(
216 i, self._task_queue, self._result_queue, 266 i, self._task_queue, self._result_queue,
217 worker_class, initargs, 267 worker_class, initargs,
218 is_profiling=is_profiling) 268 is_profiling=is_profiling,
269 is_unit_testing=is_unit_testing)
219 w = multiprocessing.Process(target=worker_func, 270 w = multiprocessing.Process(target=worker_func,
220 args=(worker_params,)) 271 args=(worker_params,))
221 w.name = w.name.replace('Process', 'PoolWorker') 272 w.name = w.name.replace('Process', 'PoolWorker')
222 w.daemon = True 273 w.daemon = True
223 w.start() 274 w.start()
227 target=WorkerPool._handleResults, 278 target=WorkerPool._handleResults,
228 args=(self,)) 279 args=(self,))
229 self._result_handler.daemon = True 280 self._result_handler.daemon = True
230 self._result_handler.start() 281 self._result_handler.start()
231 282
232 self._closed = False
233
234 def queueJobs(self, jobs): 283 def queueJobs(self, jobs):
235 if self._closed: 284 if self._closed:
285 if self._error_on_join:
286 raise self._error_on_join
236 raise Exception("This worker pool has been closed.") 287 raise Exception("This worker pool has been closed.")
237 288
238 for job in jobs: 289 jobs = list(jobs)
239 self._jobs_left += 1 290 new_job_count = len(jobs)
240 self._quick_put((TASK_JOB, job)) 291 if new_job_count > 0:
241 292 with self._lock_jobs_left:
242 if self._jobs_left > 0: 293 self._jobs_left += new_job_count
294
243 self._event.clear() 295 self._event.clear()
296 for job in jobs:
297 self._quick_put((TASK_JOB, job))
244 298
245 def wait(self, timeout=None): 299 def wait(self, timeout=None):
246 return self._event.wait(timeout) 300 if self._closed:
301 raise Exception("This worker pool has been closed.")
302
303 ret = self._event.wait(timeout)
304 if self._error_on_join:
305 raise self._error_on_join
306 return ret
247 307
248 def close(self): 308 def close(self):
309 if self._closed:
310 raise Exception("This worker pool has been closed.")
249 if self._jobs_left > 0 or not self._event.is_set(): 311 if self._jobs_left > 0 or not self._event.is_set():
250 raise Exception("A previous job queue has not finished yet.") 312 raise Exception("A previous job queue has not finished yet.")
251 313
252 logger.debug("Closing worker pool...") 314 logger.debug("Closing worker pool...")
253 handler = _ReportHandler(len(self._pool)) 315 live_workers = list(filter(lambda w: w is not None, self._pool))
316 handler = _ReportHandler(len(live_workers))
254 self._callback = handler._handle 317 self._callback = handler._handle
255 self._error_callback = handler._handleError 318 self._error_callback = handler._handleError
256 for w in self._pool: 319 for w in live_workers:
257 self._quick_put((TASK_END, None)) 320 self._quick_put((TASK_END, None))
258 for w in self._pool: 321 for w in live_workers:
259 w.join() 322 w.join()
260 323
261 logger.debug("Waiting for reports...") 324 logger.debug("Waiting for reports...")
262 if not handler.wait(2): 325 if not handler.wait(2):
263 missing = handler.reports.index(None) 326 missing = handler.reports.index(None)
270 self._result_handler.join() 333 self._result_handler.join()
271 self._closed = True 334 self._closed = True
272 335
273 return handler.reports 336 return handler.reports
274 337
338 def _onResultHandlerCriticalError(self, wid):
339 logger.error("Result handler received a critical error from "
340 "worker %d." % wid)
341 with self._lock_workers:
342 self._pool[wid] = None
343 if all(map(lambda w: w is None, self._pool)):
344 logger.error("All workers have died!")
345 self._closed = True
346 self._error_on_join = Exception("All workers have died!")
347 self._event.set()
348 return False
349
350 return True
351
275 def _onTaskDone(self): 352 def _onTaskDone(self):
276 self._jobs_left -= 1 353 with self._lock_jobs_left:
277 if self._jobs_left == 0: 354 left = self._jobs_left - 1
355 self._jobs_left = left
356
357 if left == 0:
278 self._event.set() 358 self._event.set()
279 359
280 @staticmethod 360 @staticmethod
281 def _handleResults(pool): 361 def _handleResults(pool):
282 userdata = pool.userdata 362 userdata = pool.userdata
288 "problem, exiting.") 368 "problem, exiting.")
289 return 369 return
290 370
291 if res is None: 371 if res is None:
292 logger.debug("Result handler exiting.") 372 logger.debug("Result handler exiting.")
293 break 373 return
294 374
295 task_type, task_data, success, wid, data = res 375 task_type, task_data, success, wid, data = res
296 try: 376 try:
297 if success: 377 if success:
298 if pool._callback: 378 if pool._callback:
299 pool._callback(task_data, data, userdata) 379 pool._callback(task_data, data, userdata)
300 else: 380 else:
301 if pool._error_callback: 381 if task_type == _CRITICAL_WORKER_ERROR:
302 pool._error_callback(task_data, data, userdata) 382 logger.error(data)
383 do_continue = pool._onResultHandlerCriticalError(wid)
384 if not do_continue:
385 logger.debug("Aborting result handling thread.")
386 return
303 else: 387 else:
304 logger.error( 388 if pool._error_callback:
305 "Worker %d failed to process a job:" % wid) 389 pool._error_callback(task_data, data, userdata)
306 logger.error(data) 390 else:
391 logger.error(
392 "Worker %d failed to process a job:" % wid)
393 logger.error(data)
307 except Exception as ex: 394 except Exception as ex:
308 logger.exception(ex) 395 logger.exception(ex)
309 396
310 if task_type == TASK_JOB: 397 if task_type == TASK_JOB:
311 pool._onTaskDone() 398 pool._onTaskDone()
381 with self._wlock: 468 with self._wlock:
382 with self._wbuf.getbuffer() as b: 469 with self._wbuf.getbuffer() as b:
383 self._writer.send_bytes(b, 0, size) 470 self._writer.send_bytes(b, 0, size)
384 471
385 472
386 def _pickle_fast(obj, buf):
387 fastpickle.pickle_intob(obj, buf)
388
389
390 def _unpickle_fast(buf, bufsize):
391 return fastpickle.unpickle_fromb(buf, bufsize)
392
393
394 def _pickle_default(obj, buf):
395 pickle.dump(obj, buf, pickle.HIGHEST_PROTOCOL)
396
397
398 def _unpickle_default(buf, bufsize):
399 return pickle.load(buf)
400
401
402 if use_fastpickle: 473 if use_fastpickle:
474 from piecrust import fastpickle
475
476 def _pickle_fast(obj, buf):
477 fastpickle.pickle_intob(obj, buf)
478
479 def _unpickle_fast(buf, bufsize):
480 return fastpickle.unpickle_fromb(buf, bufsize)
481
403 _pickle = _pickle_fast 482 _pickle = _pickle_fast
404 _unpickle = _unpickle_fast 483 _unpickle = _unpickle_fast
484
485 elif use_msgpack:
486 import msgpack
487
488 def _pickle_msgpack(obj, buf):
489 msgpack.pack(obj, buf)
490
491 def _unpickle_msgpack(buf, bufsize):
492 return msgpack.unpack(buf)
493
494 _pickle = _pickle_msgpack
495 _unpickle = _unpickle_msgpack
496
497 elif use_marshall:
498 import marshal
499
500 def _pickle_marshal(obj, buf):
501 marshal.dump(obj, buf)
502
503 def _unpickle_marshal(buf, bufsize):
504 return marshal.load(buf)
505
506 _pickle = _pickle_marshal
507 _unpickle = _unpickle_marshal
508
405 else: 509 else:
510 import pickle
511
512 def _pickle_default(obj, buf):
513 pickle.dump(obj, buf, pickle.HIGHEST_PROTOCOL)
514
515 def _unpickle_default(buf, bufsize):
516 return pickle.load(buf)
517
406 _pickle = _pickle_default 518 _pickle = _pickle_default
407 _unpickle = _unpickle_default 519 _unpickle = _unpickle_default
408 520