Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/finch/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,17 @@ def _process_response(
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type").split(";")
if content_type != "application/json":
raise ValueError(
f"Expected Content-Type response header to be `application/json` but received {content_type} instead."
)
if self._strict_response_validation:
raise exceptions.APIResponseValidationError(
response=response,
request=response.request,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
)

# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore

data = response.json()
return self._process_response_data(data=data, cast_to=cast_to, response=response)
Expand Down
6 changes: 4 additions & 2 deletions src/finch/_base_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing_extensions import Literal

from httpx import Request, Response
Expand All @@ -17,8 +19,8 @@ class APIResponseValidationError(APIError):
response: Response
status_code: int

def __init__(self, request: Request, response: Response) -> None:
super().__init__("Data returned by API invalid for expected schema.", request)
def __init__(self, request: Request, response: Response, *, message: str | None = None) -> None:
super().__init__(message or "Data returned by API invalid for expected schema.", request)
self.response = response
self.status_code = response.status_code

Expand Down
37 changes: 36 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
from respx import MockRouter

from finch import Finch, AsyncFinch
from finch import Finch, AsyncFinch, APIResponseValidationError
from finch._types import Omit
from finch._models import BaseModel, FinalRequestOptions
from finch._base_client import BaseClient, make_request_options
Expand Down Expand Up @@ -385,6 +385,23 @@ def test_client_context_manager(self) -> None:
assert not client.is_closed()
assert client.is_closed()

@pytest.mark.respx(base_url=base_url)
def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
class Model(BaseModel):
name: str

respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))

strict_client = Finch(base_url=base_url, access_token=access_token, _strict_response_validation=True)

with pytest.raises(APIResponseValidationError):
strict_client.get("/foo", cast_to=Model)

client = Finch(base_url=base_url, access_token=access_token, _strict_response_validation=False)

response = client.get("/foo", cast_to=Model)
assert isinstance(response, str) # type: ignore[unreachable]


class TestAsyncFinch:
client = AsyncFinch(base_url=base_url, access_token=access_token, _strict_response_validation=True)
Expand Down Expand Up @@ -744,3 +761,21 @@ async def test_client_context_manager(self) -> None:
assert not c2.is_closed()
assert not client.is_closed()
assert client.is_closed()

@pytest.mark.respx(base_url=base_url)
@pytest.mark.asyncio
async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
class Model(BaseModel):
name: str

respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))

strict_client = AsyncFinch(base_url=base_url, access_token=access_token, _strict_response_validation=True)

with pytest.raises(APIResponseValidationError):
await strict_client.get("/foo", cast_to=Model)

client = AsyncFinch(base_url=base_url, access_token=access_token, _strict_response_validation=False)

response = await client.get("/foo", cast_to=Model)
assert isinstance(response, str) # type: ignore[unreachable]