Skip to content

fix: Add aiohttp conditional type checking #1234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,11 @@ async def _aiter_response_stream(self) -> AsyncIterator[str]:
has_aiohttp and isinstance(self.response_stream, aiohttp.ClientResponse)
)
if not is_valid_response:
expected_types = 'an httpx.Response'
if has_aiohttp:
expected_types = 'an httpx.Response or aiohttp.ClientResponse'
raise TypeError(
'Expected self.response_stream to be an httpx.Response or'
' aiohttp.ClientResponse object, but got'
f'Expected self.response_stream to be {expected_types} object, but got'
f' {type(self.response_stream).__name__}.'
)

Expand Down Expand Up @@ -384,9 +386,7 @@ async def _aiter_response_stream(self) -> AsyncIterator[str]:
chunk = ''

# aiohttp.ClientResponse uses a content stream that we read line by line.
elif has_aiohttp and isinstance(
self.response_stream, aiohttp.ClientResponse
):
elif has_aiohttp and isinstance(self.response_stream, aiohttp.ClientResponse):
while True:
# Read a line from the stream. This returns bytes.
line_bytes = await self.response_stream.content.readline()
Expand Down Expand Up @@ -486,7 +486,7 @@ def _retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
}


class SyncHttpxClient(httpx.Client):
class SyncHttpxClient(httpx.Client): # type: ignore[misc]
"""Sync httpx client."""

def __init__(self, **kwargs: Any) -> None:
Expand All @@ -507,7 +507,7 @@ def __del__(self) -> None:
pass


class AsyncHttpxClient(httpx.AsyncClient):
class AsyncHttpxClient(httpx.AsyncClient): # type: ignore[misc]
"""Async httpx client."""

def __init__(self, **kwargs: Any) -> None:
Expand Down Expand Up @@ -1123,7 +1123,7 @@ async def _async_request_once(
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
**self._async_client_session_request_args,
)
except (
except ( # type: ignore[misc]
aiohttp.ClientConnectorError,
aiohttp.ClientConnectorDNSError,
aiohttp.ClientOSError,
Expand Down Expand Up @@ -1185,7 +1185,7 @@ async def _async_request_once(
)
await errors.APIError.raise_for_async_response(response)
return HttpResponse(response.headers, [await response.text()])
except (
except ( # type: ignore[misc]
aiohttp.ClientConnectorError,
aiohttp.ClientConnectorDNSError,
aiohttp.ClientOSError,
Expand Down
6 changes: 4 additions & 2 deletions google/genai/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ async def raise_for_async_response(
}
status_code = response.status
else:
response_json = response.body_segments[0].get('error', {})
# Handle ReplayResponse or other types with body_segments
response_json = getattr(response, 'body_segments', [{}])[0].get('error', {})
except ImportError:
response_json = response.body_segments[0].get('error', {})
# Handle ReplayResponse or other types with body_segments
response_json = getattr(response, 'body_segments', [{}])[0].get('error', {})

if 400 <= status_code < 500:
raise ClientError(status_code, response_json, response)
Expand Down
10 changes: 5 additions & 5 deletions google/genai/tunings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,9 +1485,9 @@ def _get_ipython_shell_name() -> Union[str, Any]:
import sys

if 'IPython' in sys.modules:
from IPython import get_ipython
from IPython import get_ipython # type: ignore[attr-defined]

return get_ipython().__class__.__name__
return get_ipython().__class__.__name__ # type: ignore[no-untyped-call]
return ''

@staticmethod
Expand Down Expand Up @@ -1603,10 +1603,10 @@ def _display_link(
</script>
"""

from IPython.display import display
from IPython.display import HTML
from IPython.display import display # type: ignore[import-untyped]
from IPython.display import HTML # type: ignore[import-untyped]

display(HTML(html))
display(HTML(html)) # type: ignore[no-untyped-call]

@staticmethod
def display_experiment_button(experiment: str, project: str) -> None:
Expand Down
40 changes: 21 additions & 19 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@
PIL_Image = PIL.Image.Image
_is_pillow_image_imported = True
else:
PIL_Image: typing.Type = Any
PIL_Image: typing.Type = Any # type: ignore[valid-type]
try:
import PIL.Image

PIL_Image = PIL.Image.Image
PIL_Image = PIL.Image.Image # type: ignore[misc]
_is_pillow_image_imported = True
except ImportError:
PIL_Image = None
PIL_Image = None # type: ignore[misc]

_is_mcp_imported = False
if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -3791,13 +3791,13 @@ class FileDict(TypedDict, total=False):


if _is_pillow_image_imported:
PartUnion = Union[str, PIL_Image, File, Part]
PartUnion = Union[str, PIL_Image, File, Part] # type: ignore[valid-type]
else:
PartUnion = Union[str, File, Part] # type: ignore[misc]


if _is_pillow_image_imported:
PartUnionDict = Union[str, PIL_Image, File, FileDict, Part, PartDict]
PartUnionDict = Union[str, PIL_Image, File, FileDict, Part, PartDict] # type: ignore[valid-type]
else:
PartUnionDict = Union[str, File, FileDict, Part, PartDict] # type: ignore[misc]

Expand Down Expand Up @@ -6178,7 +6178,7 @@ class Image(_common.BaseModel):
default=None, description="""The MIME type of the image."""
)

_loaded_image: Optional['PIL_Image'] = None
_loaded_image: Optional['PIL_Image'] = None # type: ignore[valid-type]

"""Image."""

Expand Down Expand Up @@ -6229,21 +6229,22 @@ def show(self) -> None:

This method only works in a notebook environment.
"""
IPython_display = None # type: Any
try:
from IPython import display as IPython_display
from IPython import display as IPython_display # type: ignore[import-untyped,no-redef]
except ImportError:
IPython_display = None
pass

if IPython_display:
IPython_display.display(self._pil_image)
IPython_display.display(self._pil_image) # type: ignore[no-untyped-call]

@property
def _pil_image(self) -> Optional['PIL_Image']:
PIL_Image: Optional[builtin_types.ModuleType]
def _pil_image(self) -> Optional['PIL_Image']: # type: ignore[valid-type]
PIL_Image: Optional[builtin_types.ModuleType] = None
try:
from PIL import Image as PIL_Image
from PIL import Image as PIL_Image # type: ignore[no-redef]
except ImportError:
PIL_Image = None
pass
import io

if self._loaded_image is None:
Expand Down Expand Up @@ -8120,14 +8121,15 @@ def show(self) -> None:

mime_type = self.mime_type or 'video/mp4'

IPython_display = None # type: Any
try:
from IPython import display as IPython_display
from IPython import display as IPython_display # type: ignore[import-untyped,no-redef]
except ImportError:
IPython_display = None
pass

if IPython_display:
IPython_display.display(
IPython_display.Video(
IPython_display.display( # type: ignore[no-untyped-call]
IPython_display.Video( # type: ignore[no-untyped-call]
data=self.video_bytes, mimetype=mime_type, embed=True
)
)
Expand Down Expand Up @@ -13498,13 +13500,13 @@ class LiveClientToolResponseDict(TypedDict, total=False):


if _is_pillow_image_imported:
BlobImageUnion = Union[PIL_Image, Blob]
BlobImageUnion = Union[PIL_Image, Blob] # type: ignore[valid-type]
else:
BlobImageUnion = Blob # type: ignore[misc]


if _is_pillow_image_imported:
BlobImageUnionDict = Union[PIL_Image, Blob, BlobDict]
BlobImageUnionDict = Union[PIL_Image, Blob, BlobDict] # type: ignore[valid-type]
else:
BlobImageUnionDict = Union[Blob, BlobDict] # type: ignore[misc]

Expand Down