Skip to content

Commit 6405454

Browse files
committed
Add mypy type checks
1 parent 667cdfb commit 6405454

File tree

5 files changed

+132
-84
lines changed

5 files changed

+132
-84
lines changed

.github/workflows/lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
run: uv run ruff format --check
4646
- name: Run ruff check
4747
run: uv run ruff check
48-
# - name: Run mypy
49-
# run: uv run mypy .
48+
- name: Run mypy
49+
run: uv run mypy .
5050
- name: Minimize uv cache
5151
run: uv cache prune --ci

langchain/langchain_vectorize/retrievers.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,18 @@
44

55
from typing import TYPE_CHECKING, Any, Literal, Optional
66

7-
import vectorize_client
87
from langchain_core.documents import Document
98
from langchain_core.retrievers import BaseRetriever
109
from typing_extensions import override
11-
from vectorize_client import (
12-
ApiClient,
13-
Configuration,
14-
PipelinesApi,
15-
RetrieveDocumentsRequest,
16-
)
10+
from vectorize_client.api.pipelines_api import PipelinesApi
11+
from vectorize_client.api_client import ApiClient
12+
from vectorize_client.configuration import Configuration
13+
from vectorize_client.models.retrieve_documents_request import RetrieveDocumentsRequest
1714

1815
if TYPE_CHECKING:
1916
from langchain_core.callbacks import CallbackManagerForRetrieverRun
2017
from langchain_core.runnables import RunnableConfig
18+
from vectorize_client.models.document import Document as VectorizeDocument
2119

2220
_METADATA_FIELDS = {
2321
"relevancy",
@@ -122,7 +120,7 @@ def format_docs(docs):
122120
metadata_filters: list[dict[str, Any]] = []
123121
"""The metadata filters to apply when retrieving the documents."""
124122

125-
_pipelines: PipelinesApi | None = None
123+
_pipelines: PipelinesApi = _NOT_SET # type: ignore[assignment]
126124

127125
@override
128126
def model_post_init(self, /, context: Any) -> None:
@@ -146,7 +144,7 @@ def model_post_init(self, /, context: Any) -> None:
146144
self._pipelines = PipelinesApi(api)
147145

148146
@staticmethod
149-
def _convert_document(document: vectorize_client.models.Document) -> Document:
147+
def _convert_document(document: VectorizeDocument) -> Document:
150148
metadata = {field: getattr(document, field) for field in _METADATA_FIELDS}
151149
return Document(id=document.id, page_content=document.text, metadata=metadata)
152150

@@ -162,14 +160,29 @@ def _get_relevant_documents(
162160
rerank: bool | None = None,
163161
metadata_filters: list[dict[str, Any]] | None = None,
164162
) -> list[Document]:
165-
request = RetrieveDocumentsRequest(
163+
request = RetrieveDocumentsRequest( # type: ignore[call-arg]
166164
question=query,
167165
num_results=num_results or self.num_results,
168166
rerank=rerank or self.rerank,
169167
metadata_filters=metadata_filters or self.metadata_filters,
170168
)
169+
organization_ = organization or self.organization
170+
if not organization_:
171+
msg = (
172+
"Organization must be set either at initialization "
173+
"or in the invoke method."
174+
)
175+
raise ValueError(msg)
176+
pipeline_id_ = pipeline_id or self.pipeline_id
177+
if not pipeline_id_:
178+
msg = (
179+
"Pipeline ID must be set either at initialization "
180+
"or in the invoke method."
181+
)
182+
raise ValueError(msg)
183+
171184
response = self._pipelines.retrieve_documents(
172-
organization or self.organization, pipeline_id or self.pipeline_id, request
185+
organization_, pipeline_id_, request
173186
)
174187
return [self._convert_document(doc) for doc in response.documents]
175188

@@ -181,9 +194,10 @@ def invoke(
181194
*,
182195
organization: str = "",
183196
pipeline_id: str = "",
184-
num_results: int = _NOT_SET,
185-
rerank: bool = _NOT_SET,
186-
metadata_filters: list[dict[str, Any]] = _NOT_SET,
197+
num_results: int = _NOT_SET, # type: ignore[assignment]
198+
rerank: bool = _NOT_SET, # type: ignore[assignment]
199+
metadata_filters: list[dict[str, Any]] = _NOT_SET, # type: ignore[assignment]
200+
**_kwargs: Any,
187201
) -> list[Document]:
188202
"""Invoke the retriever to get relevant documents.
189203
@@ -218,16 +232,15 @@ def invoke(
218232
query = "what year was breath of the wild released?"
219233
docs = retriever.invoke(query, num_results=2)
220234
"""
221-
kwargs = {}
222235
if organization:
223-
kwargs["organization"] = organization
236+
_kwargs["organization"] = organization
224237
if pipeline_id:
225-
kwargs["pipeline_id"] = pipeline_id
238+
_kwargs["pipeline_id"] = pipeline_id
226239
if num_results is not _NOT_SET:
227-
kwargs["num_results"] = num_results
240+
_kwargs["num_results"] = num_results
228241
if rerank is not _NOT_SET:
229-
kwargs["rerank"] = rerank
242+
_kwargs["rerank"] = rerank
230243
if metadata_filters is not _NOT_SET:
231-
kwargs["metadata_filters"] = metadata_filters
244+
_kwargs["metadata_filters"] = metadata_filters
232245

233-
return super().invoke(input, config, **kwargs)
246+
return super().invoke(input, config, **_kwargs)

langchain/pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Issues = "https://github.com/vectorize-io/integrations-python/issues"
3131

3232
[dependency-groups]
3333
dev = [
34-
"mypy>=1.13.0",
34+
"mypy>=1.17.1,<1.18",
3535
"pytest>=8.3.3",
3636
"ruff>=0.9.0,<0.10",
3737
]
@@ -59,6 +59,8 @@ flake8-annotations.mypy-init-return = true
5959

6060
[tool.mypy]
6161
strict = true
62+
strict_bytes = true
63+
enable_error_code = "deprecated"
6264
warn_unreachable = true
6365
pretty = true
6466
show_error_codes = true

langchain/tests/test_retrievers.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,29 @@
88

99
import pytest
1010
import urllib3
11-
import vectorize_client as v
12-
from vectorize_client import ApiClient
11+
from vectorize_client.api.connectors_api import ConnectorsApi
12+
from vectorize_client.api.pipelines_api import PipelinesApi
13+
from vectorize_client.api.uploads_api import UploadsApi
14+
from vectorize_client.api_client import ApiClient
15+
from vectorize_client.configuration import Configuration
16+
from vectorize_client.models.ai_platform_config_schema import AIPlatformConfigSchema
17+
from vectorize_client.models.ai_platform_schema import AIPlatformSchema
18+
from vectorize_client.models.ai_platform_type import AIPlatformType
19+
from vectorize_client.models.create_source_connector import CreateSourceConnector
20+
from vectorize_client.models.destination_connector_schema import (
21+
DestinationConnectorSchema,
22+
)
23+
from vectorize_client.models.destination_connector_type import DestinationConnectorType
24+
from vectorize_client.models.pipeline_configuration_schema import (
25+
PipelineConfigurationSchema,
26+
)
27+
from vectorize_client.models.schedule_schema import ScheduleSchema
28+
from vectorize_client.models.schedule_schema_type import ScheduleSchemaType
29+
from vectorize_client.models.source_connector_schema import SourceConnectorSchema
30+
from vectorize_client.models.source_connector_type import SourceConnectorType
31+
from vectorize_client.models.start_file_upload_to_connector_request import (
32+
StartFileUploadToConnectorRequest,
33+
)
1334

1435
from langchain_vectorize.retrievers import VectorizeRetriever
1536

@@ -38,7 +59,7 @@ def environment() -> Literal["prod", "dev", "local", "staging"]:
3859
if env not in ["prod", "dev", "local", "staging"]:
3960
msg = "Invalid VECTORIZE_ENV environment variable."
4061
raise ValueError(msg)
41-
return env
62+
return env # type: ignore[return-value]
4263

4364

4465
@pytest.fixture(scope="session")
@@ -56,35 +77,31 @@ def api_client(api_token: str, environment: str) -> Iterator[ApiClient]:
5677
else:
5778
host = "https://api-staging.vectorize.io/v1"
5879

59-
with v.ApiClient(
60-
v.Configuration(host=host, access_token=api_token, debug=True),
80+
with ApiClient(
81+
Configuration(host=host, access_token=api_token, debug=True),
6182
header_name,
6283
header_value,
6384
) as api:
6485
yield api
6586

6687

6788
@pytest.fixture(scope="session")
68-
def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
69-
pipelines = v.PipelinesApi(api_client)
89+
def pipeline_id(api_client: ApiClient, org_id: str) -> Iterator[str]:
90+
pipelines = PipelinesApi(api_client)
7091

71-
connectors_api = v.ConnectorsApi(api_client)
92+
connectors_api = ConnectorsApi(api_client)
7293
response = connectors_api.create_source_connector(
7394
org_id,
74-
[
75-
v.CreateSourceConnector(
76-
name="from api", type=v.SourceConnectorType.FILE_UPLOAD
77-
)
78-
],
95+
[CreateSourceConnector(name="from api", type=SourceConnectorType.FILE_UPLOAD)],
7996
)
8097
source_connector_id = response.connectors[0].id
8198
logging.info("Created source connector %s", source_connector_id)
8299

83-
uploads_api = v.UploadsApi(api_client)
100+
uploads_api = UploadsApi(api_client)
84101
upload_response = uploads_api.start_file_upload_to_connector(
85102
org_id,
86103
source_connector_id,
87-
v.StartFileUploadToConnectorRequest(
104+
StartFileUploadToConnectorRequest( # type: ignore[call-arg]
88105
name="research.pdf",
89106
content_type="application/pdf",
90107
metadata=json.dumps({"created-from-api": True}),
@@ -125,26 +142,26 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
125142

126143
pipeline_response = pipelines.create_pipeline(
127144
org_id,
128-
v.PipelineConfigurationSchema(
145+
PipelineConfigurationSchema( # type: ignore[call-arg]
129146
source_connectors=[
130-
v.SourceConnectorSchema(
147+
SourceConnectorSchema(
131148
id=source_connector_id,
132-
type=v.SourceConnectorType.FILE_UPLOAD,
149+
type=SourceConnectorType.FILE_UPLOAD,
133150
config={},
134151
)
135152
],
136-
destination_connector=v.DestinationConnectorSchema(
153+
destination_connector=DestinationConnectorSchema(
137154
id=builtin_vector_db,
138-
type=v.DestinationConnectorType.VECTORIZE,
155+
type=DestinationConnectorType.VECTORIZE,
139156
config={},
140157
),
141-
ai_platform=v.AIPlatformSchema(
158+
ai_platform=AIPlatformSchema(
142159
id=builtin_ai_platform,
143-
type=v.AIPlatformType.VECTORIZE,
144-
config=v.AIPlatformConfigSchema(),
160+
type=AIPlatformType.VECTORIZE,
161+
config=AIPlatformConfigSchema(),
145162
),
146163
pipeline_name="Test pipeline",
147-
schedule=v.ScheduleSchema(type=v.ScheduleSchemaType.MANUAL),
164+
schedule=ScheduleSchema(type=ScheduleSchemaType.MANUAL),
148165
),
149166
)
150167
pipeline_id = pipeline_response.data.id

0 commit comments

Comments
 (0)