diff --git a/django_mongodb_backend/indexes.py b/django_mongodb_backend/indexes.py index 7c0730e4..7044357d 100644 --- a/django_mongodb_backend/indexes.py +++ b/django_mongodb_backend/indexes.py @@ -109,9 +109,22 @@ class SearchIndex(Index): suffix = "six" _error_id_prefix = "django_mongodb_backend.indexes.SearchIndex" - def __init__(self, *, fields=(), name=None): + def __init__(self, *, fields=(), name=None, field_mappings=None): + if field_mappings and not isinstance(field_mappings, dict): + raise ValueError( + "field_mappings must be a dictionary mapping field names to their " + "Atlas Search field mappings." + ) + self.field_mappings = field_mappings or {} + + fields = list({*fields, *self.field_mappings.keys()}) super().__init__(fields=fields, name=name) + def deconstruct(self): + path, args, kwargs = super().deconstruct() + kwargs["field_mappings"] = self.field_mappings + return path, args, kwargs + def check(self, model, connection): errors = [] if not connection.features.supports_atlas_search: @@ -152,23 +165,39 @@ def get_pymongo_index_model( return None fields = {} for field_name, _ in self.fields_orders: - field = model._meta.get_field(field_name) - type_ = self.search_index_data_types(field.db_type(schema_editor.connection)) field_path = column_prefix + model._meta.get_field(field_name).column - fields[field_path] = {"type": type_} + if field_name in self.field_mappings: + fields[field_path] = self.field_mappings[field_name].copy() + else: + # If no field mapping is provided, use the default search index data type. + field = model._meta.get_field(field_name) + type_ = self.search_index_data_types(field.db_type(schema_editor.connection)) + fields[field_path] = {"type": type_} return SearchIndexModel( definition={"mappings": {"dynamic": False, "fields": fields}}, name=self.name ) +class DynamicSearchIndex(SearchIndex): + suffix = "dsix" + _error_id_prefix = "django_mongodb_backend.indexes.DynamicSearchIndex" + + def get_pymongo_index_model( + self, model, schema_editor, field=None, unique=False, column_prefix="" + ): + if not schema_editor.connection.features.supports_atlas_search: + return None + return SearchIndexModel(definition={"mappings": {"dynamic": True}}, name=self.name) + + class VectorSearchIndex(SearchIndex): suffix = "vsi" _error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex" VALID_FIELD_TYPES = frozenset(("boolean", "date", "number", "objectId", "string", "uuid")) VALID_SIMILARITIES = frozenset(("cosine", "dotProduct", "euclidean")) - def __init__(self, *, fields=(), name=None, similarities): - super().__init__(fields=fields, name=name) + def __init__(self, *, fields=(), name=None, similarities=(), fields_mappings=None): + super().__init__(fields=fields, name=name, field_mappings=fields_mappings) self.similarities = similarities self._multiple_similarities = isinstance(similarities, tuple | list) for func in similarities if self._multiple_similarities else (similarities,): diff --git a/django_mongodb_backend/schema.py b/django_mongodb_backend/schema.py index 9472db96..e1189ec7 100644 --- a/django_mongodb_backend/schema.py +++ b/django_mongodb_backend/schema.py @@ -1,3 +1,5 @@ +from time import monotonic, sleep + from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.models import Index, UniqueConstraint from pymongo.operations import SearchIndexModel @@ -28,6 +30,27 @@ def wrapper(self, model, *args, **kwargs): return wrapper +def wait_until_index_ready(collection, index_name, timeout: float = 60, interval: float = 0.5): + start = monotonic() + while monotonic() - start < timeout: + indexes = list(collection.list_search_indexes()) + for idx in indexes: + if idx["name"] == index_name and idx["status"] == "READY": + return True + sleep(interval) + raise TimeoutError(f"Index {index_name} not ready after {timeout} seconds") + + +def wait_until_index_delete(collection, index_name, timeout: float = 60, interval: float = 0.5): + start = monotonic() + while monotonic() - start < timeout: + indexes = list(collection.list_search_indexes()) + if all(idx["name"] != index_name for idx in indexes): + return True + sleep(interval) + raise TimeoutError(f"Index {index_name} not deleted after {timeout} seconds") + + class BaseSchemaEditor(BaseDatabaseSchemaEditor): def get_collection(self, name): if self.collect_sql: @@ -269,10 +292,12 @@ def add_index( ) if idx: model = parent_model or model + collection = self.get_collection(model._meta.db_table) if isinstance(idx, SearchIndexModel): - self.get_collection(model._meta.db_table).create_search_index(idx) + collection.create_search_index(idx) + wait_until_index_ready(collection, index.name) else: - self.get_collection(model._meta.db_table).create_indexes([idx]) + collection.create_indexes([idx]) def _add_composed_index(self, model, field_names, column_prefix="", parent_model=None): """Add an index on the given list of field_names.""" @@ -290,12 +315,14 @@ def _add_field_index(self, model, field, *, column_prefix=""): def remove_index(self, model, index): if index.contains_expressions: return + collection = self.get_collection(model._meta.db_table) if isinstance(index, SearchIndex): # Drop the index if it's supported. if self.connection.features.supports_atlas_search: - self.get_collection(model._meta.db_table).drop_search_index(index.name) + collection.drop_search_index(index.name) + wait_until_index_delete(collection, index.name) else: - self.get_collection(model._meta.db_table).drop_index(index.name) + collection.drop_index(index.name) def _remove_composed_index( self, model, field_names, constraint_kwargs, column_prefix="", parent_model=None