diff --git a/openapi_python_client/parser/openapi.py b/openapi_python_client/parser/openapi.py index 0aab5a717..14d8d518d 100644 --- a/openapi_python_client/parser/openapi.py +++ b/openapi_python_client/parser/openapi.py @@ -2,7 +2,6 @@ from collections.abc import Iterator from copy import deepcopy from dataclasses import dataclass, field -from http import HTTPStatus from typing import Any, Optional, Protocol, Union from pydantic import ValidationError @@ -26,7 +25,7 @@ property_from_data, ) from .properties.schemas import parameter_from_reference -from .responses import Response, response_from_data +from .responses import HTTPStatusSpec, Response, http_status_spec, response_from_data _PATH_PARAM_REGEX = re.compile("{([a-zA-Z_-][a-zA-Z0-9_-]*)}") @@ -162,22 +161,13 @@ def _add_responses( ) -> tuple["Endpoint", Schemas]: endpoint = deepcopy(endpoint) for code, response_data in data.items(): - status_code: HTTPStatus - try: - status_code = HTTPStatus(int(code)) - except ValueError: - endpoint.errors.append( - ParseError( - detail=( - f"Invalid response status code {code} (not a valid HTTP " - f"status code), response will be omitted from generated " - f"client" - ) - ) - ) + status_code: HTTPStatusSpec | ParseError = http_status_spec(code) + if isinstance(status_code, ParseError): + endpoint.errors.append(status_code) continue response, schemas = response_from_data( + status_code_str=code, status_code=status_code, data=response_data, schemas=schemas, @@ -190,7 +180,7 @@ def _add_responses( endpoint.errors.append( ParseError( detail=( - f"Cannot parse response for status code {status_code}{detail_suffix}, " + f"Cannot parse response for status code {code}{detail_suffix}, " f"response will be omitted from generated client" ), data=response.data, diff --git a/openapi_python_client/parser/responses.py b/openapi_python_client/parser/responses.py index ec0f6136b..071142c8f 100644 --- a/openapi_python_client/parser/responses.py +++ b/openapi_python_client/parser/responses.py @@ -4,6 +4,7 @@ from typing import Optional, TypedDict, Union from attrs import define +from typing_extensions import TypeAlias from openapi_python_client import utils from openapi_python_client.parser.properties.schemas import get_reference_simple_name, parse_reference_path @@ -27,12 +28,48 @@ class _ResponseSource(TypedDict): TEXT_SOURCE = _ResponseSource(attribute="response.text", return_type="str") NONE_SOURCE = _ResponseSource(attribute="None", return_type="None") +HTTPStatusSpec: TypeAlias = Union[HTTPStatus, tuple[HTTPStatus, int]] +"""Either a single http status or a tuple representing an inclusive range. + +The second element of the tuple is also logically a status code but is typically 299 or similar which +is not contained in the enum. + +https://github.com/openapi-generators/openapi-python-client/blob/61b6c54994e2a6285bb422ee3b864c45b5d88c15/openapi_python_client/schema/3.1.0.md#responses-object +""" + + +def http_status_spec(code: str | int) -> HTTPStatusSpec | ParseError: + """Parses plain integer status codes such as 201 or patterned status codes such as 2XX.""" + + multiplier = 1 + if isinstance(code, str): + if code.endswith("XX"): + code = code.removesuffix("XX") + multiplier = 100 + + try: + status_code = int(code) + + if multiplier > 1: + start = status_code * multiplier + return (HTTPStatus(start), start + multiplier - 1) + + return HTTPStatus(status_code) + except ValueError: + return ParseError( + detail=( + f"Invalid response status code {code} (not a valid HTTP " + f"status code), response will be omitted from generated " + f"client" + ) + ) + @define class Response: """Describes a single response for an endpoint""" - status_code: HTTPStatus + status_code: HTTPStatusSpec prop: Property source: _ResponseSource data: Union[oai.Response, oai.Reference] # Original data which created this response, useful for custom templates @@ -59,7 +96,7 @@ def _source_by_content_type(content_type: str, config: Config) -> Optional[_Resp def empty_response( *, - status_code: HTTPStatus, + status_code: HTTPStatusSpec, response_name: str, config: Config, data: Union[oai.Response, oai.Reference], @@ -80,9 +117,22 @@ def empty_response( ) +def _status_code_str(status_code_str: str | None, status_code: HTTPStatusSpec) -> str: + if status_code_str is None: + if isinstance(status_code, HTTPStatus): + return str(status_code.value) + if isinstance(status_code, int): + return str(status_code) + + raise ValueError(f"status_code_str must be passed for {status_code!r}") + + return status_code_str + + def response_from_data( # noqa: PLR0911 *, - status_code: HTTPStatus, + status_code_str: str | None = None, + status_code: HTTPStatusSpec, data: Union[oai.Response, oai.Reference], schemas: Schemas, responses: dict[str, Union[oai.Response, oai.Reference]], @@ -90,8 +140,9 @@ def response_from_data( # noqa: PLR0911 config: Config, ) -> tuple[Union[Response, ParseError], Schemas]: """Generate a Response from the OpenAPI dictionary representation of it""" + status_code_str = _status_code_str(status_code_str, status_code) - response_name = f"response_{status_code}" + response_name = f"response_{status_code_str}" if isinstance(data, oai.Reference): ref_path = parse_reference_path(data.ref) if isinstance(ref_path, ParseError): diff --git a/openapi_python_client/templates/endpoint_module.py.jinja b/openapi_python_client/templates/endpoint_module.py.jinja index 802fcc2ea..6bf1173ce 100644 --- a/openapi_python_client/templates/endpoint_module.py.jinja +++ b/openapi_python_client/templates/endpoint_module.py.jinja @@ -67,7 +67,11 @@ def _get_kwargs( def _parse_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Optional[{{ return_string }}]: {% for response in endpoint.responses %} + {% if response.status_code.value is defined %} if response.status_code == {{ response.status_code.value }}: + {% else %} + if {{ response.status_code[0].value }} <= response.status_code <= {{ response.status_code[1] }}: + {% endif %} {% if parsed_responses %}{% import "property_templates/" + response.prop.template as prop_template %} {% if prop_template.construct %} {{ prop_template.construct(response.prop, response.source.attribute) | indent(8) }} diff --git a/tests/test_parser/test_openapi.py b/tests/test_parser/test_openapi.py index 3d1391ae2..dcf993217 100644 --- a/tests/test_parser/test_openapi.py +++ b/tests/test_parser/test_openapi.py @@ -1,3 +1,4 @@ +from http import HTTPStatus from unittest.mock import MagicMock import pydantic @@ -7,6 +8,7 @@ from openapi_python_client.parser.errors import ParseError from openapi_python_client.parser.openapi import Endpoint, EndpointCollection, import_string_from_class from openapi_python_client.parser.properties import Class, IntProperty, Parameters, Schemas +from openapi_python_client.parser.responses import empty_response from openapi_python_client.schema import DataType MODULE_NAME = "openapi_python_client.parser.openapi" @@ -48,6 +50,44 @@ def test__add_responses_status_code_error(self, response_status_code, mocker): ] response_from_data.assert_not_called() + def test__add_response_with_patterned_status_code(self, mocker): + schemas = Schemas() + response_1_data = mocker.MagicMock() + data = { + "2XX": response_1_data, + } + endpoint = self.make_endpoint() + config = MagicMock() + response = empty_response( + status_code=(HTTPStatus(200), 299), + response_name="dummy", + config=config, + data=data, + ) + response_from_data = mocker.patch(f"{MODULE_NAME}.response_from_data", return_value=(response, schemas)) + + response, schemas = Endpoint._add_responses( + endpoint=endpoint, data=data, schemas=schemas, responses={}, config=config + ) + + assert response.errors == [] + + assert response.responses[0].status_code == (200, 299) + + response_from_data.assert_has_calls( + [ + mocker.call( + status_code_str="2XX", + status_code=(HTTPStatus(200), 299), + data=response_1_data, + schemas=schemas, + responses={}, + parent_name="name", + config=config, + ), + ] + ) + def test__add_responses_error(self, mocker): schemas = Schemas() response_1_data = mocker.MagicMock() @@ -68,6 +108,7 @@ def test__add_responses_error(self, mocker): response_from_data.assert_has_calls( [ mocker.call( + status_code_str="200", status_code=200, data=response_1_data, schemas=schemas, @@ -76,6 +117,7 @@ def test__add_responses_error(self, mocker): config=config, ), mocker.call( + status_code_str="404", status_code=404, data=response_2_data, schemas=schemas, diff --git a/tests/test_parser/test_responses.py b/tests/test_parser/test_responses.py index 8fb04d720..880945407 100644 --- a/tests/test_parser/test_responses.py +++ b/tests/test_parser/test_responses.py @@ -240,6 +240,7 @@ def test_response_from_data_content_type_overrides(any_property_factory): config = MagicMock() config.content_type_overrides = {"application/zip": "application/octet-stream"} response, schemas = response_from_data( + status_code_str="200", status_code=200, data=data, schemas=Schemas(),