1414from __future__ import annotations
1515
1616import asyncio
17+ import collections
1718import os
1819import threading
1920import time
2021import weakref
21- from typing import Any , Callable , Optional
22+ from typing import Any , Callable , Optional , TypeVar
2223
2324_HAS_REGISTER_AT_FORK = hasattr (os , "register_at_fork" )
2425
2526# References to instances of _create_lock
2627_forkable_locks : weakref .WeakSet [threading .Lock ] = weakref .WeakSet ()
2728
29+ _T = TypeVar ("_T" )
30+
2831
2932def _create_lock () -> threading .Lock :
3033 """Represents a lock that is tracked upon instantiation using a WeakSet and
@@ -43,7 +46,14 @@ def _release_locks() -> None:
4346 lock .release ()
4447
4548
49+ # Needed only for synchro.py compat.
50+ def _Lock (lock : threading .Lock ) -> threading .Lock :
51+ return lock
52+
53+
4654class _ALock :
55+ __slots__ = ("_lock" ,)
56+
4757 def __init__ (self , lock : threading .Lock ) -> None :
4858 self ._lock = lock
4959
@@ -81,9 +91,18 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
8191 self .release ()
8292
8393
94+ def _safe_set_result (fut : asyncio .Future ) -> None :
95+ # Ensure the future hasn't been cancelled before calling set_result.
96+ if not fut .done ():
97+ fut .set_result (False )
98+
99+
84100class _ACondition :
101+ __slots__ = ("_condition" , "_waiters" )
102+
85103 def __init__ (self , condition : threading .Condition ) -> None :
86104 self ._condition = condition
105+ self ._waiters : collections .deque = collections .deque ()
87106
88107 async def acquire (self , blocking : bool = True , timeout : float = - 1 ) -> bool :
89108 if timeout > 0 :
@@ -99,30 +118,116 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
99118 await asyncio .sleep (0 )
100119
101120 async def wait (self , timeout : Optional [float ] = None ) -> bool :
102- if timeout is not None :
103- tstart = time .monotonic ()
104- while True :
105- notified = self ._condition .wait (0.001 )
106- if notified :
107- return True
108- if timeout is not None and (time .monotonic () - tstart ) > timeout :
109- return False
110-
111- async def wait_for (self , predicate : Callable , timeout : Optional [float ] = None ) -> bool :
112- if timeout is not None :
113- tstart = time .monotonic ()
114- while True :
115- notified = self ._condition .wait_for (predicate , 0.001 )
116- if notified :
117- return True
118- if timeout is not None and (time .monotonic () - tstart ) > timeout :
119- return False
121+ """Wait until notified.
122+
123+ If the calling task has not acquired the lock when this
124+ method is called, a RuntimeError is raised.
125+
126+ This method releases the underlying lock, and then blocks
127+ until it is awakened by a notify() or notify_all() call for
128+ the same condition variable in another task. Once
129+ awakened, it re-acquires the lock and returns True.
130+
131+ This method may return spuriously,
132+ which is why the caller should always
133+ re-check the state and be prepared to wait() again.
134+ """
135+ loop = asyncio .get_running_loop ()
136+ fut = loop .create_future ()
137+ self ._waiters .append ((loop , fut ))
138+ self .release ()
139+ try :
140+ try :
141+ try :
142+ await asyncio .wait_for (fut , timeout )
143+ return True
144+ except asyncio .TimeoutError :
145+ return False # Return false on timeout for sync pool compat.
146+ finally :
147+ # Must re-acquire lock even if wait is cancelled.
148+ # We only catch CancelledError here, since we don't want any
149+ # other (fatal) errors with the future to cause us to spin.
150+ err = None
151+ while True :
152+ try :
153+ await self .acquire ()
154+ break
155+ except asyncio .exceptions .CancelledError as e :
156+ err = e
157+
158+ self ._waiters .remove ((loop , fut ))
159+ if err is not None :
160+ try :
161+ raise err # Re-raise most recent exception instance.
162+ finally :
163+ err = None # Break reference cycles.
164+ except BaseException :
165+ # Any error raised out of here _may_ have occurred after this Task
166+ # believed to have been successfully notified.
167+ # Make sure to notify another Task instead. This may result
168+ # in a "spurious wakeup", which is allowed as part of the
169+ # Condition Variable protocol.
170+ self .notify (1 )
171+ raise
172+
173+ async def wait_for (self , predicate : Callable [[], _T ]) -> _T :
174+ """Wait until a predicate becomes true.
175+
176+ The predicate should be a callable whose result will be
177+ interpreted as a boolean value. The method will repeatedly
178+ wait() until it evaluates to true. The final predicate value is
179+ the return value.
180+ """
181+ result = predicate ()
182+ while not result :
183+ await self .wait ()
184+ result = predicate ()
185+ return result
120186
121187 def notify (self , n : int = 1 ) -> None :
122- self ._condition .notify (n )
188+ """By default, wake up one coroutine waiting on this condition, if any.
189+ If the calling coroutine has not acquired the lock when this method
190+ is called, a RuntimeError is raised.
191+
192+ This method wakes up at most n of the coroutines waiting for the
193+ condition variable; it is a no-op if no coroutines are waiting.
194+
195+ Note: an awakened coroutine does not actually return from its
196+ wait() call until it can reacquire the lock. Since notify() does
197+ not release the lock, its caller should.
198+ """
199+ idx = 0
200+ to_remove = []
201+ for loop , fut in self ._waiters :
202+ if idx >= n :
203+ break
204+
205+ if fut .done ():
206+ continue
207+
208+ try :
209+ loop .call_soon_threadsafe (_safe_set_result , fut )
210+ except RuntimeError :
211+ # Loop was closed, ignore.
212+ to_remove .append ((loop , fut ))
213+ continue
214+
215+ idx += 1
216+
217+ for waiter in to_remove :
218+ self ._waiters .remove (waiter )
123219
124220 def notify_all (self ) -> None :
125- self ._condition .notify_all ()
221+ """Wake up all threads waiting on this condition. This method acts
222+ like notify(), but wakes up all waiting threads instead of one. If the
223+ calling thread has not acquired the lock when this method is called,
224+ a RuntimeError is raised.
225+ """
226+ self .notify (len (self ._waiters ))
227+
228+ def locked (self ) -> bool :
229+ """Only needed for tests in test_locks."""
230+ return self ._condition ._lock .locked () # type: ignore[attr-defined]
126231
127232 def release (self ) -> None :
128233 self ._condition .release ()
0 commit comments