Skip to content

Commit b9c03a8

Browse files
authored
Improve type hints on WebSockets implementations (#2335)
1 parent 14bdf04 commit b9c03a8

File tree

4 files changed

+29
-51
lines changed

4 files changed

+29
-51
lines changed

uvicorn/protocols/http/h11_impl.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,8 @@
2020
)
2121
from uvicorn.config import Config
2222
from uvicorn.logging import TRACE_LOG_LEVEL
23-
from uvicorn.protocols.http.flow_control import (
24-
CLOSE_HEADER,
25-
HIGH_WATER_LIMIT,
26-
FlowControl,
27-
service_unavailable,
28-
)
29-
from uvicorn.protocols.utils import (
30-
get_client_addr,
31-
get_local_addr,
32-
get_path_with_query_string,
33-
get_remote_addr,
34-
is_ssl,
35-
)
23+
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
24+
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
3625
from uvicorn.server import ServerState
3726

3827

uvicorn/protocols/http/httptools_impl.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,8 @@
2121
)
2222
from uvicorn.config import Config
2323
from uvicorn.logging import TRACE_LOG_LEVEL
24-
from uvicorn.protocols.http.flow_control import (
25-
CLOSE_HEADER,
26-
HIGH_WATER_LIMIT,
27-
FlowControl,
28-
service_unavailable,
29-
)
30-
from uvicorn.protocols.utils import (
31-
get_client_addr,
32-
get_local_addr,
33-
get_path_with_query_string,
34-
get_remote_addr,
35-
is_ssl,
36-
)
24+
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
25+
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
3726
from uvicorn.server import ServerState
3827

3928
HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]')

uvicorn/protocols/websockets/websockets_impl.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77
from urllib.parse import unquote
88

99
import websockets
10+
import websockets.legacy.handshake
1011
from websockets.datastructures import Headers
1112
from websockets.exceptions import ConnectionClosed
13+
from websockets.extensions.base import ServerExtensionFactory
1214
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
1315
from websockets.legacy.server import HTTPResponse
1416
from websockets.server import WebSocketServerProtocol
1517
from websockets.typing import Subprotocol
1618

1719
from uvicorn._types import (
20+
ASGI3Application,
1821
ASGISendEvent,
1922
WebSocketAcceptEvent,
2023
WebSocketCloseEvent,
@@ -53,6 +56,7 @@ def is_serving(self) -> bool:
5356

5457
class WebSocketProtocol(WebSocketServerProtocol):
5558
extra_headers: list[tuple[str, str]]
59+
logger: logging.Logger | logging.LoggerAdapter[Any]
5660

5761
def __init__(
5862
self,
@@ -65,7 +69,7 @@ def __init__(
6569
config.load()
6670

6771
self.config = config
68-
self.app = config.loaded_app
72+
self.app = cast(ASGI3Application, config.loaded_app)
6973
self.loop = _loop or asyncio.get_event_loop()
7074
self.root_path = config.root_path
7175
self.app_state = app_state
@@ -92,7 +96,7 @@ def __init__(
9296

9397
self.ws_server: Server = Server() # type: ignore[assignment]
9498

95-
extensions = []
99+
extensions: list[ServerExtensionFactory] = []
96100
if self.config.ws_per_message_deflate:
97101
extensions.append(ServerPerMessageDeflateFactory())
98102

@@ -147,10 +151,10 @@ def shutdown(self) -> None:
147151
self.send_500_response()
148152
self.transport.close()
149153

150-
def on_task_complete(self, task: asyncio.Task) -> None:
154+
def on_task_complete(self, task: asyncio.Task[None]) -> None:
151155
self.tasks.discard(task)
152156

153-
async def process_request(self, path: str, headers: Headers) -> HTTPResponse | None:
157+
async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None:
154158
"""
155159
This hook is called to determine if the websocket should return
156160
an HTTP response and close.
@@ -161,15 +165,15 @@ async def process_request(self, path: str, headers: Headers) -> HTTPResponse | N
161165
"""
162166
path_portion, _, query_string = path.partition("?")
163167

164-
websockets.legacy.handshake.check_request(headers)
168+
websockets.legacy.handshake.check_request(request_headers)
165169

166-
subprotocols = []
167-
for header in headers.get_all("Sec-WebSocket-Protocol"):
170+
subprotocols: list[str] = []
171+
for header in request_headers.get_all("Sec-WebSocket-Protocol"):
168172
subprotocols.extend([token.strip() for token in header.split(",")])
169173

170174
asgi_headers = [
171175
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
172-
for name, value in headers.raw_items()
176+
for name, value in request_headers.raw_items()
173177
]
174178
path = unquote(path_portion)
175179
full_path = self.root_path + path
@@ -237,14 +241,13 @@ async def run_asgi(self) -> None:
237241
termination states.
238242
"""
239243
try:
240-
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
244+
result = await self.app(self.scope, self.asgi_receive, self.asgi_send) # type: ignore[func-returns-value]
241245
except ClientDisconnected:
242246
self.closed_event.set()
243247
self.transport.close()
244-
except BaseException as exc:
248+
except BaseException:
245249
self.closed_event.set()
246-
msg = "Exception in ASGI application\n"
247-
self.logger.error(msg, exc_info=exc)
250+
self.logger.exception("Exception in ASGI application\n")
248251
if not self.handshake_started_event.is_set():
249252
self.send_500_response()
250253
else:
@@ -253,13 +256,11 @@ async def run_asgi(self) -> None:
253256
else:
254257
self.closed_event.set()
255258
if not self.handshake_started_event.is_set():
256-
msg = "ASGI callable returned without sending handshake."
257-
self.logger.error(msg)
259+
self.logger.error("ASGI callable returned without sending handshake.")
258260
self.send_500_response()
259261
self.transport.close()
260262
elif result is not None:
261-
msg = "ASGI callable should return None, but returned '%s'."
262-
self.logger.error(msg, result)
263+
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
263264
await self.handshake_completed_event.wait()
264265
self.transport.close()
265266

uvicorn/protocols/websockets/wsproto_impl.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import logging
55
import typing
6-
from typing import Literal
6+
from typing import Literal, cast
77
from urllib.parse import unquote
88

99
import wsproto
@@ -13,6 +13,7 @@
1313
from wsproto.utilities import LocalProtocolError, RemoteProtocolError
1414

1515
from uvicorn._types import (
16+
ASGI3Application,
1617
ASGISendEvent,
1718
WebSocketAcceptEvent,
1819
WebSocketCloseEvent,
@@ -46,7 +47,7 @@ def __init__(
4647
config.load()
4748

4849
self.config = config
49-
self.app = config.loaded_app
50+
self.app = cast(ASGI3Application, config.loaded_app)
5051
self.loop = _loop or asyncio.get_event_loop()
5152
self.logger = logging.getLogger("uvicorn.error")
5253
self.root_path = config.root_path
@@ -156,7 +157,7 @@ def shutdown(self) -> None:
156157
self.send_500_response()
157158
self.transport.close()
158159

159-
def on_task_complete(self, task: asyncio.Task) -> None:
160+
def on_task_complete(self, task: asyncio.Task[None]) -> None:
160161
self.tasks.discard(task)
161162

162163
# Event handlers
@@ -220,7 +221,7 @@ def handle_ping(self, event: events.Ping) -> None:
220221
def send_500_response(self) -> None:
221222
if self.response_started or self.handshake_complete:
222223
return # we cannot send responses anymore
223-
headers = [
224+
headers: list[tuple[bytes, bytes]] = [
224225
(b"content-type", b"text/plain; charset=utf-8"),
225226
(b"connection", b"close"),
226227
]
@@ -230,7 +231,7 @@ def send_500_response(self) -> None:
230231

231232
async def run_asgi(self) -> None:
232233
try:
233-
result = await self.app(self.scope, self.receive, self.send)
234+
result = await self.app(self.scope, self.receive, self.send) # type: ignore[func-returns-value]
234235
except ClientDisconnected:
235236
self.transport.close()
236237
except BaseException:
@@ -239,13 +240,11 @@ async def run_asgi(self) -> None:
239240
self.transport.close()
240241
else:
241242
if not self.handshake_complete:
242-
msg = "ASGI callable returned without completing handshake."
243-
self.logger.error(msg)
243+
self.logger.error("ASGI callable returned without completing handshake.")
244244
self.send_500_response()
245245
self.transport.close()
246246
elif result is not None:
247-
msg = "ASGI callable should return None, but returned '%s'."
248-
self.logger.error(msg, result)
247+
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
249248
self.transport.close()
250249

251250
async def send(self, message: ASGISendEvent) -> None:

0 commit comments

Comments
 (0)