diff --git a/src/finch/_base_client.py b/src/finch/_base_client.py index de73b186..58309b6b 100644 --- a/src/finch/_base_client.py +++ b/src/finch/_base_client.py @@ -24,7 +24,7 @@ overload, ) from functools import lru_cache -from typing_extensions import Literal, get_origin +from typing_extensions import Literal, get_args, get_origin import anyio import httpx @@ -458,6 +458,14 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o serialized[key] = value return serialized + def _extract_stream_chunk_type(self, stream_cls: type) -> type: + args = get_args(stream_cls) + if not args: + raise TypeError( + f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}", + ) + return cast(type, args[0]) + def _process_response( self, *, @@ -793,7 +801,10 @@ def _request( raise APIConnectionError(request=request) from err if stream: - stream_cls = stream_cls or cast("type[_StreamT] | None", self._default_stream_cls) + if stream_cls: + return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self) + + stream_cls = cast("type[_StreamT] | None", self._default_stream_cls) if stream_cls is None: raise MissingStreamClassError() return stream_cls(cast_to=cast_to, response=response, client=self) @@ -1156,7 +1167,10 @@ async def _request( raise APIConnectionError(request=request) from err if stream: - stream_cls = stream_cls or cast("type[_AsyncStreamT] | None", self._default_stream_cls) + if stream_cls: + return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self) + + stream_cls = cast("type[_AsyncStreamT] | None", self._default_stream_cls) if stream_cls is None: raise MissingStreamClassError() return stream_cls(cast_to=cast_to, response=response, client=self)