comparison piecrust/workerpool.py @ 447:aefe70229fdd

bake: Commonize worker pool code between html and asset baking. The `workerpool` package now defines a generic-ish worker pool. It's similar to the Python framework pool but with a simpler use-case (only one way to queue jobs) and support for workers to send a final "report" to the master process, which we use to get timing information here. The rest of the changes basically remove a whole bunch of duplicated code that's not needed anymore.
author Ludovic Chabant <ludovic@chabant.com>
date Sun, 05 Jul 2015 00:09:41 -0700
parents
children 838f3964f400
comparison
equal deleted inserted replaced
446:4cdf6c2157a0 447:aefe70229fdd
1 import os
2 import sys
3 import logging
4 import threading
5 import multiprocessing
6
7
8 logger = logging.getLogger(__name__)
9
10
11 class IWorker(object):
12 def initialize(self):
13 raise NotImplementedError()
14
15 def process(self, job):
16 raise NotImplementedError()
17
18 def getReport(self):
19 return None
20
21
22 TASK_JOB = 0
23 TASK_END = 1
24
25
26 def worker_func(params):
27 if params.is_profiling:
28 try:
29 import cProfile as profile
30 except ImportError:
31 import profile
32
33 params.is_profiling = False
34 name = params.worker_class.__name__
35 profile.runctx('_real_worker_func(params)',
36 globals(), locals(),
37 filename='%s-%d.prof' % (name, params.wid))
38 else:
39 _real_worker_func(params)
40
41
42 def _real_worker_func(params):
43 if hasattr(params.inqueue, '_writer'):
44 params.inqueue._writer.close()
45 params.outqueue._reader.close()
46
47 wid = params.wid
48 logger.debug("Worker %d initializing..." % wid)
49
50 w = params.worker_class(*params.initargs)
51 w.wid = wid
52 w.initialize()
53
54 get = params.inqueue.get
55 put = params.outqueue.put
56
57 completed = 0
58 while True:
59 try:
60 task = get()
61 except (EOFError, OSError):
62 logger.debug("Worker %d encountered connection problem." % wid)
63 break
64
65 task_type, task_data = task
66 if task_type == TASK_END:
67 logger.debug("Worker %d got end task, exiting." % wid)
68 try:
69 rep = (task_type, True, wid, (wid, w.getReport()))
70 except Exception as e:
71 if params.wrap_exception:
72 e = multiprocessing.ExceptionWithTraceback(
73 e, e.__traceback__)
74 rep = (task_type, False, wid, (wid, e))
75 put(rep)
76 break
77
78 try:
79 res = (task_type, True, wid, w.process(task_data))
80 except Exception as e:
81 if params.wrap_exception:
82 e = multiprocessing.ExceptionWithTraceback(e, e.__traceback__)
83 res = (task_type, False, wid, e)
84 put(res)
85
86 completed += 1
87
88 logger.debug("Worker %d completed %d tasks." % (wid, completed))
89
90
91 class _WorkerParams(object):
92 def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(),
93 wrap_exception=False, is_profiling=False):
94 self.wid = wid
95 self.inqueue = inqueue
96 self.outqueue = outqueue
97 self.worker_class = worker_class
98 self.initargs = initargs
99 self.wrap_exception = wrap_exception
100 self.is_profiling = is_profiling
101
102
103 class WorkerPool(object):
104 def __init__(self, worker_class, worker_count=None, initargs=()):
105 worker_count = worker_count or os.cpu_count() or 1
106
107 self._task_queue = multiprocessing.SimpleQueue()
108 self._result_queue = multiprocessing.SimpleQueue()
109 self._quick_put = self._task_queue._writer.send
110 self._quick_get = self._result_queue._reader.recv
111
112 self._callback = None
113 self._error_callback = None
114 self._listener = None
115
116 main_module = sys.modules['__main__']
117 is_profiling = os.path.basename(main_module.__file__) in [
118 'profile.py', 'cProfile.py']
119
120 self._pool = []
121 for i in range(worker_count):
122 worker_params = _WorkerParams(
123 i, self._task_queue, self._result_queue,
124 worker_class, initargs,
125 is_profiling=is_profiling)
126 w = multiprocessing.Process(target=worker_func,
127 args=(worker_params,))
128 w.name = w.name.replace('Process', 'PoolWorker')
129 w.daemon = True
130 w.start()
131 self._pool.append(w)
132
133 self._result_handler = threading.Thread(
134 target=WorkerPool._handleResults,
135 args=(self,))
136 self._result_handler.daemon = True
137 self._result_handler.start()
138
139 self._closed = False
140
141 def setHandler(self, callback=None, error_callback=None):
142 self._callback = callback
143 self._error_callback = error_callback
144
145 def queueJobs(self, jobs, handler=None):
146 if self._closed:
147 raise Exception("This worker pool has been closed.")
148 if self._listener is not None:
149 raise Exception("A previous job queue has not finished yet.")
150
151 if handler is not None:
152 self.setHandler(handler)
153
154 if not hasattr(jobs, '__len__'):
155 jobs = list(jobs)
156
157 res = AsyncResult(self, len(jobs))
158 if res._count == 0:
159 res._event.set()
160 return res
161
162 self._listener = res
163 for job in jobs:
164 self._quick_put((TASK_JOB, job))
165
166 return res
167
168 def close(self):
169 if self._listener is not None:
170 raise Exception("A previous job queue has not finished yet.")
171
172 logger.debug("Closing worker pool...")
173 handler = _ReportHandler(len(self._pool))
174 self._callback = handler._handle
175 for w in self._pool:
176 self._quick_put((TASK_END, None))
177 for w in self._pool:
178 w.join()
179
180 logger.debug("Waiting for reports...")
181 if not handler.wait(2):
182 missing = handler.reports.index(None)
183 logger.warning(
184 "Didn't receive all worker reports before timeout. "
185 "Missing report from worker %d." % missing)
186
187 logger.debug("Exiting result handler thread...")
188 self._result_queue.put(None)
189 self._result_handler.join()
190 self._closed = True
191
192 return handler.reports
193
194 @staticmethod
195 def _handleResults(pool):
196 while True:
197 try:
198 res = pool._quick_get()
199 except (EOFError, OSError):
200 logger.debug("Result handler thread encountered connection "
201 "problem, exiting.")
202 return
203
204 if res is None:
205 logger.debug("Result handler exiting.")
206 break
207
208 task_type, success, wid, data = res
209 try:
210 if success and pool._callback:
211 pool._callback(data)
212 elif not success and pool._error_callback:
213 pool._error_callback(data)
214 except Exception as ex:
215 logger.exception(ex)
216
217 if task_type == TASK_JOB:
218 pool._listener._onTaskDone()
219
220
221 class AsyncResult(object):
222 def __init__(self, pool, count):
223 self._pool = pool
224 self._count = count
225 self._event = threading.Event()
226
227 def ready(self):
228 return self._event.is_set()
229
230 def wait(self, timeout=None):
231 return self._event.wait(timeout)
232
233 def _onTaskDone(self):
234 self._count -= 1
235 if self._count == 0:
236 self._pool.setHandler(None)
237 self._pool._listener = None
238 self._event.set()
239
240
241 class _ReportHandler(object):
242 def __init__(self, worker_count):
243 self.reports = [None] * worker_count
244 self._count = worker_count
245 self._received = 0
246 self._event = threading.Event()
247
248 def wait(self, timeout=None):
249 return self._event.wait(timeout)
250
251 def _handle(self, res):
252 wid, data = res
253 if wid < 0 or wid > self._count:
254 logger.error("Ignoring report from unknown worker %d." % wid)
255 return
256
257 self._received += 1
258 self.reports[wid] = data
259
260 if self._received == self._count:
261 self._event.set()
262
263 def _handleError(self, res):
264 wid, data = res
265 logger.error("Worker %d failed to send its report." % wid)
266 logger.exception(data)
267