blob: f1ce7324785ba983b8ba12f3283a70757ad0cc86 [file] [log] [blame]
Olivier Deprezf4ef2d02021-04-20 13:36:24 +02001"""Synchronization primitives."""
2
3__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')
4
5import collections
6import warnings
7
8from . import events
9from . import exceptions
10
11
12class _ContextManagerMixin:
13 async def __aenter__(self):
14 await self.acquire()
15 # We have no use for the "as ..." clause in the with
16 # statement for locks.
17 return None
18
19 async def __aexit__(self, exc_type, exc, tb):
20 self.release()
21
22
23class Lock(_ContextManagerMixin):
24 """Primitive lock objects.
25
26 A primitive lock is a synchronization primitive that is not owned
27 by a particular coroutine when locked. A primitive lock is in one
28 of two states, 'locked' or 'unlocked'.
29
30 It is created in the unlocked state. It has two basic methods,
31 acquire() and release(). When the state is unlocked, acquire()
32 changes the state to locked and returns immediately. When the
33 state is locked, acquire() blocks until a call to release() in
34 another coroutine changes it to unlocked, then the acquire() call
35 resets it to locked and returns. The release() method should only
36 be called in the locked state; it changes the state to unlocked
37 and returns immediately. If an attempt is made to release an
38 unlocked lock, a RuntimeError will be raised.
39
40 When more than one coroutine is blocked in acquire() waiting for
41 the state to turn to unlocked, only one coroutine proceeds when a
42 release() call resets the state to unlocked; first coroutine which
43 is blocked in acquire() is being processed.
44
45 acquire() is a coroutine and should be called with 'await'.
46
47 Locks also support the asynchronous context management protocol.
48 'async with lock' statement should be used.
49
50 Usage:
51
52 lock = Lock()
53 ...
54 await lock.acquire()
55 try:
56 ...
57 finally:
58 lock.release()
59
60 Context manager usage:
61
62 lock = Lock()
63 ...
64 async with lock:
65 ...
66
67 Lock objects can be tested for locking state:
68
69 if not lock.locked():
70 await lock.acquire()
71 else:
72 # lock is acquired
73 ...
74
75 """
76
77 def __init__(self, *, loop=None):
78 self._waiters = None
79 self._locked = False
80 if loop is None:
81 self._loop = events.get_event_loop()
82 else:
83 self._loop = loop
84 warnings.warn("The loop argument is deprecated since Python 3.8, "
85 "and scheduled for removal in Python 3.10.",
86 DeprecationWarning, stacklevel=2)
87
88 def __repr__(self):
89 res = super().__repr__()
90 extra = 'locked' if self._locked else 'unlocked'
91 if self._waiters:
92 extra = f'{extra}, waiters:{len(self._waiters)}'
93 return f'<{res[1:-1]} [{extra}]>'
94
95 def locked(self):
96 """Return True if lock is acquired."""
97 return self._locked
98
99 async def acquire(self):
100 """Acquire a lock.
101
102 This method blocks until the lock is unlocked, then sets it to
103 locked and returns True.
104 """
105 if (not self._locked and (self._waiters is None or
106 all(w.cancelled() for w in self._waiters))):
107 self._locked = True
108 return True
109
110 if self._waiters is None:
111 self._waiters = collections.deque()
112 fut = self._loop.create_future()
113 self._waiters.append(fut)
114
115 # Finally block should be called before the CancelledError
116 # handling as we don't want CancelledError to call
117 # _wake_up_first() and attempt to wake up itself.
118 try:
119 try:
120 await fut
121 finally:
122 self._waiters.remove(fut)
123 except exceptions.CancelledError:
124 if not self._locked:
125 self._wake_up_first()
126 raise
127
128 self._locked = True
129 return True
130
131 def release(self):
132 """Release a lock.
133
134 When the lock is locked, reset it to unlocked, and return.
135 If any other coroutines are blocked waiting for the lock to become
136 unlocked, allow exactly one of them to proceed.
137
138 When invoked on an unlocked lock, a RuntimeError is raised.
139
140 There is no return value.
141 """
142 if self._locked:
143 self._locked = False
144 self._wake_up_first()
145 else:
146 raise RuntimeError('Lock is not acquired.')
147
148 def _wake_up_first(self):
149 """Wake up the first waiter if it isn't done."""
150 if not self._waiters:
151 return
152 try:
153 fut = next(iter(self._waiters))
154 except StopIteration:
155 return
156
157 # .done() necessarily means that a waiter will wake up later on and
158 # either take the lock, or, if it was cancelled and lock wasn't
159 # taken already, will hit this again and wake up a new waiter.
160 if not fut.done():
161 fut.set_result(True)
162
163
164class Event:
165 """Asynchronous equivalent to threading.Event.
166
167 Class implementing event objects. An event manages a flag that can be set
168 to true with the set() method and reset to false with the clear() method.
169 The wait() method blocks until the flag is true. The flag is initially
170 false.
171 """
172
173 def __init__(self, *, loop=None):
174 self._waiters = collections.deque()
175 self._value = False
176 if loop is None:
177 self._loop = events.get_event_loop()
178 else:
179 self._loop = loop
180 warnings.warn("The loop argument is deprecated since Python 3.8, "
181 "and scheduled for removal in Python 3.10.",
182 DeprecationWarning, stacklevel=2)
183
184 def __repr__(self):
185 res = super().__repr__()
186 extra = 'set' if self._value else 'unset'
187 if self._waiters:
188 extra = f'{extra}, waiters:{len(self._waiters)}'
189 return f'<{res[1:-1]} [{extra}]>'
190
191 def is_set(self):
192 """Return True if and only if the internal flag is true."""
193 return self._value
194
195 def set(self):
196 """Set the internal flag to true. All coroutines waiting for it to
197 become true are awakened. Coroutine that call wait() once the flag is
198 true will not block at all.
199 """
200 if not self._value:
201 self._value = True
202
203 for fut in self._waiters:
204 if not fut.done():
205 fut.set_result(True)
206
207 def clear(self):
208 """Reset the internal flag to false. Subsequently, coroutines calling
209 wait() will block until set() is called to set the internal flag
210 to true again."""
211 self._value = False
212
213 async def wait(self):
214 """Block until the internal flag is true.
215
216 If the internal flag is true on entry, return True
217 immediately. Otherwise, block until another coroutine calls
218 set() to set the flag to true, then return True.
219 """
220 if self._value:
221 return True
222
223 fut = self._loop.create_future()
224 self._waiters.append(fut)
225 try:
226 await fut
227 return True
228 finally:
229 self._waiters.remove(fut)
230
231
232class Condition(_ContextManagerMixin):
233 """Asynchronous equivalent to threading.Condition.
234
235 This class implements condition variable objects. A condition variable
236 allows one or more coroutines to wait until they are notified by another
237 coroutine.
238
239 A new Lock object is created and used as the underlying lock.
240 """
241
242 def __init__(self, lock=None, *, loop=None):
243 if loop is None:
244 self._loop = events.get_event_loop()
245 else:
246 self._loop = loop
247 warnings.warn("The loop argument is deprecated since Python 3.8, "
248 "and scheduled for removal in Python 3.10.",
249 DeprecationWarning, stacklevel=2)
250
251 if lock is None:
252 lock = Lock(loop=loop)
253 elif lock._loop is not self._loop:
254 raise ValueError("loop argument must agree with lock")
255
256 self._lock = lock
257 # Export the lock's locked(), acquire() and release() methods.
258 self.locked = lock.locked
259 self.acquire = lock.acquire
260 self.release = lock.release
261
262 self._waiters = collections.deque()
263
264 def __repr__(self):
265 res = super().__repr__()
266 extra = 'locked' if self.locked() else 'unlocked'
267 if self._waiters:
268 extra = f'{extra}, waiters:{len(self._waiters)}'
269 return f'<{res[1:-1]} [{extra}]>'
270
271 async def wait(self):
272 """Wait until notified.
273
274 If the calling coroutine has not acquired the lock when this
275 method is called, a RuntimeError is raised.
276
277 This method releases the underlying lock, and then blocks
278 until it is awakened by a notify() or notify_all() call for
279 the same condition variable in another coroutine. Once
280 awakened, it re-acquires the lock and returns True.
281 """
282 if not self.locked():
283 raise RuntimeError('cannot wait on un-acquired lock')
284
285 self.release()
286 try:
287 fut = self._loop.create_future()
288 self._waiters.append(fut)
289 try:
290 await fut
291 return True
292 finally:
293 self._waiters.remove(fut)
294
295 finally:
296 # Must reacquire lock even if wait is cancelled
297 cancelled = False
298 while True:
299 try:
300 await self.acquire()
301 break
302 except exceptions.CancelledError:
303 cancelled = True
304
305 if cancelled:
306 raise exceptions.CancelledError
307
308 async def wait_for(self, predicate):
309 """Wait until a predicate becomes true.
310
311 The predicate should be a callable which result will be
312 interpreted as a boolean value. The final predicate value is
313 the return value.
314 """
315 result = predicate()
316 while not result:
317 await self.wait()
318 result = predicate()
319 return result
320
321 def notify(self, n=1):
322 """By default, wake up one coroutine waiting on this condition, if any.
323 If the calling coroutine has not acquired the lock when this method
324 is called, a RuntimeError is raised.
325
326 This method wakes up at most n of the coroutines waiting for the
327 condition variable; it is a no-op if no coroutines are waiting.
328
329 Note: an awakened coroutine does not actually return from its
330 wait() call until it can reacquire the lock. Since notify() does
331 not release the lock, its caller should.
332 """
333 if not self.locked():
334 raise RuntimeError('cannot notify on un-acquired lock')
335
336 idx = 0
337 for fut in self._waiters:
338 if idx >= n:
339 break
340
341 if not fut.done():
342 idx += 1
343 fut.set_result(False)
344
345 def notify_all(self):
346 """Wake up all threads waiting on this condition. This method acts
347 like notify(), but wakes up all waiting threads instead of one. If the
348 calling thread has not acquired the lock when this method is called,
349 a RuntimeError is raised.
350 """
351 self.notify(len(self._waiters))
352
353
354class Semaphore(_ContextManagerMixin):
355 """A Semaphore implementation.
356
357 A semaphore manages an internal counter which is decremented by each
358 acquire() call and incremented by each release() call. The counter
359 can never go below zero; when acquire() finds that it is zero, it blocks,
360 waiting until some other thread calls release().
361
362 Semaphores also support the context management protocol.
363
364 The optional argument gives the initial value for the internal
365 counter; it defaults to 1. If the value given is less than 0,
366 ValueError is raised.
367 """
368
369 def __init__(self, value=1, *, loop=None):
370 if value < 0:
371 raise ValueError("Semaphore initial value must be >= 0")
372 self._value = value
373 self._waiters = collections.deque()
374 if loop is None:
375 self._loop = events.get_event_loop()
376 else:
377 self._loop = loop
378 warnings.warn("The loop argument is deprecated since Python 3.8, "
379 "and scheduled for removal in Python 3.10.",
380 DeprecationWarning, stacklevel=2)
381
382 def __repr__(self):
383 res = super().__repr__()
384 extra = 'locked' if self.locked() else f'unlocked, value:{self._value}'
385 if self._waiters:
386 extra = f'{extra}, waiters:{len(self._waiters)}'
387 return f'<{res[1:-1]} [{extra}]>'
388
389 def _wake_up_next(self):
390 while self._waiters:
391 waiter = self._waiters.popleft()
392 if not waiter.done():
393 waiter.set_result(None)
394 return
395
396 def locked(self):
397 """Returns True if semaphore can not be acquired immediately."""
398 return self._value == 0
399
400 async def acquire(self):
401 """Acquire a semaphore.
402
403 If the internal counter is larger than zero on entry,
404 decrement it by one and return True immediately. If it is
405 zero on entry, block, waiting until some other coroutine has
406 called release() to make it larger than 0, and then return
407 True.
408 """
409 while self._value <= 0:
410 fut = self._loop.create_future()
411 self._waiters.append(fut)
412 try:
413 await fut
414 except:
415 # See the similar code in Queue.get.
416 fut.cancel()
417 if self._value > 0 and not fut.cancelled():
418 self._wake_up_next()
419 raise
420 self._value -= 1
421 return True
422
423 def release(self):
424 """Release a semaphore, incrementing the internal counter by one.
425 When it was zero on entry and another coroutine is waiting for it to
426 become larger than zero again, wake up that coroutine.
427 """
428 self._value += 1
429 self._wake_up_next()
430
431
432class BoundedSemaphore(Semaphore):
433 """A bounded semaphore implementation.
434
435 This raises ValueError in release() if it would increase the value
436 above the initial value.
437 """
438
439 def __init__(self, value=1, *, loop=None):
440 if loop:
441 warnings.warn("The loop argument is deprecated since Python 3.8, "
442 "and scheduled for removal in Python 3.10.",
443 DeprecationWarning, stacklevel=2)
444
445 self._bound_value = value
446 super().__init__(value, loop=loop)
447
448 def release(self):
449 if self._value >= self._bound_value:
450 raise ValueError('BoundedSemaphore released too many times')
451 super().release()