|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import sys |
4 | | -from collections.abc import Iterator |
| 4 | +from collections.abc import Generator, Iterator |
5 | 5 | from contextlib import ExitStack, contextmanager |
6 | | -from inspect import isasyncgenfunction, iscoroutinefunction |
| 6 | +from inspect import isasyncgenfunction, iscoroutinefunction, ismethod |
7 | 7 | from typing import Any, cast |
8 | 8 |
|
9 | 9 | import pytest |
10 | 10 | import sniffio |
| 11 | +from _pytest.fixtures import SubRequest |
11 | 12 | from _pytest.outcomes import Exit |
12 | 13 |
|
13 | 14 | from ._core._eventloop import get_all_backends, get_async_backend |
@@ -70,28 +71,56 @@ def pytest_configure(config: Any) -> None: |
70 | 71 | ) |
71 | 72 |
|
72 | 73 |
|
73 | | -def pytest_fixture_setup(fixturedef: Any, request: Any) -> None: |
74 | | - def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def] |
| 74 | +@pytest.hookimpl(hookwrapper=True) |
| 75 | +def pytest_fixture_setup(fixturedef: Any, request: Any) -> Generator[Any]: |
| 76 | + def wrapper( |
| 77 | + *args: Any, anyio_backend: Any, request: SubRequest, **kwargs: Any |
| 78 | + ) -> Any: |
| 79 | + # Rebind any fixture methods to the request instance |
| 80 | + if ( |
| 81 | + request.instance |
| 82 | + and ismethod(func) |
| 83 | + and type(func.__self__) is type(request.instance) |
| 84 | + ): |
| 85 | + local_func = func.__func__.__get__(request.instance) |
| 86 | + else: |
| 87 | + local_func = func |
| 88 | + |
75 | 89 | backend_name, backend_options = extract_backend_and_options(anyio_backend) |
76 | 90 | if has_backend_arg: |
77 | 91 | kwargs["anyio_backend"] = anyio_backend |
78 | 92 |
|
| 93 | + if has_request_arg: |
| 94 | + kwargs["request"] = anyio_backend |
| 95 | + |
79 | 96 | with get_runner(backend_name, backend_options) as runner: |
80 | | - if isasyncgenfunction(func): |
81 | | - yield from runner.run_asyncgen_fixture(func, kwargs) |
| 97 | + if isasyncgenfunction(local_func): |
| 98 | + yield from runner.run_asyncgen_fixture(local_func, kwargs) |
82 | 99 | else: |
83 | | - yield runner.run_fixture(func, kwargs) |
| 100 | + yield runner.run_fixture(local_func, kwargs) |
84 | 101 |
|
85 | 102 | # Only apply this to coroutine functions and async generator functions in requests |
86 | 103 | # that involve the anyio_backend fixture |
87 | 104 | func = fixturedef.func |
88 | 105 | if isasyncgenfunction(func) or iscoroutinefunction(func): |
89 | 106 | if "anyio_backend" in request.fixturenames: |
90 | | - has_backend_arg = "anyio_backend" in fixturedef.argnames |
91 | 107 | fixturedef.func = wrapper |
92 | | - if not has_backend_arg: |
| 108 | + original_argname = fixturedef.argnames |
| 109 | + |
| 110 | + if not (has_backend_arg := "anyio_backend" in fixturedef.argnames): |
93 | 111 | fixturedef.argnames += ("anyio_backend",) |
94 | 112 |
|
| 113 | + if not (has_request_arg := "request" in fixturedef.argnames): |
| 114 | + fixturedef.argnames += ("request",) |
| 115 | + |
| 116 | + try: |
| 117 | + return (yield) |
| 118 | + finally: |
| 119 | + fixturedef.func = func |
| 120 | + fixturedef.argnames = original_argname |
| 121 | + |
| 122 | + return (yield) |
| 123 | + |
95 | 124 |
|
96 | 125 | @pytest.hookimpl(tryfirst=True) |
97 | 126 | def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None: |
|
0 commit comments