diff --git a/h11/_connection.py b/h11/_connection.py index 6f796ef..bcd3089 100644 --- a/h11/_connection.py +++ b/h11/_connection.py @@ -483,7 +483,7 @@ def send_with_data_passthrough(self, event): raise LocalProtocolError("Can't send data when our state is ERROR") try: if type(event) is Response: - self._clean_up_response_headers_for_sending(event) + event = self._clean_up_response_headers_for_sending(event) # We want to call _process_event before calling the writer, # because if someone tries to do something invalid then this will # give a sensible error message, while our writers all just assume @@ -528,8 +528,7 @@ def send_failed(self): # # This function's *only* responsibility is making sure headers are set up # right -- everything downstream just looks at the headers. There are no - # side channels. It mutates the response event in-place (but not the - # response.headers list object). + # side channels. def _clean_up_response_headers_for_sending(self, response): assert type(response) is Response @@ -582,4 +581,9 @@ def _clean_up_response_headers_for_sending(self, response): connection.add(b"close") headers = set_comma_header(headers, b"connection", sorted(connection)) - response.headers = headers + return Response( + headers=headers, + status_code=response.status_code, + http_version=response.http_version, + reason=response.reason, + ) diff --git a/h11/_events.py b/h11/_events.py index 1827930..ebbf10f 100644 --- a/h11/_events.py +++ b/h11/_events.py @@ -6,13 +6,17 @@ # Don't subclass these. Stuff will break. import re +from abc import ABC +from dataclasses import dataclass, field +from typing import Any, cast, Dict, List, Tuple, Union -from . import _headers from ._abnf import request_target +from ._headers import Headers, normalize_and_validate from ._util import bytesify, LocalProtocolError, validate # Everything in __all__ gets re-exported as part of the h11 public API. __all__ = [ + "Event", "Request", "InformationalResponse", "Response", @@ -24,72 +28,16 @@ request_target_re = re.compile(request_target.encode("ascii")) -class _EventBundle: - _fields = [] - _defaults = {} - - def __init__(self, **kwargs): - _parsed = kwargs.pop("_parsed", False) - allowed = set(self._fields) - for kwarg in kwargs: - if kwarg not in allowed: - raise TypeError( - "unrecognized kwarg {} for {}".format( - kwarg, self.__class__.__name__ - ) - ) - required = allowed.difference(self._defaults) - for field in required: - if field not in kwargs: - raise TypeError( - "missing required kwarg {} for {}".format( - field, self.__class__.__name__ - ) - ) - self.__dict__.update(self._defaults) - self.__dict__.update(kwargs) - - # Special handling for some fields - - if "headers" in self.__dict__: - self.headers = _headers.normalize_and_validate( - self.headers, _parsed=_parsed - ) - - if not _parsed: - for field in ["method", "target", "http_version", "reason"]: - if field in self.__dict__: - self.__dict__[field] = bytesify(self.__dict__[field]) - - if "status_code" in self.__dict__: - if not isinstance(self.status_code, int): - raise LocalProtocolError("status code must be integer") - # Because IntEnum objects are instances of int, but aren't - # duck-compatible (sigh), see gh-72. - self.status_code = int(self.status_code) - - self._validate() - - def _validate(self): - pass - - def __repr__(self): - name = self.__class__.__name__ - kwarg_strs = [ - "{}={}".format(field, self.__dict__[field]) for field in self._fields - ] - kwarg_str = ", ".join(kwarg_strs) - return "{}({})".format(name, kwarg_str) - - # Useful for tests - def __eq__(self, other): - return self.__class__ == other.__class__ and self.__dict__ == other.__dict__ +class Event(ABC): + """ + Base class for h11 events. + """ - # This is an unhashable type. - __hash__ = None + __slots__ = () -class Request(_EventBundle): +@dataclass(init=False, frozen=True) +class Request(Event): """The beginning of an HTTP request. Fields: @@ -123,10 +71,38 @@ class Request(_EventBundle): """ - _fields = ["method", "target", "headers", "http_version"] - _defaults = {"http_version": b"1.1"} + __slots__ = ("method", "headers", "target", "http_version") + + method: bytes + headers: Headers + target: bytes + http_version: bytes + + def __init__( + self, + *, + method: Union[bytes, str], + headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]], + target: Union[bytes, str], + http_version: Union[bytes, str] = b"1.1", + _parsed: bool = False, + ) -> None: + super().__init__() + if isinstance(headers, Headers): + object.__setattr__(self, "headers", headers) + else: + object.__setattr__( + self, "headers", normalize_and_validate(headers, _parsed=_parsed) + ) + if not _parsed: + object.__setattr__(self, "method", bytesify(method)) + object.__setattr__(self, "target", bytesify(target)) + object.__setattr__(self, "http_version", bytesify(http_version)) + else: + object.__setattr__(self, "method", method) + object.__setattr__(self, "target", target) + object.__setattr__(self, "http_version", http_version) - def _validate(self): # "A server MUST respond with a 400 (Bad Request) status code to any # HTTP/1.1 request message that lacks a Host header field and to any # request message that contains more than one Host header field or a @@ -143,12 +119,58 @@ def _validate(self): validate(request_target_re, self.target, "Illegal target characters") + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(init=False, frozen=True) +class _ResponseBase(Event): + __slots__ = ("headers", "http_version", "reason", "status_code") + + headers: Headers + http_version: bytes + reason: bytes + status_code: int + + def __init__( + self, + *, + headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]], + status_code: int, + http_version: Union[bytes, str] = b"1.1", + reason: Union[bytes, str] = b"", + _parsed: bool = False, + ) -> None: + super().__init__() + if isinstance(headers, Headers): + object.__setattr__(self, "headers", headers) + else: + object.__setattr__( + self, "headers", normalize_and_validate(headers, _parsed=_parsed) + ) + if not _parsed: + object.__setattr__(self, "reason", bytesify(reason)) + object.__setattr__(self, "http_version", bytesify(http_version)) + if not isinstance(status_code, int): + raise LocalProtocolError("status code must be integer") + # Because IntEnum objects are instances of int, but aren't + # duck-compatible (sigh), see gh-72. + object.__setattr__(self, "status_code", int(status_code)) + else: + object.__setattr__(self, "reason", reason) + object.__setattr__(self, "http_version", http_version) + object.__setattr__(self, "status_code", status_code) + + self.__post_init__() + + def __post_init__(self) -> None: + pass -class _ResponseBase(_EventBundle): - _fields = ["status_code", "headers", "http_version", "reason"] - _defaults = {"http_version": b"1.1", "reason": b""} + # This is an unhashable type. + __hash__ = None # type: ignore +@dataclass(init=False, frozen=True) class InformationalResponse(_ResponseBase): """An HTTP informational response. @@ -179,14 +201,18 @@ class InformationalResponse(_ResponseBase): """ - def _validate(self): + def __post_init__(self) -> None: if not (100 <= self.status_code < 200): raise LocalProtocolError( "InformationalResponse status_code should be in range " "[100, 200), not {}".format(self.status_code) ) + # This is an unhashable type. + __hash__ = None # type: ignore + +@dataclass(init=False, frozen=True) class Response(_ResponseBase): """The beginning of an HTTP response. @@ -216,7 +242,7 @@ class Response(_ResponseBase): """ - def _validate(self): + def __post_init__(self) -> None: if not (200 <= self.status_code < 600): raise LocalProtocolError( "Response status_code should be in range [200, 600), not {}".format( @@ -224,8 +250,12 @@ def _validate(self): ) ) + # This is an unhashable type. + __hash__ = None # type: ignore + -class Data(_EventBundle): +@dataclass(init=False, frozen=True) +class Data(Event): """Part of an HTTP message body. Fields: @@ -258,8 +288,21 @@ class Data(_EventBundle): """ - _fields = ["data", "chunk_start", "chunk_end"] - _defaults = {"chunk_start": False, "chunk_end": False} + __slots__ = ("data", "chunk_start", "chunk_end") + + data: bytes + chunk_start: bool + chunk_end: bool + + def __init__( + self, data: bytes, chunk_start: bool = False, chunk_end: bool = False + ) -> None: + object.__setattr__(self, "data", data) + object.__setattr__(self, "chunk_start", chunk_start) + object.__setattr__(self, "chunk_end", chunk_end) + + # This is an unhashable type. + __hash__ = None # type: ignore # XX FIXME: "A recipient MUST ignore (or consider as an error) any fields that @@ -267,7 +310,8 @@ class Data(_EventBundle): # present in the header section might bypass external security filters." # https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#chunked.trailer.part # Unfortunately, the list of forbidden fields is long and vague :-/ -class EndOfMessage(_EventBundle): +@dataclass(init=False, frozen=True) +class EndOfMessage(Event): """The end of an HTTP message. Fields: @@ -284,11 +328,32 @@ class EndOfMessage(_EventBundle): """ - _fields = ["headers"] - _defaults = {"headers": []} + __slots__ = ("headers",) + + headers: Headers + + def __init__( + self, + *, + headers: Union[ + Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]], None + ] = None, + _parsed: bool = False, + ) -> None: + super().__init__() + if headers is None: + headers = Headers([]) + elif not isinstance(headers, Headers): + headers = normalize_and_validate(headers, _parsed=_parsed) + + object.__setattr__(self, "headers", headers) + + # This is an unhashable type. + __hash__ = None # type: ignore -class ConnectionClosed(_EventBundle): +@dataclass(frozen=True) +class ConnectionClosed(Event): """This event indicates that the sender has closed their outgoing connection. diff --git a/h11/tests/helpers.py b/h11/tests/helpers.py index 9d2cf38..5f53457 100644 --- a/h11/tests/helpers.py +++ b/h11/tests/helpers.py @@ -26,11 +26,13 @@ def normalize_data_events(in_events): out_events = [] for event in in_events: if type(event) is Data: - event.data = bytes(event.data) - event.chunk_start = False - event.chunk_end = False + event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False) if out_events and type(out_events[-1]) is type(event) is Data: - out_events[-1].data += event.data + out_events[-1] = Data( + data=out_events[-1].data + event.data, + chunk_start=out_events[-1].chunk_start, + chunk_end=out_events[-1].chunk_end, + ) else: out_events.append(event) return out_events diff --git a/h11/tests/test_connection.py b/h11/tests/test_connection.py index baadec8..de175de 100644 --- a/h11/tests/test_connection.py +++ b/h11/tests/test_connection.py @@ -467,7 +467,13 @@ def test_reuse_simple(): CLIENT, [Request(method="GET", target="/", headers=[("Host", "a")]), EndOfMessage()], ) - p.send(SERVER, [Response(status_code=200, headers=[]), EndOfMessage()]) + p.send( + SERVER, + [ + Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), + EndOfMessage(), + ], + ) for conn in p.conns: assert conn.states == {CLIENT: DONE, SERVER: DONE} conn.start_next_cycle() @@ -479,7 +485,13 @@ def test_reuse_simple(): EndOfMessage(), ], ) - p.send(SERVER, [Response(status_code=404, headers=[]), EndOfMessage()]) + p.send( + SERVER, + [ + Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), + EndOfMessage(), + ], + ) def test_pipelining(): @@ -562,8 +574,8 @@ def test_protocol_switch(): target="example.com:443", headers=[("Host", "foo"), ("Content-Length", "1")], ), - Response(status_code=404, headers=[]), - Response(status_code=200, headers=[]), + Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), + Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), ), ( Request( @@ -571,7 +583,7 @@ def test_protocol_switch(): target="/", headers=[("Host", "foo"), ("Content-Length", "1"), ("Upgrade", "a, b")], ), - Response(status_code=200, headers=[]), + Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), InformationalResponse(status_code=101, headers=[("Upgrade", "a")]), ), ( @@ -580,9 +592,9 @@ def test_protocol_switch(): target="example.com:443", headers=[("Host", "foo"), ("Content-Length", "1"), ("Upgrade", "a, b")], ), - Response(status_code=404, headers=[]), + Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), # Accept CONNECT, not upgrade - Response(status_code=200, headers=[]), + Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), ), ( Request( @@ -590,7 +602,7 @@ def test_protocol_switch(): target="example.com:443", headers=[("Host", "foo"), ("Content-Length", "1"), ("Upgrade", "a, b")], ), - Response(status_code=404, headers=[]), + Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), # Accept Upgrade, not CONNECT InformationalResponse(status_code=101, headers=[("Upgrade", "b")]), ), @@ -725,7 +737,10 @@ def test_close_different_states(): Request(method="GET", target="/foo", headers=[("Host", "a")]), EndOfMessage(), ] - resp = [Response(status_code=200, headers=[]), EndOfMessage()] + resp = [ + Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), + EndOfMessage(), + ] # Client before request p = ConnectionPair() @@ -949,7 +964,7 @@ def test_408_request_timeout(): # Should be able to send this spontaneously as a server without seeing # anything from client p = ConnectionPair() - p.send(SERVER, Response(status_code=408, headers=[])) + p.send(SERVER, Response(status_code=408, headers=[(b"connection", b"close")])) # This used to raise IndexError diff --git a/h11/tests/test_events.py b/h11/tests/test_events.py index e20f741..4748c4b 100644 --- a/h11/tests/test_events.py +++ b/h11/tests/test_events.py @@ -7,52 +7,6 @@ from .._util import LocalProtocolError -def test_event_bundle(): - class T(_events._EventBundle): - _fields = ["a", "b"] - _defaults = {"b": 1} - - def _validate(self): - if self.a == 0: - raise ValueError - - # basic construction and methods - t = T(a=1, b=0) - assert repr(t) == "T(a=1, b=0)" - assert t == T(a=1, b=0) - assert not (t == T(a=2, b=0)) - assert not (t != T(a=1, b=0)) - assert t != T(a=2, b=0) - with pytest.raises(TypeError): - hash(t) - - # check defaults - t = T(a=10) - assert t.a == 10 - assert t.b == 1 - - # no positional args - with pytest.raises(TypeError): - T(1) - - with pytest.raises(TypeError): - T(1, a=1, b=0) - - # unknown field - with pytest.raises(TypeError): - T(a=1, b=0, c=10) - - # missing required field - with pytest.raises(TypeError) as exc: - T(b=0) - # make sure we error on the right missing kwarg - assert "kwarg a" in str(exc.value) - - # _validate is called - with pytest.raises(ValueError): - T(a=0, b=0) - - def test_events(): with pytest.raises(LocalProtocolError): # Missing Host: diff --git a/newsfragments/124.feature.rst b/newsfragments/124.feature.rst new file mode 100644 index 0000000..e89c221 --- /dev/null +++ b/newsfragments/124.feature.rst @@ -0,0 +1,2 @@ +Switch event classes to dataclasses for easier typing and slightly +improved performance. diff --git a/setup.py b/setup.py index 93b6b62..eab298e 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ # doesn't look like a source file, so long as it appears in MANIFEST.in: include_package_data=True, python_requires=">=3.6", + install_requires=["dataclasses; python_version < '3.7'"], classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers",