From 78afc2dcc5ffd62032b80d6a9da0cd148aa8559c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 14 May 2024 18:22:05 -0400 Subject: [PATCH 1/2] Use pytestmark --- tests/middleware/test_logging.py | 36 ++++++++--------- tests/protocols/test_http.py | 51 +----------------------- tests/protocols/test_websocket.py | 43 +-------------------- tests/test_auto_detection.py | 4 -- tests/test_default_headers.py | 64 ++++++++----------------------- tests/test_main.py | 11 +++--- tests/utils.py | 6 ++- 7 files changed, 48 insertions(+), 167 deletions(-) diff --git a/tests/middleware/test_logging.py b/tests/middleware/test_logging.py index 59bef1d37..7e9edd371 100644 --- a/tests/middleware/test_logging.py +++ b/tests/middleware/test_logging.py @@ -13,6 +13,7 @@ from tests.utils import run_server from uvicorn import Config +from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope if typing.TYPE_CHECKING: import sys @@ -27,9 +28,11 @@ WSProtocol: TypeAlias = "type[WebSocketProtocol | _WSProtocol]" +pytestmark = pytest.mark.anyio + @contextlib.contextmanager -def caplog_for_logger(caplog, logger_name): +def caplog_for_logger(caplog: pytest.LogCaptureFixture, logger_name: str) -> typing.Iterator[pytest.LogCaptureFixture]: logger = logging.getLogger(logger_name) logger.propagate, old_propagate = False, logger.propagate logger.addHandler(caplog.handler) @@ -40,14 +43,13 @@ def caplog_for_logger(caplog, logger_name): logger.propagate = old_propagate -async def app(scope, receive, send): +async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "http" await send({"type": "http.response.start", "status": 204, "headers": []}) await send({"type": "http.response.body", "body": b"", "more_body": False}) -@pytest.mark.anyio -async def test_trace_logging(caplog, logging_config, unused_tcp_port: int): +async def test_trace_logging(caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int): config = Config( app=app, log_level="trace", @@ -69,7 +71,6 @@ async def test_trace_logging(caplog, logging_config, unused_tcp_port: int): assert "ASGI [2] Completed" in messages.pop(0) -@pytest.mark.anyio async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging_config, unused_tcp_port: int): config = Config( app=app, @@ -88,14 +89,13 @@ async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging assert any(" - HTTP connection lost" in message for message in messages) -@pytest.mark.anyio async def test_trace_logging_on_ws_protocol( ws_protocol_cls: WSProtocol, caplog, logging_config, unused_tcp_port: int, ): - async def websocket_app(scope, receive, send): + async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "websocket" while True: message = await receive() @@ -125,9 +125,8 @@ async def open_connection(url): assert any(" - WebSocket connection lost" in message for message in messages) -@pytest.mark.anyio @pytest.mark.parametrize("use_colors", [(True), (False), (None)]) -async def test_access_logging(use_colors, caplog, logging_config, unused_tcp_port: int): +async def test_access_logging(use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int): config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port) with caplog_for_logger(caplog, "uvicorn.access"): async with run_server(config): @@ -139,9 +138,10 @@ async def test_access_logging(use_colors, caplog, logging_config, unused_tcp_por assert '"GET / HTTP/1.1" 204' in messages.pop() -@pytest.mark.anyio @pytest.mark.parametrize("use_colors", [(True), (False)]) -async def test_default_logging(use_colors, caplog, logging_config, unused_tcp_port: int): +async def test_default_logging( + use_colors: bool, caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int +): config = Config(app=app, use_colors=use_colors, log_config=logging_config, port=unused_tcp_port) with caplog_for_logger(caplog, "uvicorn.access"): async with run_server(config): @@ -158,9 +158,10 @@ async def test_default_logging(use_colors, caplog, logging_config, unused_tcp_po assert "Shutting down" in messages.pop(0) -@pytest.mark.anyio @pytest.mark.skipif(sys.platform == "win32", reason="require unix-like system") -async def test_running_log_using_uds(caplog, short_socket_name, unused_tcp_port: int): # pragma: py-win32 +async def test_running_log_using_uds( + caplog: pytest.LogCaptureFixture, short_socket_name: str, unused_tcp_port: int +): # pragma: py-win32 config = Config(app=app, uds=short_socket_name, port=unused_tcp_port) with caplog_for_logger(caplog, "uvicorn.access"): async with run_server(config): @@ -170,9 +171,8 @@ async def test_running_log_using_uds(caplog, short_socket_name, unused_tcp_port: assert f"Uvicorn running on unix socket {short_socket_name} (Press CTRL+C to quit)" in messages -@pytest.mark.anyio @pytest.mark.skipif(sys.platform == "win32", reason="require unix-like system") -async def test_running_log_using_fd(caplog, unused_tcp_port: int): # pragma: py-win32 +async def test_running_log_using_fd(caplog: pytest.LogCaptureFixture, unused_tcp_port: int): # pragma: py-win32 with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: fd = sock.fileno() config = Config(app=app, fd=fd, port=unused_tcp_port) @@ -184,9 +184,8 @@ async def test_running_log_using_fd(caplog, unused_tcp_port: int): # pragma: py assert f"Uvicorn running on socket {sockname} (Press CTRL+C to quit)" in messages -@pytest.mark.anyio -async def test_unknown_status_code(caplog, unused_tcp_port: int): - async def app(scope, receive, send): +async def test_unknown_status_code(caplog: pytest.LogCaptureFixture, unused_tcp_port: int): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "http" await send({"type": "http.response.start", "status": 599, "headers": []}) await send({"type": "http.response.body", "body": b"", "more_body": False}) @@ -202,7 +201,6 @@ async def app(scope, receive, send): assert '"GET / HTTP/1.1" 599' in messages.pop() -@pytest.mark.anyio async def test_server_start_with_port_zero(caplog: pytest.LogCaptureFixture): config = Config(app=app, port=0) async with run_server(config) as server: diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 677e72f9f..7bb110e00 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -39,6 +39,8 @@ HTTPProtocol: TypeAlias = "type[HttpToolsProtocol | H11Protocol]" WSProtocol: TypeAlias = "type[WebSocketProtocol | _WSProtocol]" +pytestmark = pytest.mark.anyio + WEBSOCKET_PROTOCOLS = WS_PROTOCOLS.keys() @@ -239,7 +241,6 @@ def get_connected_protocol( return protocol -@pytest.mark.anyio async def test_get_request(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") @@ -250,7 +251,6 @@ async def test_get_request(http_protocol_cls: HTTPProtocol): assert b"Hello, world" in protocol.transport.buffer -@pytest.mark.anyio @pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"]) async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture): get_request_with_query_string = b"\r\n".join( @@ -267,7 +267,6 @@ async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplo assert f'"GET {path} HTTP/1.1" 200' in caplog.records[0].message -@pytest.mark.anyio async def test_head_request(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") @@ -278,7 +277,6 @@ async def test_head_request(http_protocol_cls: HTTPProtocol): assert b"Hello, world" not in protocol.transport.buffer -@pytest.mark.anyio async def test_post_request(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): body = b"" @@ -298,7 +296,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b'Body: {"hello": "world"}' in protocol.transport.buffer -@pytest.mark.anyio async def test_keepalive(http_protocol_cls: HTTPProtocol): app = Response(b"", status_code=204) @@ -310,7 +307,6 @@ async def test_keepalive(http_protocol_cls: HTTPProtocol): assert not protocol.transport.is_closing() -@pytest.mark.anyio async def test_keepalive_timeout(http_protocol_cls: HTTPProtocol): app = Response(b"", status_code=204) @@ -325,7 +321,6 @@ async def test_keepalive_timeout(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_keepalive_timeout_with_pipelined_requests( http_protocol_cls: HTTPProtocol, ): @@ -351,7 +346,6 @@ async def test_keepalive_timeout_with_pipelined_requests( assert protocol.timeout_keep_alive_task is not None -@pytest.mark.anyio async def test_close(http_protocol_cls: HTTPProtocol): app = Response(b"", status_code=204, headers={"connection": "close"}) @@ -362,7 +356,6 @@ async def test_close(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_chunked_encoding(http_protocol_cls: HTTPProtocol): app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}) @@ -374,7 +367,6 @@ async def test_chunked_encoding(http_protocol_cls: HTTPProtocol): assert not protocol.transport.is_closing() -@pytest.mark.anyio async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol): app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}) @@ -386,7 +378,6 @@ async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol): assert not protocol.transport.is_closing() -@pytest.mark.anyio async def test_chunked_encoding_head_request( http_protocol_cls: HTTPProtocol, ): @@ -399,7 +390,6 @@ async def test_chunked_encoding_head_request( assert not protocol.transport.is_closing() -@pytest.mark.anyio async def test_pipelined_requests(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") @@ -421,7 +411,6 @@ async def test_pipelined_requests(http_protocol_cls: HTTPProtocol): protocol.transport.clear_buffer() -@pytest.mark.anyio async def test_undersized_request(http_protocol_cls: HTTPProtocol): app = Response(b"xxx", headers={"content-length": "10"}) @@ -431,7 +420,6 @@ async def test_undersized_request(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_oversized_request(http_protocol_cls: HTTPProtocol): app = Response(b"xxx" * 20, headers={"content-length": "10"}) @@ -441,7 +429,6 @@ async def test_oversized_request(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_large_post_request(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") @@ -452,7 +439,6 @@ async def test_large_post_request(http_protocol_cls: HTTPProtocol): assert not protocol.transport.read_paused -@pytest.mark.anyio async def test_invalid_http(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") @@ -461,7 +447,6 @@ async def test_invalid_http(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_app_exception(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): raise Exception() @@ -473,7 +458,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_exception_during_response(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) @@ -487,7 +471,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_no_response_returned(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): ... @@ -498,7 +481,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_partial_response_returned(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) @@ -510,7 +492,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) @@ -523,7 +504,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_missing_start_message(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.body", "body": b""}) @@ -535,7 +515,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) @@ -549,7 +528,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_value_returned(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) @@ -563,7 +541,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_early_disconnect(http_protocol_cls: HTTPProtocol): got_disconnect_event = False @@ -585,7 +562,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert got_disconnect_event -@pytest.mark.anyio async def test_early_response(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") @@ -597,7 +573,6 @@ async def test_early_response(http_protocol_cls: HTTPProtocol): assert not protocol.transport.is_closing() -@pytest.mark.anyio async def test_read_after_response(http_protocol_cls: HTTPProtocol): message_after_response = None @@ -615,7 +590,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert message_after_response == {"type": "http.disconnect"} -@pytest.mark.anyio async def test_http10_request(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "http" @@ -630,7 +604,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b"Version: 1.0" in protocol.transport.buffer -@pytest.mark.anyio async def test_root_path(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "http" @@ -646,7 +619,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b"root_path=/app path=/app/" in protocol.transport.buffer -@pytest.mark.anyio async def test_raw_path(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "http" @@ -664,7 +636,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b"Done" in protocol.transport.buffer -@pytest.mark.anyio async def test_max_concurrency(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") @@ -686,7 +657,6 @@ async def test_max_concurrency(http_protocol_cls: HTTPProtocol): ) -@pytest.mark.anyio async def test_shutdown_during_request(http_protocol_cls: HTTPProtocol): app = Response(b"", status_code=204) @@ -698,7 +668,6 @@ async def test_shutdown_during_request(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_shutdown_during_idle(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") @@ -708,7 +677,6 @@ async def test_shutdown_during_idle(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -@pytest.mark.anyio async def test_100_continue_sent_when_body_consumed(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): body = b"" @@ -740,7 +708,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b'Body: {"hello": "world"}' in protocol.transport.buffer -@pytest.mark.anyio async def test_100_continue_not_sent_when_body_not_consumed( http_protocol_cls: HTTPProtocol, ): @@ -764,7 +731,6 @@ async def test_100_continue_not_sent_when_body_not_consumed( assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer -@pytest.mark.anyio async def test_supported_upgrade_request(http_protocol_cls: HTTPProtocol): pytest.importorskip("wsproto") @@ -775,7 +741,6 @@ async def test_supported_upgrade_request(http_protocol_cls: HTTPProtocol): assert b"HTTP/1.1 426 " in protocol.transport.buffer -@pytest.mark.anyio async def test_unsupported_ws_upgrade_request(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") @@ -786,7 +751,6 @@ async def test_unsupported_ws_upgrade_request(http_protocol_cls: HTTPProtocol): assert b"Hello, world" in protocol.transport.buffer -@pytest.mark.anyio async def test_unsupported_ws_upgrade_request_warn_on_auto( caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol ): @@ -804,7 +768,6 @@ async def test_unsupported_ws_upgrade_request_warn_on_auto( assert msg in warnings -@pytest.mark.anyio async def test_http2_upgrade_request(http_protocol_cls: HTTPProtocol, ws_protocol_cls: WSProtocol): app = Response("Hello, world", media_type="text/plain") @@ -826,7 +789,6 @@ async def asgi(receive: ASGIReceiveCallable, send: ASGISendCallable): return asgi -@pytest.mark.anyio @pytest.mark.parametrize( "asgi2or3_app, expected_scopes", [ @@ -845,7 +807,6 @@ async def test_scopes( assert expected_scopes == protocol.scope.get("asgi") -@pytest.mark.anyio @pytest.mark.parametrize( "request_line", [ @@ -915,7 +876,6 @@ def send_fragmented_req(path: str): t.join() -@pytest.mark.anyio async def test_huge_headers_h11protocol_failure(): app = Response("Hello, world", media_type="text/plain") @@ -928,7 +888,6 @@ async def test_huge_headers_h11protocol_failure(): assert b"Invalid HTTP request received." in protocol.transport.buffer -@pytest.mark.anyio @skip_if_no_httptools async def test_huge_headers_httptools_will_pass(): app = Response("Hello, world", media_type="text/plain") @@ -943,7 +902,6 @@ async def test_huge_headers_httptools_will_pass(): assert b"Hello, world" in protocol.transport.buffer -@pytest.mark.anyio async def test_huge_headers_h11protocol_failure_with_setting(): app = Response("Hello, world", media_type="text/plain") @@ -956,7 +914,6 @@ async def test_huge_headers_h11protocol_failure_with_setting(): assert b"Invalid HTTP request received." in protocol.transport.buffer -@pytest.mark.anyio @skip_if_no_httptools async def test_huge_headers_httptools(): app = Response("Hello, world", media_type="text/plain") @@ -971,7 +928,6 @@ async def test_huge_headers_httptools(): assert b"Hello, world" in protocol.transport.buffer -@pytest.mark.anyio async def test_huge_headers_h11_max_incomplete(): app = Response("Hello, world", media_type="text/plain") @@ -983,7 +939,6 @@ async def test_huge_headers_h11_max_incomplete(): assert b"Hello, world" in protocol.transport.buffer -@pytest.mark.anyio async def test_return_close_header(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") @@ -998,7 +953,6 @@ async def test_return_close_header(http_protocol_cls: HTTPProtocol): assert b"connection: close" in protocol.transport.buffer.lower() -@pytest.mark.anyio async def test_iterator_headers(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): headers = iter([(b"x-test-header", b"test value")]) @@ -1011,7 +965,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b"x-test-header: test value" in protocol.transport.buffer -@pytest.mark.anyio async def test_lifespan_state(http_protocol_cls: HTTPProtocol): expected_states = [{"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}] diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 2d1e667de..471c9b435 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -48,6 +48,8 @@ HTTPProtocol: TypeAlias = "type[H11Protocol | HttpToolsProtocol]" WSProtocol: TypeAlias = "type[_WSProtocol | WebSocketProtocol]" +pytestmark = pytest.mark.anyio + class WebSocketResponse: def __init__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -84,7 +86,6 @@ async def wsresponse(url): return await client.get(url, headers=headers) -@pytest.mark.anyio async def test_invalid_upgrade(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): def app(scope: Scope): return None @@ -117,7 +118,6 @@ def app(scope: Scope): ) -@pytest.mark.anyio async def test_accept_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): @@ -139,7 +139,6 @@ async def open_connection(url): assert is_open -@pytest.mark.anyio async def test_shutdown(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): @@ -158,7 +157,6 @@ async def websocket_connect(self, message): await server.shutdown() -@pytest.mark.anyio async def test_supports_permessage_deflate_extension( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -183,7 +181,6 @@ async def open_connection(url): assert "permessage-deflate" in extension_names -@pytest.mark.anyio async def test_can_disable_permessage_deflate_extension( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -211,7 +208,6 @@ async def open_connection(url: str): assert "permessage-deflate" not in extension_names -@pytest.mark.anyio async def test_close_connection(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): @@ -236,7 +232,6 @@ async def open_connection(url: str): assert not is_open -@pytest.mark.anyio async def test_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): @@ -262,7 +257,6 @@ async def open_connection(url: str): assert is_open -@pytest.mark.anyio async def test_extra_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): @@ -284,7 +278,6 @@ async def open_connection(url: str): assert extra_headers.get("extra") == "header" -@pytest.mark.anyio async def test_path_and_raw_path(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): @@ -310,7 +303,6 @@ async def open_connection(url: str): assert is_open -@pytest.mark.anyio async def test_send_text_data_to_client( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -335,7 +327,6 @@ async def get_data(url: str): assert data == "123" -@pytest.mark.anyio async def test_send_binary_data_to_client( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -360,7 +351,6 @@ async def get_data(url: str): assert data == b"123" -@pytest.mark.anyio async def test_send_and_close_connection( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -393,7 +383,6 @@ async def get_data(url: str): assert not is_open -@pytest.mark.anyio async def test_send_text_data_to_server( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -422,7 +411,6 @@ async def send_text(url: str): assert data == "abc" -@pytest.mark.anyio async def test_send_binary_data_to_server( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -451,7 +439,6 @@ async def send_text(url: str): assert data == b"abc" -@pytest.mark.anyio async def test_send_after_protocol_close( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -486,7 +473,6 @@ async def get_data(url: str): assert not is_open -@pytest.mark.anyio async def test_missing_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): pass @@ -507,7 +493,6 @@ async def connect(url: str): assert exc_info.value.status_code == 500 -@pytest.mark.anyio async def test_send_before_handshake( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -530,7 +515,6 @@ async def connect(url: str): assert exc_info.value.status_code == 500 -@pytest.mark.anyio async def test_duplicate_handshake(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "websocket.accept"}) @@ -553,7 +537,6 @@ async def connect(url: str): assert exc_info.value.code == 1006 -@pytest.mark.anyio async def test_asgi_return_value(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): """ The ASGI callable should return 'None'. If it doesn't, make sure that @@ -581,7 +564,6 @@ async def connect(url: str): assert exc_info.value.code == 1006 -@pytest.mark.anyio @pytest.mark.parametrize("code", [None, 1000, 1001]) @pytest.mark.parametrize( "reason", @@ -633,7 +615,6 @@ async def websocket_session(url: str): assert exc_info.value.reason == (reason or "") -@pytest.mark.anyio async def test_client_close(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): while True: @@ -661,7 +642,6 @@ async def websocket_session(url: str): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") -@pytest.mark.anyio async def test_client_connection_lost( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -695,7 +675,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert got_disconnect_event_before_shutdown is True -@pytest.mark.anyio async def test_client_connection_lost_on_send( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -729,7 +708,6 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert got_disconnect_event is True -@pytest.mark.anyio async def test_connection_lost_before_handshake_complete( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -778,7 +756,6 @@ async def websocket_session(uri: str): await task -@pytest.mark.anyio async def test_send_close_on_server_shutdown( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -823,7 +800,6 @@ async def websocket_session(uri: str): task.cancel() -@pytest.mark.anyio @pytest.mark.parametrize("subprotocol", ["proto1", "proto2"]) async def test_subprotocols( ws_protocol_cls: WSProtocol, @@ -857,7 +833,6 @@ async def get_subprotocol(url: str): MAX_WS_BYTES_PLUS1 = MAX_WS_BYTES + 1 -@pytest.mark.anyio @pytest.mark.parametrize( "client_size_sent, server_size_max, expected_result", [ @@ -911,7 +886,6 @@ async def send_text(url: str): assert e.value.code == expected_result -@pytest.mark.anyio async def test_server_reject_connection( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -952,7 +926,6 @@ async def websocket_session(url): assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} -@pytest.mark.anyio async def test_server_reject_connection_with_response( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -990,7 +963,6 @@ async def websocket_session(url): assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} -@pytest.mark.anyio async def test_server_reject_connection_with_multibody_response( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -1043,7 +1015,6 @@ async def websocket_session(url: str): assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} -@pytest.mark.anyio async def test_server_reject_connection_with_invalid_status( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -1085,7 +1056,6 @@ async def websocket_session(url): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") -@pytest.mark.anyio async def test_server_reject_connection_with_body_nolength( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -1130,7 +1100,6 @@ async def websocket_session(url): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") -@pytest.mark.anyio async def test_server_reject_connection_with_invalid_msg( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -1168,7 +1137,6 @@ async def websocket_session(url): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") -@pytest.mark.anyio async def test_server_reject_connection_with_missing_body( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -1205,7 +1173,6 @@ async def websocket_session(url): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") -@pytest.mark.anyio async def test_server_multiple_websocket_http_response_start_events( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -1257,7 +1224,6 @@ async def websocket_session(url: str): ) -@pytest.mark.anyio async def test_server_can_read_messages_in_buffer_after_close( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -1299,7 +1265,6 @@ async def send_text(url: str): assert disconnect_message == {"type": "websocket.disconnect", "code": 1000} -@pytest.mark.anyio async def test_default_server_headers( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -1323,7 +1288,6 @@ async def open_connection(url: str): assert headers.get("server") == "uvicorn" and "date" in headers -@pytest.mark.anyio async def test_no_server_headers(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): async def websocket_connect(self, message): @@ -1346,7 +1310,6 @@ async def open_connection(url: str): assert "server" not in headers -@pytest.mark.anyio @skip_if_no_wsproto async def test_no_date_header_on_wsproto(http_protocol_cls: HTTPProtocol, unused_tcp_port: int): class App(WebSocketResponse): @@ -1370,7 +1333,6 @@ async def open_connection(url: str): assert "date" not in headers -@pytest.mark.anyio async def test_multiple_server_header( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): @@ -1402,7 +1364,6 @@ async def open_connection(url: str): assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"] -@pytest.mark.anyio async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int): expected_states = [ {"a": 123, "b": [1]}, diff --git a/tests/test_auto_detection.py b/tests/test_auto_detection.py index 54ab904e5..1f79b3786 100644 --- a/tests/test_auto_detection.py +++ b/tests/test_auto_detection.py @@ -32,10 +32,6 @@ async def app(scope, receive, send): pass # pragma: no cover -# TODO: Add pypy to our testing matrix, and assert we get the correct classes -# dependent on the platform we're running the tests under. - - def test_loop_auto(): auto_loop_setup() policy = asyncio.get_event_loop_policy() diff --git a/tests/test_default_headers.py b/tests/test_default_headers.py index ab3375ea6..4f75dbded 100644 --- a/tests/test_default_headers.py +++ b/tests/test_default_headers.py @@ -1,17 +1,21 @@ +from __future__ import annotations + import httpx import pytest from tests.utils import run_server from uvicorn import Config +from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope + +pytestmark = pytest.mark.anyio -async def app(scope, receive, send): +async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: assert scope["type"] == "http" await send({"type": "http.response.start", "status": 200, "headers": []}) await send({"type": "http.response.body", "body": b"", "more_body": False}) -@pytest.mark.anyio async def test_default_default_headers(unused_tcp_port: int): config = Config(app=app, loop="asyncio", limit_max_requests=1, port=unused_tcp_port) async with run_server(config): @@ -20,79 +24,45 @@ async def test_default_default_headers(unused_tcp_port: int): assert response.headers["server"] == "uvicorn" and response.headers["date"] -@pytest.mark.anyio async def test_override_server_header(unused_tcp_port: int): - config = Config( - app=app, - loop="asyncio", - limit_max_requests=1, - headers=[("Server", "over-ridden")], - port=unused_tcp_port, - ) + headers: list[tuple[str, str]] = [("Server", "over-ridden")] + config = Config(app=app, loop="asyncio", limit_max_requests=1, headers=headers, port=unused_tcp_port) async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") assert response.headers["server"] == "over-ridden" and response.headers["date"] -@pytest.mark.anyio async def test_disable_default_server_header(unused_tcp_port: int): - config = Config( - app=app, - loop="asyncio", - limit_max_requests=1, - server_header=False, - port=unused_tcp_port, - ) + config = Config(app=app, loop="asyncio", limit_max_requests=1, server_header=False, port=unused_tcp_port) async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") assert "server" not in response.headers -@pytest.mark.anyio async def test_override_server_header_multiple_times(unused_tcp_port: int): - config = Config( - app=app, - loop="asyncio", - limit_max_requests=1, - headers=[("Server", "over-ridden"), ("Server", "another-value")], - port=unused_tcp_port, - ) + headers: list[tuple[str, str]] = [("Server", "over-ridden"), ("Server", "another-value")] + config = Config(app=app, loop="asyncio", limit_max_requests=1, headers=headers, port=unused_tcp_port) async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") assert response.headers["server"] == "over-ridden, another-value" and response.headers["date"] -@pytest.mark.anyio async def test_add_additional_header(unused_tcp_port: int): - config = Config( - app=app, - loop="asyncio", - limit_max_requests=1, - headers=[("X-Additional", "new-value")], - port=unused_tcp_port, - ) + headers: list[tuple[str, str]] = [("X-Additional", "new-value")] + config = Config(app=app, loop="asyncio", limit_max_requests=1, headers=headers, port=unused_tcp_port) async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") - assert ( - response.headers["x-additional"] == "new-value" - and response.headers["server"] == "uvicorn" - and response.headers["date"] - ) + assert response.headers["x-additional"] == "new-value" + assert response.headers["server"] == "uvicorn" + assert response.headers["date"] -@pytest.mark.anyio async def test_disable_default_date_header(unused_tcp_port: int): - config = Config( - app=app, - loop="asyncio", - limit_max_requests=1, - date_header=False, - port=unused_tcp_port, - ) + config = Config(app=app, loop="asyncio", limit_max_requests=1, date_header=False, port=unused_tcp_port) async with run_server(config): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{unused_tcp_port}") diff --git a/tests/test_main.py b/tests/test_main.py index fc2532749..b520f27dd 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,17 +7,20 @@ from tests.utils import run_server from uvicorn import Server +from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope from uvicorn.config import Config from uvicorn.main import run +pytestmark = pytest.mark.anyio -async def app(scope, receive, send): + +async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: assert scope["type"] == "http" await send({"type": "http.response.start", "status": 204, "headers": []}) await send({"type": "http.response.body", "body": b"", "more_body": False}) -def _has_ipv6(host): +def _has_ipv6(host: str): sock = None has_ipv6 = False if socket.has_ipv6: @@ -32,7 +35,6 @@ def _has_ipv6(host): return has_ipv6 -@pytest.mark.anyio @pytest.mark.parametrize( "host, url", [ @@ -54,7 +56,6 @@ async def test_run(host, url: str, unused_tcp_port: int): assert response.status_code == 204 -@pytest.mark.anyio async def test_run_multiprocess(unused_tcp_port: int): config = Config(app=app, loop="asyncio", workers=2, limit_max_requests=1, port=unused_tcp_port) async with run_server(config): @@ -63,7 +64,6 @@ async def test_run_multiprocess(unused_tcp_port: int): assert response.status_code == 204 -@pytest.mark.anyio async def test_run_reload(unused_tcp_port: int): config = Config(app=app, loop="asyncio", reload=True, limit_max_requests=1, port=unused_tcp_port) async with run_server(config): @@ -107,7 +107,6 @@ def test_run_match_config_params() -> None: assert config_params == run_params -@pytest.mark.anyio async def test_exit_on_create_server_with_invalid_host() -> None: with pytest.raises(SystemExit) as exc_info: config = Config(app=app, host="illegal_host") diff --git a/tests/utils.py b/tests/utils.py index 909064651..32dd76b1a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import asyncio import os +from collections.abc import AsyncIterator from contextlib import asynccontextmanager, contextmanager from pathlib import Path +from socket import socket from uvicorn import Config, Server @asynccontextmanager -async def run_server(config: Config, sockets=None): +async def run_server(config: Config, sockets: list[socket] | None = None) -> AsyncIterator[Server]: server = Server(config=config) task = asyncio.create_task(server.serve(sockets=sockets)) await asyncio.sleep(0.1) From 775597f0d4bd226873c1f16e6e788086c7a575e9 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 14 May 2024 18:28:45 -0400 Subject: [PATCH 2/2] Fix linter --- tests/middleware/test_logging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/middleware/test_logging.py b/tests/middleware/test_logging.py index 7e9edd371..f27633aa5 100644 --- a/tests/middleware/test_logging.py +++ b/tests/middleware/test_logging.py @@ -203,8 +203,8 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable async def test_server_start_with_port_zero(caplog: pytest.LogCaptureFixture): config = Config(app=app, port=0) - async with run_server(config) as server: - server = server.servers[0] + async with run_server(config) as _server: + server = _server.servers[0] sock = server.sockets[0] host, port = sock.getsockname() messages = [record.message for record in caplog.records if "uvicorn" in record.name]