77from urllib .parse import unquote
88
99import websockets
10+ import websockets .legacy .handshake
1011from websockets .datastructures import Headers
1112from websockets .exceptions import ConnectionClosed
13+ from websockets .extensions .base import ServerExtensionFactory
1214from websockets .extensions .permessage_deflate import ServerPerMessageDeflateFactory
1315from websockets .legacy .server import HTTPResponse
1416from websockets .server import WebSocketServerProtocol
1517from websockets .typing import Subprotocol
1618
1719from uvicorn ._types import (
20+ ASGI3Application ,
1821 ASGISendEvent ,
1922 WebSocketAcceptEvent ,
2023 WebSocketCloseEvent ,
@@ -53,6 +56,7 @@ def is_serving(self) -> bool:
5356
5457class WebSocketProtocol (WebSocketServerProtocol ):
5558 extra_headers : list [tuple [str , str ]]
59+ logger : logging .Logger | logging .LoggerAdapter [Any ]
5660
5761 def __init__ (
5862 self ,
@@ -65,7 +69,7 @@ def __init__(
6569 config .load ()
6670
6771 self .config = config
68- self .app = config .loaded_app
72+ self .app = cast ( ASGI3Application , config .loaded_app )
6973 self .loop = _loop or asyncio .get_event_loop ()
7074 self .root_path = config .root_path
7175 self .app_state = app_state
@@ -92,7 +96,7 @@ def __init__(
9296
9397 self .ws_server : Server = Server () # type: ignore[assignment]
9498
95- extensions = []
99+ extensions : list [ ServerExtensionFactory ] = []
96100 if self .config .ws_per_message_deflate :
97101 extensions .append (ServerPerMessageDeflateFactory ())
98102
@@ -147,10 +151,10 @@ def shutdown(self) -> None:
147151 self .send_500_response ()
148152 self .transport .close ()
149153
150- def on_task_complete (self , task : asyncio .Task ) -> None :
154+ def on_task_complete (self , task : asyncio .Task [ None ] ) -> None :
151155 self .tasks .discard (task )
152156
153- async def process_request (self , path : str , headers : Headers ) -> HTTPResponse | None :
157+ async def process_request (self , path : str , request_headers : Headers ) -> HTTPResponse | None :
154158 """
155159 This hook is called to determine if the websocket should return
156160 an HTTP response and close.
@@ -161,15 +165,15 @@ async def process_request(self, path: str, headers: Headers) -> HTTPResponse | N
161165 """
162166 path_portion , _ , query_string = path .partition ("?" )
163167
164- websockets .legacy .handshake .check_request (headers )
168+ websockets .legacy .handshake .check_request (request_headers )
165169
166- subprotocols = []
167- for header in headers .get_all ("Sec-WebSocket-Protocol" ):
170+ subprotocols : list [ str ] = []
171+ for header in request_headers .get_all ("Sec-WebSocket-Protocol" ):
168172 subprotocols .extend ([token .strip () for token in header .split ("," )])
169173
170174 asgi_headers = [
171175 (name .encode ("ascii" ), value .encode ("ascii" , errors = "surrogateescape" ))
172- for name , value in headers .raw_items ()
176+ for name , value in request_headers .raw_items ()
173177 ]
174178 path = unquote (path_portion )
175179 full_path = self .root_path + path
@@ -237,14 +241,13 @@ async def run_asgi(self) -> None:
237241 termination states.
238242 """
239243 try :
240- result = await self .app (self .scope , self .asgi_receive , self .asgi_send )
244+ result = await self .app (self .scope , self .asgi_receive , self .asgi_send ) # type: ignore[func-returns-value]
241245 except ClientDisconnected :
242246 self .closed_event .set ()
243247 self .transport .close ()
244- except BaseException as exc :
248+ except BaseException :
245249 self .closed_event .set ()
246- msg = "Exception in ASGI application\n "
247- self .logger .error (msg , exc_info = exc )
250+ self .logger .exception ("Exception in ASGI application\n " )
248251 if not self .handshake_started_event .is_set ():
249252 self .send_500_response ()
250253 else :
@@ -253,13 +256,11 @@ async def run_asgi(self) -> None:
253256 else :
254257 self .closed_event .set ()
255258 if not self .handshake_started_event .is_set ():
256- msg = "ASGI callable returned without sending handshake."
257- self .logger .error (msg )
259+ self .logger .error ("ASGI callable returned without sending handshake." )
258260 self .send_500_response ()
259261 self .transport .close ()
260262 elif result is not None :
261- msg = "ASGI callable should return None, but returned '%s'."
262- self .logger .error (msg , result )
263+ self .logger .error ("ASGI callable should return None, but returned '%s'." , result )
263264 await self .handshake_completed_event .wait ()
264265 self .transport .close ()
265266
0 commit comments