Skip to content

Commit 189381e

Browse files
Merge #927
927: Make embedders deseralize to correct type r=curquiza a=sanders41 # Pull Request ## Related issue Fixes #926 ## What does this PR do? - Checks the model before deserializing to make sure the correct model is used. - Updates test to catch any potential regressions. ## PR checklist Please check if your PR fulfills the following requirements: - [x] Does this PR fix an existing issue, or have you listed the changes applied in the PR description (and why they are needed)? - [x] Have you read the contributing guidelines? - [x] Have you made sure that the title is accurate and descriptive of the changes? Thank you so much for contributing to Meilisearch! Co-authored-by: Paul Sanders <[email protected]>
2 parents 28b16ce + f2f6929 commit 189381e

File tree

4 files changed

+83
-25
lines changed

4 files changed

+83
-25
lines changed

meilisearch/index.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,16 @@
1010
from meilisearch.config import Config
1111
from meilisearch.errors import version_error_hint_message
1212
from meilisearch.models.document import Document, DocumentsResults
13-
from meilisearch.models.index import Embedders, Faceting, IndexStats, Pagination, TypoTolerance
13+
from meilisearch.models.index import (
14+
Embedders,
15+
Faceting,
16+
HuggingFaceEmbedder,
17+
IndexStats,
18+
OpenAiEmbedder,
19+
Pagination,
20+
TypoTolerance,
21+
UserProvidedEmbedder,
22+
)
1423
from meilisearch.models.task import Task, TaskInfo, TaskResults
1524
from meilisearch.task import TaskHandler
1625

@@ -865,7 +874,23 @@ def get_settings(self) -> Dict[str, Any]:
865874
MeilisearchApiError
866875
An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors
867876
"""
868-
return self.http.get(f"{self.config.paths.index}/{self.uid}/{self.config.paths.setting}")
877+
settings = self.http.get(
878+
f"{self.config.paths.index}/{self.uid}/{self.config.paths.setting}"
879+
)
880+
881+
if settings.get("embedders"):
882+
embedders: dict[str, OpenAiEmbedder | HuggingFaceEmbedder | UserProvidedEmbedder] = {}
883+
for k, v in settings["embedders"].items():
884+
if v.get("source") == "openAi":
885+
embedders[k] = OpenAiEmbedder(**v)
886+
elif v.get("source") == "huggingFace":
887+
embedders[k] = HuggingFaceEmbedder(**v)
888+
else:
889+
embedders[k] = UserProvidedEmbedder(**v)
890+
891+
settings["embedders"] = embedders
892+
893+
return settings
869894

870895
def update_settings(self, body: Mapping[str, Any]) -> TaskInfo:
871896
"""Update settings of the index.
@@ -1777,7 +1802,17 @@ def get_embedders(self) -> Embedders | None:
17771802
if not response:
17781803
return None
17791804

1780-
return Embedders(embedders=response)
1805+
embedders: dict[str, OpenAiEmbedder | HuggingFaceEmbedder | UserProvidedEmbedder] = {}
1806+
for k, v in response.items():
1807+
print(v.get("source"))
1808+
if v.get("source") == "openAi":
1809+
embedders[k] = OpenAiEmbedder(**v)
1810+
elif v.get("source") == "huggingFace":
1811+
embedders[k] = HuggingFaceEmbedder(**v)
1812+
else:
1813+
embedders[k] = UserProvidedEmbedder(**v)
1814+
1815+
return Embedders(embedders=embedders)
17811816

17821817
def update_embedders(self, body: Union[Mapping[str, Any], None]) -> TaskInfo:
17831818
"""Update embedders of the index.

tests/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import meilisearch
99
from meilisearch.errors import MeilisearchApiError
10+
from meilisearch.models.index import HuggingFaceEmbedder, OpenAiEmbedder, UserProvidedEmbedder
1011
from tests import common
1112

1213

@@ -230,8 +231,7 @@ def enable_vector_search():
230231
@fixture
231232
def new_embedders():
232233
return {
233-
"default": {
234-
"source": "userProvided",
235-
"dimensions": 1,
236-
}
234+
"default": UserProvidedEmbedder(dimensions=1).model_dump(by_alias=True),
235+
"open_ai": OpenAiEmbedder().model_dump(by_alias=True),
236+
"hugging_face": HuggingFaceEmbedder().model_dump(by_alias=True),
237237
}

tests/settings/test_settings.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
1-
NEW_SETTINGS = {
2-
"rankingRules": ["typo", "words"],
3-
"searchableAttributes": ["title", "overview"],
4-
}
1+
# pylint: disable=redefined-outer-name
2+
import pytest
3+
4+
from meilisearch.models.index import HuggingFaceEmbedder, OpenAiEmbedder, UserProvidedEmbedder
5+
6+
7+
@pytest.fixture
8+
def new_settings(new_embedders):
9+
return {
10+
"rankingRules": ["typo", "words"],
11+
"searchableAttributes": ["title", "overview"],
12+
"embedders": new_embedders,
13+
}
14+
515

616
DEFAULT_RANKING_RULES = ["words", "typo", "proximity", "attribute", "sort", "exactness"]
717

@@ -31,36 +41,41 @@ def test_get_settings_default(empty_index):
3141
assert response["synonyms"] == {}
3242

3343

34-
def test_update_settings(empty_index):
44+
@pytest.mark.usefixtures("enable_vector_search")
45+
def test_update_settings(new_settings, empty_index):
3546
"""Tests updating some settings."""
3647
index = empty_index()
37-
response = index.update_settings(NEW_SETTINGS)
48+
response = index.update_settings(new_settings)
3849
update = index.wait_for_task(response.task_uid)
3950
assert update.status == "succeeded"
4051
response = index.get_settings()
41-
for rule in NEW_SETTINGS["rankingRules"]:
52+
for rule in new_settings["rankingRules"]:
4253
assert rule in response["rankingRules"]
4354
assert response["distinctAttribute"] is None
44-
for attribute in NEW_SETTINGS["searchableAttributes"]:
55+
for attribute in new_settings["searchableAttributes"]:
4556
assert attribute in response["searchableAttributes"]
4657
assert response["displayedAttributes"] == ["*"]
4758
assert response["stopWords"] == []
4859
assert response["synonyms"] == {}
60+
assert isinstance(response["embedders"]["default"], UserProvidedEmbedder)
61+
assert isinstance(response["embedders"]["open_ai"], OpenAiEmbedder)
62+
assert isinstance(response["embedders"]["hugging_face"], HuggingFaceEmbedder)
4963

5064

51-
def test_reset_settings(empty_index):
65+
@pytest.mark.usefixtures("enable_vector_search")
66+
def test_reset_settings(new_settings, empty_index):
5267
"""Tests resetting all the settings to their default value."""
5368
index = empty_index()
5469
# Update settings first
55-
response = index.update_settings(NEW_SETTINGS)
70+
response = index.update_settings(new_settings)
5671
update = index.wait_for_task(response.task_uid)
5772
assert update.status == "succeeded"
5873
# Check the settings have been correctly updated
5974
response = index.get_settings()
60-
for rule in NEW_SETTINGS["rankingRules"]:
75+
for rule in new_settings["rankingRules"]:
6176
assert rule in response["rankingRules"]
6277
assert response["distinctAttribute"] is None
63-
for attribute in NEW_SETTINGS["searchableAttributes"]:
78+
for attribute in new_settings["searchableAttributes"]:
6479
assert attribute in response["searchableAttributes"]
6580
assert response["displayedAttributes"] == ["*"]
6681
assert response["stopWords"] == []
@@ -80,3 +95,4 @@ def test_reset_settings(empty_index):
8095
assert response["searchableAttributes"] == ["*"]
8196
assert response["stopWords"] == []
8297
assert response["synonyms"] == {}
98+
assert response.get("embedders") is None

tests/settings/test_settings_embedders.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
# pylint: disable=redefined-outer-name
12
import pytest
23

3-
from meilisearch.models.index import Embedders
4+
from meilisearch.models.index import HuggingFaceEmbedder, OpenAiEmbedder, UserProvidedEmbedder
45

56

67
@pytest.mark.usefixtures("enable_vector_search")
@@ -19,7 +20,9 @@ def test_update_embedders_with_user_provided_source(new_embedders, empty_index):
1920
update = index.wait_for_task(response_update.task_uid)
2021
response_get = index.get_embedders()
2122
assert update.status == "succeeded"
22-
assert response_get == Embedders(embedders=new_embedders)
23+
assert isinstance(response_get.embedders["default"], UserProvidedEmbedder)
24+
assert isinstance(response_get.embedders["open_ai"], OpenAiEmbedder)
25+
assert isinstance(response_get.embedders["hugging_face"], HuggingFaceEmbedder)
2326

2427

2528
@pytest.mark.usefixtures("enable_vector_search")
@@ -30,15 +33,19 @@ def test_reset_embedders(new_embedders, empty_index):
3033
# Update the settings
3134
response_update = index.update_embedders(new_embedders)
3235
update1 = index.wait_for_task(response_update.task_uid)
36+
assert update1.status == "succeeded"
3337
# Get the setting after update
3438
response_get = index.get_embedders()
39+
assert isinstance(response_get.embedders["default"], UserProvidedEmbedder)
40+
assert isinstance(response_get.embedders["open_ai"], OpenAiEmbedder)
41+
assert isinstance(response_get.embedders["hugging_face"], HuggingFaceEmbedder)
3542
# Reset the setting
3643
response_reset = index.reset_embedders()
3744
update2 = index.wait_for_task(response_reset.task_uid)
3845
# Get the setting after reset
39-
response_last = index.get_embedders()
40-
41-
assert update1.status == "succeeded"
42-
assert response_get == Embedders(embedders=new_embedders)
4346
assert update2.status == "succeeded"
47+
assert isinstance(response_get.embedders["default"], UserProvidedEmbedder)
48+
assert isinstance(response_get.embedders["open_ai"], OpenAiEmbedder)
49+
assert isinstance(response_get.embedders["hugging_face"], HuggingFaceEmbedder)
50+
response_last = index.get_embedders()
4451
assert response_last is None

0 commit comments

Comments
 (0)