1010import inspect
1111import socket
1212import sys
13- import threading
1413import traceback
1514import warnings
1615from asyncio import AbstractEventLoop , AbstractEventLoopPolicy
5049 PytestPluginManager ,
5150)
5251
52+ _seen_markers : set [int ] = set ()
53+
54+
55+ def _warn_scope_deprecation_once (marker_id : int ) -> None :
56+ """Issues deprecation warning exactly once per marker ID."""
57+ if marker_id not in _seen_markers :
58+ _seen_markers .add (marker_id )
59+ warnings .warn (PytestDeprecationWarning (_MARKER_SCOPE_KWARG_DEPRECATION_WARNING ))
60+
61+
5362if sys .version_info >= (3 , 10 ):
5463 from typing import ParamSpec
5564else :
6372_ScopeName = Literal ["session" , "package" , "module" , "class" , "function" ]
6473_R = TypeVar ("_R" , bound = Union [Awaitable [Any ], AsyncIterator [Any ]])
6574_P = ParamSpec ("_P" )
75+ T = TypeVar ("T" )
6676FixtureFunction = Callable [_P , _R ]
77+ CoroutineFunction = Callable [_P , Awaitable [T ]]
6778
6879
6980class PytestAsyncioError (Exception ):
@@ -292,7 +303,7 @@ def _asyncgen_fixture_wrapper(
292303 gen_obj = fixture_function (* args , ** kwargs )
293304
294305 async def setup ():
295- res = await gen_obj .__anext__ () # type: ignore[union-attr]
306+ res = await gen_obj .__anext__ ()
296307 return res
297308
298309 context = contextvars .copy_context ()
@@ -305,7 +316,7 @@ def finalizer() -> None:
305316
306317 async def async_finalizer () -> None :
307318 try :
308- await gen_obj .__anext__ () # type: ignore[union-attr]
319+ await gen_obj .__anext__ ()
309320 except StopAsyncIteration :
310321 pass
311322 else :
@@ -334,8 +345,7 @@ def _wrap_async_fixture(
334345 runner : Runner ,
335346 request : FixtureRequest ,
336347) -> Callable [AsyncFixtureParams , AsyncFixtureReturnType ]:
337-
338- @functools .wraps (fixture_function ) # type: ignore[arg-type]
348+ @functools .wraps (fixture_function )
339349 def _async_fixture_wrapper (
340350 * args : AsyncFixtureParams .args ,
341351 ** kwargs : AsyncFixtureParams .kwargs ,
@@ -448,7 +458,7 @@ def _can_substitute(item: Function) -> bool:
448458 return inspect .iscoroutinefunction (func )
449459
450460 def runtest (self ) -> None :
451- synchronized_obj = wrap_in_sync ( self .obj )
461+ synchronized_obj = get_async_test_wrapper ( self , self .obj )
452462 with MonkeyPatch .context () as c :
453463 c .setattr (self , "obj" , synchronized_obj )
454464 super ().runtest ()
@@ -490,7 +500,7 @@ def _can_substitute(item: Function) -> bool:
490500 )
491501
492502 def runtest (self ) -> None :
493- synchronized_obj = wrap_in_sync ( self .obj )
503+ synchronized_obj = get_async_test_wrapper ( self , self .obj )
494504 with MonkeyPatch .context () as c :
495505 c .setattr (self , "obj" , synchronized_obj )
496506 super ().runtest ()
@@ -512,7 +522,10 @@ def _can_substitute(item: Function) -> bool:
512522 )
513523
514524 def runtest (self ) -> None :
515- synchronized_obj = wrap_in_sync (self .obj .hypothesis .inner_test )
525+ synchronized_obj = get_async_test_wrapper (
526+ self ,
527+ self .obj .hypothesis .inner_test ,
528+ )
516529 with MonkeyPatch .context () as c :
517530 c .setattr (self .obj .hypothesis , "inner_test" , synchronized_obj )
518531 super ().runtest ()
@@ -603,10 +616,68 @@ def _set_event_loop(loop: AbstractEventLoop | None) -> None:
603616 asyncio .set_event_loop (loop )
604617
605618
606- def _reinstate_event_loop_on_main_thread () -> None :
607- if threading .current_thread () is threading .main_thread ():
608- policy = _get_event_loop_policy ()
609- policy .set_event_loop (policy .new_event_loop ())
619+ _session_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
620+ contextvars .ContextVar (
621+ "_session_loop" ,
622+ default = None ,
623+ )
624+ )
625+ _package_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
626+ contextvars .ContextVar (
627+ "_package_loop" ,
628+ default = None ,
629+ )
630+ )
631+ _module_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
632+ contextvars .ContextVar (
633+ "_module_loop" ,
634+ default = None ,
635+ )
636+ )
637+ _class_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
638+ contextvars .ContextVar (
639+ "_class_loop" ,
640+ default = None ,
641+ )
642+ )
643+ _function_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
644+ contextvars .ContextVar (
645+ "_function_loop" ,
646+ default = None ,
647+ )
648+ )
649+
650+ _SCOPE_TO_CONTEXTVAR = {
651+ "session" : _session_loop ,
652+ "package" : _package_loop ,
653+ "module" : _module_loop ,
654+ "class" : _class_loop ,
655+ "function" : _function_loop ,
656+ }
657+
658+
659+ def _get_or_restore_event_loop (loop_scope : _ScopeName ) -> asyncio .AbstractEventLoop :
660+ """
661+ Get or restore the appropriate event loop for the given scope.
662+
663+ If we have a shared loop for this scope, restore and return it.
664+ Otherwise, get the current event loop or create a new one.
665+ """
666+ shared_loop = _SCOPE_TO_CONTEXTVAR [loop_scope ].get ()
667+ if shared_loop is not None :
668+ _reinstate_event_loop_on_main_thread (loop_scope )
669+ return shared_loop
670+ else :
671+ return _get_event_loop_no_warn ()
672+
673+
674+ def _reinstate_event_loop_on_main_thread (loop_scope : _ScopeName ) -> None :
675+ shared_loop = _SCOPE_TO_CONTEXTVAR [loop_scope ].get ()
676+ if shared_loop is None :
677+ return
678+
679+ policy = _get_event_loop_policy ()
680+ policy .set_event_loop (shared_loop )
610681
611682
612683@pytest .hookimpl (tryfirst = True , hookwrapper = True )
@@ -659,9 +730,22 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
659730 return None
660731
661732
662- def wrap_in_sync (
663- func : Callable [..., Awaitable [Any ]],
664- ):
733+ def get_async_test_wrapper (
734+ item : Function ,
735+ func : CoroutineFunction [_P , T ],
736+ ) -> Callable [_P , T ]:
737+ """Returns a synchronous wrapper for the specified async test function."""
738+ marker = item .get_closest_marker ("asyncio" )
739+ assert marker is not None
740+ default_loop_scope = _get_default_test_loop_scope (item .config )
741+ loop_scope = _get_marked_loop_scope (marker , default_loop_scope )
742+ return _wrap_in_sync (func , loop_scope )
743+
744+
745+ def _wrap_in_sync (
746+ func : CoroutineFunction [_P , T ],
747+ loop_scope : _ScopeName ,
748+ ) -> Callable [_P , T ]:
665749 """
666750 Return a sync wrapper around an async function executing it in the
667751 current event loop.
@@ -670,12 +754,7 @@ def wrap_in_sync(
670754 @functools .wraps (func )
671755 def inner (* args , ** kwargs ):
672756 coro = func (* args , ** kwargs )
673- try :
674- _loop = _get_event_loop_no_warn ()
675- except RuntimeError :
676- # Handle situation where asyncio.set_event_loop(None) removes shared loops.
677- _reinstate_event_loop_on_main_thread ()
678- _loop = _get_event_loop_no_warn ()
757+ _loop = _get_or_restore_event_loop (loop_scope )
679758 task = asyncio .ensure_future (coro , loop = _loop )
680759 try :
681760 _loop .run_until_complete (task )
@@ -758,7 +837,7 @@ def _get_marked_loop_scope(
758837 if "scope" in asyncio_marker .kwargs :
759838 if "loop_scope" in asyncio_marker .kwargs :
760839 raise pytest .UsageError (_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR )
761- warnings . warn ( PytestDeprecationWarning ( _MARKER_SCOPE_KWARG_DEPRECATION_WARNING ))
840+ _warn_scope_deprecation_once ( id ( asyncio_marker ))
762841 scope = asyncio_marker .kwargs .get ("loop_scope" ) or asyncio_marker .kwargs .get (
763842 "scope"
764843 )
@@ -796,6 +875,8 @@ def _scoped_runner(
796875 debug_mode = _get_asyncio_debug (request .config )
797876 with _temporary_event_loop_policy (new_loop_policy ):
798877 runner = Runner (debug = debug_mode ).__enter__ ()
878+ shared_loop = runner .get_loop ()
879+ _SCOPE_TO_CONTEXTVAR [scope ].set (shared_loop )
799880 try :
800881 yield runner
801882 except Exception as e :
0 commit comments