Mercurial > piecrust2
comparison piecrust/workerpool.py @ 447:aefe70229fdd
bake: Commonize worker pool code between html and asset baking.
The `workerpool` package now defines a generic-ish worker pool. It's similar
to the Python framework pool but with a simpler use-case (only one way to
queue jobs) and support for workers to send a final "report" to the master
process, which we use to get timing information here.
The rest of the changes basically remove a whole bunch of duplicated code
that's not needed anymore.
author | Ludovic Chabant <ludovic@chabant.com> |
---|---|
date | Sun, 05 Jul 2015 00:09:41 -0700 |
parents | |
children | 838f3964f400 |
comparison
equal
deleted
inserted
replaced
446:4cdf6c2157a0 | 447:aefe70229fdd |
---|---|
1 import os | |
2 import sys | |
3 import logging | |
4 import threading | |
5 import multiprocessing | |
6 | |
7 | |
8 logger = logging.getLogger(__name__) | |
9 | |
10 | |
11 class IWorker(object): | |
12 def initialize(self): | |
13 raise NotImplementedError() | |
14 | |
15 def process(self, job): | |
16 raise NotImplementedError() | |
17 | |
18 def getReport(self): | |
19 return None | |
20 | |
21 | |
22 TASK_JOB = 0 | |
23 TASK_END = 1 | |
24 | |
25 | |
26 def worker_func(params): | |
27 if params.is_profiling: | |
28 try: | |
29 import cProfile as profile | |
30 except ImportError: | |
31 import profile | |
32 | |
33 params.is_profiling = False | |
34 name = params.worker_class.__name__ | |
35 profile.runctx('_real_worker_func(params)', | |
36 globals(), locals(), | |
37 filename='%s-%d.prof' % (name, params.wid)) | |
38 else: | |
39 _real_worker_func(params) | |
40 | |
41 | |
42 def _real_worker_func(params): | |
43 if hasattr(params.inqueue, '_writer'): | |
44 params.inqueue._writer.close() | |
45 params.outqueue._reader.close() | |
46 | |
47 wid = params.wid | |
48 logger.debug("Worker %d initializing..." % wid) | |
49 | |
50 w = params.worker_class(*params.initargs) | |
51 w.wid = wid | |
52 w.initialize() | |
53 | |
54 get = params.inqueue.get | |
55 put = params.outqueue.put | |
56 | |
57 completed = 0 | |
58 while True: | |
59 try: | |
60 task = get() | |
61 except (EOFError, OSError): | |
62 logger.debug("Worker %d encountered connection problem." % wid) | |
63 break | |
64 | |
65 task_type, task_data = task | |
66 if task_type == TASK_END: | |
67 logger.debug("Worker %d got end task, exiting." % wid) | |
68 try: | |
69 rep = (task_type, True, wid, (wid, w.getReport())) | |
70 except Exception as e: | |
71 if params.wrap_exception: | |
72 e = multiprocessing.ExceptionWithTraceback( | |
73 e, e.__traceback__) | |
74 rep = (task_type, False, wid, (wid, e)) | |
75 put(rep) | |
76 break | |
77 | |
78 try: | |
79 res = (task_type, True, wid, w.process(task_data)) | |
80 except Exception as e: | |
81 if params.wrap_exception: | |
82 e = multiprocessing.ExceptionWithTraceback(e, e.__traceback__) | |
83 res = (task_type, False, wid, e) | |
84 put(res) | |
85 | |
86 completed += 1 | |
87 | |
88 logger.debug("Worker %d completed %d tasks." % (wid, completed)) | |
89 | |
90 | |
91 class _WorkerParams(object): | |
92 def __init__(self, wid, inqueue, outqueue, worker_class, initargs=(), | |
93 wrap_exception=False, is_profiling=False): | |
94 self.wid = wid | |
95 self.inqueue = inqueue | |
96 self.outqueue = outqueue | |
97 self.worker_class = worker_class | |
98 self.initargs = initargs | |
99 self.wrap_exception = wrap_exception | |
100 self.is_profiling = is_profiling | |
101 | |
102 | |
103 class WorkerPool(object): | |
104 def __init__(self, worker_class, worker_count=None, initargs=()): | |
105 worker_count = worker_count or os.cpu_count() or 1 | |
106 | |
107 self._task_queue = multiprocessing.SimpleQueue() | |
108 self._result_queue = multiprocessing.SimpleQueue() | |
109 self._quick_put = self._task_queue._writer.send | |
110 self._quick_get = self._result_queue._reader.recv | |
111 | |
112 self._callback = None | |
113 self._error_callback = None | |
114 self._listener = None | |
115 | |
116 main_module = sys.modules['__main__'] | |
117 is_profiling = os.path.basename(main_module.__file__) in [ | |
118 'profile.py', 'cProfile.py'] | |
119 | |
120 self._pool = [] | |
121 for i in range(worker_count): | |
122 worker_params = _WorkerParams( | |
123 i, self._task_queue, self._result_queue, | |
124 worker_class, initargs, | |
125 is_profiling=is_profiling) | |
126 w = multiprocessing.Process(target=worker_func, | |
127 args=(worker_params,)) | |
128 w.name = w.name.replace('Process', 'PoolWorker') | |
129 w.daemon = True | |
130 w.start() | |
131 self._pool.append(w) | |
132 | |
133 self._result_handler = threading.Thread( | |
134 target=WorkerPool._handleResults, | |
135 args=(self,)) | |
136 self._result_handler.daemon = True | |
137 self._result_handler.start() | |
138 | |
139 self._closed = False | |
140 | |
141 def setHandler(self, callback=None, error_callback=None): | |
142 self._callback = callback | |
143 self._error_callback = error_callback | |
144 | |
145 def queueJobs(self, jobs, handler=None): | |
146 if self._closed: | |
147 raise Exception("This worker pool has been closed.") | |
148 if self._listener is not None: | |
149 raise Exception("A previous job queue has not finished yet.") | |
150 | |
151 if handler is not None: | |
152 self.setHandler(handler) | |
153 | |
154 if not hasattr(jobs, '__len__'): | |
155 jobs = list(jobs) | |
156 | |
157 res = AsyncResult(self, len(jobs)) | |
158 if res._count == 0: | |
159 res._event.set() | |
160 return res | |
161 | |
162 self._listener = res | |
163 for job in jobs: | |
164 self._quick_put((TASK_JOB, job)) | |
165 | |
166 return res | |
167 | |
168 def close(self): | |
169 if self._listener is not None: | |
170 raise Exception("A previous job queue has not finished yet.") | |
171 | |
172 logger.debug("Closing worker pool...") | |
173 handler = _ReportHandler(len(self._pool)) | |
174 self._callback = handler._handle | |
175 for w in self._pool: | |
176 self._quick_put((TASK_END, None)) | |
177 for w in self._pool: | |
178 w.join() | |
179 | |
180 logger.debug("Waiting for reports...") | |
181 if not handler.wait(2): | |
182 missing = handler.reports.index(None) | |
183 logger.warning( | |
184 "Didn't receive all worker reports before timeout. " | |
185 "Missing report from worker %d." % missing) | |
186 | |
187 logger.debug("Exiting result handler thread...") | |
188 self._result_queue.put(None) | |
189 self._result_handler.join() | |
190 self._closed = True | |
191 | |
192 return handler.reports | |
193 | |
194 @staticmethod | |
195 def _handleResults(pool): | |
196 while True: | |
197 try: | |
198 res = pool._quick_get() | |
199 except (EOFError, OSError): | |
200 logger.debug("Result handler thread encountered connection " | |
201 "problem, exiting.") | |
202 return | |
203 | |
204 if res is None: | |
205 logger.debug("Result handler exiting.") | |
206 break | |
207 | |
208 task_type, success, wid, data = res | |
209 try: | |
210 if success and pool._callback: | |
211 pool._callback(data) | |
212 elif not success and pool._error_callback: | |
213 pool._error_callback(data) | |
214 except Exception as ex: | |
215 logger.exception(ex) | |
216 | |
217 if task_type == TASK_JOB: | |
218 pool._listener._onTaskDone() | |
219 | |
220 | |
221 class AsyncResult(object): | |
222 def __init__(self, pool, count): | |
223 self._pool = pool | |
224 self._count = count | |
225 self._event = threading.Event() | |
226 | |
227 def ready(self): | |
228 return self._event.is_set() | |
229 | |
230 def wait(self, timeout=None): | |
231 return self._event.wait(timeout) | |
232 | |
233 def _onTaskDone(self): | |
234 self._count -= 1 | |
235 if self._count == 0: | |
236 self._pool.setHandler(None) | |
237 self._pool._listener = None | |
238 self._event.set() | |
239 | |
240 | |
241 class _ReportHandler(object): | |
242 def __init__(self, worker_count): | |
243 self.reports = [None] * worker_count | |
244 self._count = worker_count | |
245 self._received = 0 | |
246 self._event = threading.Event() | |
247 | |
248 def wait(self, timeout=None): | |
249 return self._event.wait(timeout) | |
250 | |
251 def _handle(self, res): | |
252 wid, data = res | |
253 if wid < 0 or wid > self._count: | |
254 logger.error("Ignoring report from unknown worker %d." % wid) | |
255 return | |
256 | |
257 self._received += 1 | |
258 self.reports[wid] = data | |
259 | |
260 if self._received == self._count: | |
261 self._event.set() | |
262 | |
263 def _handleError(self, res): | |
264 wid, data = res | |
265 logger.error("Worker %d failed to send its report." % wid) | |
266 logger.exception(data) | |
267 |