Skip to content

Add mypy type checks #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
run: uv run ruff format --check
- name: Run ruff check
run: uv run ruff check
# - name: Run mypy
# run: uv run mypy .
- name: Run mypy
run: uv run mypy .
- name: Minimize uv cache
run: uv cache prune --ci
55 changes: 34 additions & 21 deletions langchain/langchain_vectorize/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@

from typing import TYPE_CHECKING, Any, Literal, Optional

import vectorize_client
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from typing_extensions import override
from vectorize_client import (
ApiClient,
Configuration,
PipelinesApi,
RetrieveDocumentsRequest,
)
from vectorize_client.api.pipelines_api import PipelinesApi
from vectorize_client.api_client import ApiClient
from vectorize_client.configuration import Configuration
from vectorize_client.models.retrieve_documents_request import RetrieveDocumentsRequest

if TYPE_CHECKING:
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.runnables import RunnableConfig
from vectorize_client.models.document import Document as VectorizeDocument

_METADATA_FIELDS = {
"relevancy",
Expand Down Expand Up @@ -122,7 +120,7 @@ def format_docs(docs):
metadata_filters: list[dict[str, Any]] = []
"""The metadata filters to apply when retrieving the documents."""

_pipelines: PipelinesApi | None = None
_pipelines: PipelinesApi = _NOT_SET # type: ignore[assignment]

@override
def model_post_init(self, /, context: Any) -> None:
Expand All @@ -146,7 +144,7 @@ def model_post_init(self, /, context: Any) -> None:
self._pipelines = PipelinesApi(api)

@staticmethod
def _convert_document(document: vectorize_client.models.Document) -> Document:
def _convert_document(document: VectorizeDocument) -> Document:
metadata = {field: getattr(document, field) for field in _METADATA_FIELDS}
return Document(id=document.id, page_content=document.text, metadata=metadata)

Expand All @@ -162,14 +160,29 @@ def _get_relevant_documents(
rerank: bool | None = None,
metadata_filters: list[dict[str, Any]] | None = None,
) -> list[Document]:
request = RetrieveDocumentsRequest(
request = RetrieveDocumentsRequest( # type: ignore[call-arg]
question=query,
num_results=num_results or self.num_results,
rerank=rerank or self.rerank,
metadata_filters=metadata_filters or self.metadata_filters,
)
organization_ = organization or self.organization
if not organization_:
msg = (
"Organization must be set either at initialization "
"or in the invoke method."
)
raise ValueError(msg)
pipeline_id_ = pipeline_id or self.pipeline_id
if not pipeline_id_:
msg = (
"Pipeline ID must be set either at initialization "
"or in the invoke method."
)
raise ValueError(msg)

response = self._pipelines.retrieve_documents(
organization or self.organization, pipeline_id or self.pipeline_id, request
organization_, pipeline_id_, request
)
return [self._convert_document(doc) for doc in response.documents]

Expand All @@ -181,9 +194,10 @@ def invoke(
*,
organization: str = "",
pipeline_id: str = "",
num_results: int = _NOT_SET,
rerank: bool = _NOT_SET,
metadata_filters: list[dict[str, Any]] = _NOT_SET,
num_results: int = _NOT_SET, # type: ignore[assignment]
rerank: bool = _NOT_SET, # type: ignore[assignment]
metadata_filters: list[dict[str, Any]] = _NOT_SET, # type: ignore[assignment]
**_kwargs: Any,
) -> list[Document]:
"""Invoke the retriever to get relevant documents.

Expand Down Expand Up @@ -218,16 +232,15 @@ def invoke(
query = "what year was breath of the wild released?"
docs = retriever.invoke(query, num_results=2)
"""
kwargs = {}
if organization:
kwargs["organization"] = organization
_kwargs["organization"] = organization
if pipeline_id:
kwargs["pipeline_id"] = pipeline_id
_kwargs["pipeline_id"] = pipeline_id
if num_results is not _NOT_SET:
kwargs["num_results"] = num_results
_kwargs["num_results"] = num_results
if rerank is not _NOT_SET:
kwargs["rerank"] = rerank
_kwargs["rerank"] = rerank
if metadata_filters is not _NOT_SET:
kwargs["metadata_filters"] = metadata_filters
_kwargs["metadata_filters"] = metadata_filters

return super().invoke(input, config, **kwargs)
return super().invoke(input, config, **_kwargs)
4 changes: 3 additions & 1 deletion langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Issues = "https://github.com/vectorize-io/integrations-python/issues"

[dependency-groups]
dev = [
"mypy>=1.13.0",
"mypy>=1.17.1,<1.18",
"pytest>=8.3.3",
"ruff>=0.9.0,<0.10",
]
Expand Down Expand Up @@ -59,6 +59,8 @@ flake8-annotations.mypy-init-return = true

[tool.mypy]
strict = true
strict_bytes = true
enable_error_code = "deprecated"
warn_unreachable = true
pretty = true
show_error_codes = true
Expand Down
84 changes: 58 additions & 26 deletions langchain/tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,42 @@

import pytest
import urllib3
import vectorize_client as v
from vectorize_client import ApiClient
from vectorize_client.api.ai_platform_connectors_api import AIPlatformConnectorsApi
from vectorize_client.api.destination_connectors_api import DestinationConnectorsApi
from vectorize_client.api.pipelines_api import PipelinesApi
from vectorize_client.api.source_connectors_api import SourceConnectorsApi
from vectorize_client.api.uploads_api import UploadsApi
from vectorize_client.api_client import ApiClient
from vectorize_client.configuration import Configuration
from vectorize_client.models.ai_platform_config_schema import AIPlatformConfigSchema
from vectorize_client.models.ai_platform_type_for_pipeline import (
AIPlatformTypeForPipeline,
)
from vectorize_client.models.create_source_connector_request import (
CreateSourceConnectorRequest,
)
from vectorize_client.models.destination_connector_type_for_pipeline import (
DestinationConnectorTypeForPipeline,
)
from vectorize_client.models.file_upload import FileUpload
from vectorize_client.models.pipeline_ai_platform_connector_schema import (
PipelineAIPlatformConnectorSchema,
)
from vectorize_client.models.pipeline_configuration_schema import (
PipelineConfigurationSchema,
)
from vectorize_client.models.pipeline_destination_connector_schema import (
PipelineDestinationConnectorSchema,
)
from vectorize_client.models.pipeline_source_connector_schema import (
PipelineSourceConnectorSchema,
)
from vectorize_client.models.schedule_schema import ScheduleSchema
from vectorize_client.models.schedule_schema_type import ScheduleSchemaType
from vectorize_client.models.source_connector_type import SourceConnectorType
from vectorize_client.models.start_file_upload_to_connector_request import (
StartFileUploadToConnectorRequest,
)

from langchain_vectorize.retrievers import VectorizeRetriever

Expand Down Expand Up @@ -38,7 +72,7 @@ def environment() -> Literal["prod", "dev", "local", "staging"]:
if env not in ["prod", "dev", "local", "staging"]:
msg = "Invalid VECTORIZE_ENV environment variable."
raise ValueError(msg)
return env
return env # type: ignore[return-value]


@pytest.fixture(scope="session")
Expand All @@ -56,33 +90,31 @@ def api_client(api_token: str, environment: str) -> Iterator[ApiClient]:
else:
host = "https://api-staging.vectorize.io/v1"

with v.ApiClient(
v.Configuration(host=host, access_token=api_token, debug=True),
with ApiClient(
Configuration(host=host, access_token=api_token, debug=True),
header_name,
header_value,
) as api:
yield api


@pytest.fixture(scope="session")
def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
pipelines = v.PipelinesApi(api_client)
def pipeline_id(api_client: ApiClient, org_id: str) -> Iterator[str]:
pipelines = PipelinesApi(api_client)

connectors_api = v.SourceConnectorsApi(api_client)
connectors_api = SourceConnectorsApi(api_client)
response = connectors_api.create_source_connector(
org_id,
v.CreateSourceConnectorRequest(
v.FileUpload(name="from api", type="FILE_UPLOAD")
),
CreateSourceConnectorRequest(FileUpload(name="from api", type="FILE_UPLOAD")),
)
source_connector_id = response.connector.id
logging.info("Created source connector %s", source_connector_id)

uploads_api = v.UploadsApi(api_client)
uploads_api = UploadsApi(api_client)
upload_response = uploads_api.start_file_upload_to_connector(
org_id,
source_connector_id,
v.StartFileUploadToConnectorRequest(
StartFileUploadToConnectorRequest( # type: ignore[call-arg]
name="research.pdf",
content_type="application/pdf",
metadata=json.dumps({"created-from-api": True}),
Expand All @@ -109,44 +141,44 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
else:
logging.info("Upload successful")

ai_platforms = v.AIPlatformConnectorsApi(api_client).get_ai_platform_connectors(
ai_platforms = AIPlatformConnectorsApi(api_client).get_ai_platform_connectors(
org_id
)
builtin_ai_platform = next(
c.id for c in ai_platforms.ai_platform_connectors if c.type == "VECTORIZE"
)
logging.info("Using AI platform %s", builtin_ai_platform)

vector_databases = v.DestinationConnectorsApi(
api_client
).get_destination_connectors(org_id)
vector_databases = DestinationConnectorsApi(api_client).get_destination_connectors(
org_id
)
builtin_vector_db = next(
c.id for c in vector_databases.destination_connectors if c.type == "VECTORIZE"
)
logging.info("Using destination connector %s", builtin_vector_db)

pipeline_response = pipelines.create_pipeline(
org_id,
v.PipelineConfigurationSchema(
PipelineConfigurationSchema( # type: ignore[call-arg]
source_connectors=[
v.PipelineSourceConnectorSchema(
PipelineSourceConnectorSchema(
id=source_connector_id,
type=v.SourceConnectorType.FILE_UPLOAD,
type=SourceConnectorType.FILE_UPLOAD,
config={},
)
],
destination_connector=v.PipelineDestinationConnectorSchema(
destination_connector=PipelineDestinationConnectorSchema(
id=builtin_vector_db,
type="VECTORIZE",
type=DestinationConnectorTypeForPipeline.VECTORIZE,
config={},
),
ai_platform_connector=v.PipelineAIPlatformConnectorSchema(
ai_platform_connector=PipelineAIPlatformConnectorSchema(
id=builtin_ai_platform,
type="VECTORIZE",
config={},
type=AIPlatformTypeForPipeline.VECTORIZE,
config=AIPlatformConfigSchema(),
),
pipeline_name="Test pipeline",
schedule=v.ScheduleSchema(type="manual"),
schedule=ScheduleSchema(type=ScheduleSchemaType.MANUAL),
),
)
pipeline_id = pipeline_response.data.id
Expand Down
Loading
Loading