1010import inspect
1111import socket
1212import sys
13- import threading
1413import traceback
1514import warnings
1615from asyncio import AbstractEventLoop , AbstractEventLoopPolicy
5049 PytestPluginManager ,
5150)
5251
52+ _seen_markers : set [int ] = set ()
53+
54+
55+ def _warn_scope_deprecation_once (marker_id : int ) -> None :
56+ """Issues deprecation warning exactly once per marker ID."""
57+ if marker_id not in _seen_markers :
58+ _seen_markers .add (marker_id )
59+ warnings .warn (PytestDeprecationWarning (_MARKER_SCOPE_KWARG_DEPRECATION_WARNING ))
60+
61+
5362if sys .version_info >= (3 , 10 ):
5463 from typing import ParamSpec
5564else :
6372_ScopeName = Literal ["session" , "package" , "module" , "class" , "function" ]
6473_R = TypeVar ("_R" , bound = Union [Awaitable [Any ], AsyncIterator [Any ]])
6574_P = ParamSpec ("_P" )
75+ T = TypeVar ("T" )
6676FixtureFunction = Callable [_P , _R ]
77+ CoroutineFunction = Callable [_P , Awaitable [T ]]
6778
6879
6980class PytestAsyncioError (Exception ):
@@ -292,7 +303,7 @@ def _asyncgen_fixture_wrapper(
292303 gen_obj = fixture_function (* args , ** kwargs )
293304
294305 async def setup ():
295- res = await gen_obj .__anext__ () # type: ignore[union-attr]
306+ res = await gen_obj .__anext__ ()
296307 return res
297308
298309 context = contextvars .copy_context ()
@@ -305,7 +316,7 @@ def finalizer() -> None:
305316
306317 async def async_finalizer () -> None :
307318 try :
308- await gen_obj .__anext__ () # type: ignore[union-attr]
319+ await gen_obj .__anext__ ()
309320 except StopAsyncIteration :
310321 pass
311322 else :
@@ -334,8 +345,7 @@ def _wrap_async_fixture(
334345 runner : Runner ,
335346 request : FixtureRequest ,
336347) -> Callable [AsyncFixtureParams , AsyncFixtureReturnType ]:
337-
338- @functools .wraps (fixture_function ) # type: ignore[arg-type]
348+ @functools .wraps (fixture_function )
339349 def _async_fixture_wrapper (
340350 * args : AsyncFixtureParams .args ,
341351 ** kwargs : AsyncFixtureParams .kwargs ,
@@ -448,7 +458,7 @@ def _can_substitute(item: Function) -> bool:
448458 return inspect .iscoroutinefunction (func )
449459
450460 def runtest (self ) -> None :
451- synchronized_obj = wrap_in_sync ( self .obj )
461+ synchronized_obj = get_async_test_wrapper ( self , self .obj )
452462 with MonkeyPatch .context () as c :
453463 c .setattr (self , "obj" , synchronized_obj )
454464 super ().runtest ()
@@ -490,7 +500,7 @@ def _can_substitute(item: Function) -> bool:
490500 )
491501
492502 def runtest (self ) -> None :
493- synchronized_obj = wrap_in_sync ( self .obj )
503+ synchronized_obj = get_async_test_wrapper ( self , self .obj )
494504 with MonkeyPatch .context () as c :
495505 c .setattr (self , "obj" , synchronized_obj )
496506 super ().runtest ()
@@ -512,7 +522,10 @@ def _can_substitute(item: Function) -> bool:
512522 )
513523
514524 def runtest (self ) -> None :
515- synchronized_obj = wrap_in_sync (self .obj .hypothesis .inner_test )
525+ synchronized_obj = get_async_test_wrapper (
526+ self ,
527+ self .obj .hypothesis .inner_test ,
528+ )
516529 with MonkeyPatch .context () as c :
517530 c .setattr (self .obj .hypothesis , "inner_test" , synchronized_obj )
518531 super ().runtest ()
@@ -603,10 +616,60 @@ def _set_event_loop(loop: AbstractEventLoop | None) -> None:
603616 asyncio .set_event_loop (loop )
604617
605618
606- def _reinstate_event_loop_on_main_thread () -> None :
607- if threading .current_thread () is threading .main_thread ():
619+ _session_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
620+ contextvars .ContextVar (
621+ "_session_loop" ,
622+ default = None ,
623+ )
624+ )
625+ _package_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
626+ contextvars .ContextVar (
627+ "_package_loop" ,
628+ default = None ,
629+ )
630+ )
631+ _module_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
632+ contextvars .ContextVar (
633+ "_module_loop" ,
634+ default = None ,
635+ )
636+ )
637+ _class_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
638+ contextvars .ContextVar (
639+ "_class_loop" ,
640+ default = None ,
641+ )
642+ )
643+ _function_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
644+ contextvars .ContextVar (
645+ "_function_loop" ,
646+ default = None ,
647+ )
648+ )
649+
650+ _SCOPE_TO_CONTEXTVAR = {
651+ "session" : _session_loop ,
652+ "package" : _package_loop ,
653+ "module" : _module_loop ,
654+ "class" : _class_loop ,
655+ "function" : _function_loop ,
656+ }
657+
658+
659+ def _get_or_restore_event_loop (loop_scope : _ScopeName ) -> asyncio .AbstractEventLoop :
660+ """
661+ Get or restore the appropriate event loop for the given scope.
662+
663+ If we have a shared loop for this scope, restore and return it.
664+ Otherwise, get the current event loop or create a new one.
665+ """
666+ shared_loop = _SCOPE_TO_CONTEXTVAR [loop_scope ].get ()
667+ if shared_loop is not None :
608668 policy = _get_event_loop_policy ()
609- policy .set_event_loop (policy .new_event_loop ())
669+ policy .set_event_loop (shared_loop )
670+ return shared_loop
671+ else :
672+ return _get_event_loop_no_warn ()
610673
611674
612675@pytest .hookimpl (tryfirst = True , hookwrapper = True )
@@ -659,9 +722,22 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
659722 return None
660723
661724
662- def wrap_in_sync (
663- func : Callable [..., Awaitable [Any ]],
664- ):
725+ def get_async_test_wrapper (
726+ item : Function ,
727+ func : CoroutineFunction [_P , T ],
728+ ) -> Callable [_P , None ]:
729+ """Returns a synchronous wrapper for the specified async test function."""
730+ marker = item .get_closest_marker ("asyncio" )
731+ assert marker is not None
732+ default_loop_scope = _get_default_test_loop_scope (item .config )
733+ loop_scope = _get_marked_loop_scope (marker , default_loop_scope )
734+ return _wrap_in_sync (func , loop_scope )
735+
736+
737+ def _wrap_in_sync (
738+ func : CoroutineFunction [_P , T ],
739+ loop_scope : _ScopeName ,
740+ ) -> Callable [_P , None ]:
665741 """
666742 Return a sync wrapper around an async function executing it in the
667743 current event loop.
@@ -670,12 +746,7 @@ def wrap_in_sync(
670746 @functools .wraps (func )
671747 def inner (* args , ** kwargs ):
672748 coro = func (* args , ** kwargs )
673- try :
674- _loop = _get_event_loop_no_warn ()
675- except RuntimeError :
676- # Handle situation where asyncio.set_event_loop(None) removes shared loops.
677- _reinstate_event_loop_on_main_thread ()
678- _loop = _get_event_loop_no_warn ()
749+ _loop = _get_or_restore_event_loop (loop_scope )
679750 task = asyncio .ensure_future (coro , loop = _loop )
680751 try :
681752 _loop .run_until_complete (task )
@@ -758,7 +829,7 @@ def _get_marked_loop_scope(
758829 if "scope" in asyncio_marker .kwargs :
759830 if "loop_scope" in asyncio_marker .kwargs :
760831 raise pytest .UsageError (_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR )
761- warnings . warn ( PytestDeprecationWarning ( _MARKER_SCOPE_KWARG_DEPRECATION_WARNING ))
832+ _warn_scope_deprecation_once ( id ( asyncio_marker ))
762833 scope = asyncio_marker .kwargs .get ("loop_scope" ) or asyncio_marker .kwargs .get (
763834 "scope"
764835 )
@@ -768,7 +839,7 @@ def _get_marked_loop_scope(
768839 return scope
769840
770841
771- def _get_default_test_loop_scope (config : Config ) -> _ScopeName :
842+ def _get_default_test_loop_scope (config : Config ) -> Any :
772843 return config .getini ("asyncio_default_test_loop_scope" )
773844
774845
@@ -796,6 +867,8 @@ def _scoped_runner(
796867 debug_mode = _get_asyncio_debug (request .config )
797868 with _temporary_event_loop_policy (new_loop_policy ):
798869 runner = Runner (debug = debug_mode ).__enter__ ()
870+ shared_loop = runner .get_loop ()
871+ _SCOPE_TO_CONTEXTVAR [scope ].set (shared_loop )
799872 try :
800873 yield runner
801874 except Exception as e :
0 commit comments