From e8212d764e62792b90db3e5a0ed199a1dc1bf79c Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 9 Sep 2025 14:15:26 -0300 Subject: [PATCH 1/3] WIP. --- django_mongodb_backend/base.py | 40 ++++++++++++++++++- django_mongodb_backend/compiler.py | 2 +- .../expressions/builtins.py | 12 +++--- django_mongodb_backend/fields/json.py | 26 ++++++++---- django_mongodb_backend/functions.py | 10 +++-- django_mongodb_backend/lookups.py | 24 ++++++++--- django_mongodb_backend/query.py | 7 ++-- django_mongodb_backend/query_utils.py | 17 +++++--- 8 files changed, 104 insertions(+), 34 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index f751c27f..4699cf94 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -20,7 +20,7 @@ from .features import DatabaseFeatures from .introspection import DatabaseIntrospection from .operations import DatabaseOperations -from .query_utils import regex_match +from .query_utils import regex_expr, regex_match from .schema import DatabaseSchemaEditor from .utils import OperationDebugWrapper from .validation import DatabaseValidation @@ -108,7 +108,12 @@ def _isnull_operator(a, b): } return is_null if b else {"$not": is_null} - mongo_operators = { + def _isnull_operator_match(a, b): + if b: + return {"$or": [{a: {"$exists": False}}, {a: None}]} + return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]} + + mongo_operators_expr = { "exact": lambda a, b: {"$eq": [a, b]}, "gt": lambda a, b: {"$gt": [a, b]}, "gte": lambda a, b: {"$gte": [a, b]}, @@ -126,6 +131,37 @@ def _isnull_operator(a, b): {"$or": [DatabaseWrapper._isnull_operator(b[1], True), {"$lte": [a, b[1]]}]}, ] }, + "iexact": lambda a, b: regex_expr(a, ("^", b, {"$literal": "$"}), insensitive=True), + "startswith": lambda a, b: regex_expr(a, ("^", b)), + "istartswith": lambda a, b: regex_expr(a, ("^", b), insensitive=True), + "endswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"})), + "iendswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"}), insensitive=True), + "contains": lambda a, b: regex_expr(a, b), + "icontains": lambda a, b: regex_expr(a, b, insensitive=True), + "regex": lambda a, b: regex_expr(a, b), + "iregex": lambda a, b: regex_expr(a, b, insensitive=True), + } + + mongo_operators_match = { + "exact": lambda a, b: {a: b}, + "gt": lambda a, b: {a: {"$gt": b}}, + "gte": lambda a, b: {a: {"$gte": b}}, + # MongoDB considers null less than zero. Exclude null values to match + # SQL behavior. + "lt": lambda a, b: { + "$and": [{a: {"$lt": b}}, DatabaseWrapper._isnull_operator_match(a, False)] + }, + "lte": lambda a, b: { + "$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator_match(a, False)] + }, + "in": lambda a, b: {a: {"$in": list(b)}}, + "isnull": _isnull_operator_match, + "range": lambda a, b: { + "$and": [ + {"$or": [DatabaseWrapper._isnull_operator_match(b[0], True), {a: {"$gte": b[0]}}]}, + {"$or": [DatabaseWrapper._isnull_operator_match(b[1], True), {a: {"$lte": b[1]}}]}, + ] + }, "iexact": lambda a, b: regex_match(a, ("^", b, {"$literal": "$"}), insensitive=True), "startswith": lambda a, b: regex_match(a, ("^", b)), "istartswith": lambda a, b: regex_match(a, ("^", b), insensitive=True), diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 628a91e8..bd259694 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -485,7 +485,7 @@ def build_query(self, columns=None): except FullResultSet: query.match_mql = {} else: - query.match_mql = {"$expr": expr} + query.match_mql = expr if extra_fields: query.extra_fields = self.get_project_fields(extra_fields, force_expression=True) query.subqueries = self.subqueries diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 0bc93935..6514ca3a 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -33,7 +33,7 @@ def case(self, compiler, connection): for case in self.cases: case_mql = {} try: - case_mql["case"] = case.as_mql(compiler, connection) + case_mql["case"] = case.as_mql(compiler, connection, as_expr=True) except EmptyResultSet: continue except FullResultSet: @@ -152,7 +152,7 @@ def raw_sql(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("RawSQL is not supported on MongoDB.") -def ref(self, compiler, connection): # noqa: ARG001 +def ref(self, compiler, connection, as_path=False): # noqa: ARG001 prefix = ( f"{self.source.alias}." if isinstance(self.source, Col) and self.source.alias != compiler.collection_name @@ -162,7 +162,9 @@ def ref(self, compiler, connection): # noqa: ARG001 refs, _ = compiler.columns[self.ordinal - 1] else: refs = self.refs - return f"${prefix}{refs}" + if not as_path: + prefix = f"${prefix}" + return f"{prefix}{refs}" def star(self, compiler, connection): # noqa: ARG001 @@ -181,8 +183,8 @@ def exists(self, compiler, connection, get_wrapping_pipeline=None): return connection.mongo_operators["isnull"](lhs_mql, False) -def when(self, compiler, connection): - return self.condition.as_mql(compiler, connection) +def when(self, compiler, connection, **extra): + return self.condition.as_mql(compiler, connection, **extra) def value(self, compiler, connection): # noqa: ARG001 diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index 1a7ecb61..c68c9ef8 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -1,3 +1,5 @@ +from itertools import chain + from django.db import NotSupportedError from django.db.models.fields.json import ( ContainedBy, @@ -13,12 +15,14 @@ KeyTransformNumericLookupMixin, ) -from ..lookups import builtin_lookup +from ..lookups import builtin_lookup, is_constant_value from ..query_utils import process_lhs, process_rhs -def build_json_mql_path(lhs, key_transforms): +def build_json_mql_path(lhs, key_transforms, as_path=False): # Build the MQL path using the collected key transforms. + if as_path: + return ".".join(chain([lhs], key_transforms)) result = lhs for key in key_transforms: get_field = {"$getField": {"input": result, "field": key}} @@ -45,8 +49,12 @@ def data_contains(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("contains lookup is not supported on this database backend.") -def _has_key_predicate(path, root_column, negated=False): +def _has_key_predicate(path, root_column, negated=False, as_path=False): """Return MQL to check for the existence of `path`.""" + if as_path: + if not negated: + return {"$and": [{path: {"$exists": True}}, {path: {"$ne": None}}]} + return {"$or": [{path: {"$exists": False}}, {path: None}]} result = { "$and": [ # The path must exist (i.e. not be "missing"). @@ -64,18 +72,20 @@ def _has_key_predicate(path, root_column, negated=False): def has_key_lookup(self, compiler, connection): """Return MQL to check for the existence of a key.""" rhs = self.rhs + as_path = is_constant_value(rhs) lhs = process_lhs(self, compiler, connection) if not isinstance(rhs, (list, tuple)): rhs = [rhs] paths = [] # Transform any "raw" keys into KeyTransforms to allow consistent handling # in the code that follows. + for key in rhs: rhs_json_path = key if isinstance(key, KeyTransform) else KeyTransform(key, self.lhs) - paths.append(rhs_json_path.as_mql(compiler, connection)) + paths.append(rhs_json_path.as_mql(compiler, connection, as_path=as_path)) keys = [] for path in paths: - keys.append(_has_key_predicate(path, lhs)) + keys.append(_has_key_predicate(path, lhs, as_path=as_path)) if self.mongo_operator is None: return keys[0] return {self.mongo_operator: keys} @@ -93,7 +103,7 @@ def json_exact_process_rhs(self, compiler, connection): ) -def key_transform(self, compiler, connection): +def key_transform(self, compiler, connection, **extra): """ Return MQL for this KeyTransform (JSON path). @@ -108,8 +118,8 @@ def key_transform(self, compiler, connection): while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs - lhs_mql = previous.as_mql(compiler, connection) - return build_json_mql_path(lhs_mql, key_transforms) + lhs_mql = previous.as_mql(compiler, connection, **extra) + return build_json_mql_path(lhs_mql, key_transforms, **extra) def key_transform_in(self, compiler, connection): diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index c45800a0..3009ef5f 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -146,10 +146,12 @@ def preserve_null(operator): def wrapped(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) return { - "$cond": { - "if": connection.mongo_operators["isnull"](lhs_mql, True), - "then": None, - "else": {f"${operator}": lhs_mql}, + "$expr": { + "$cond": { + "if": connection.mongo_operators_expr["isnull"](lhs_mql, True), + "then": None, + "else": {f"${operator}": lhs_mql}, + } } } diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 8dda2bab..168a30fa 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -1,4 +1,5 @@ from django.db import NotSupportedError +from django.db.models.expressions import Value from django.db.models.fields.related_lookups import In, RelatedIn from django.db.models.lookups import ( BuiltinLookup, @@ -8,13 +9,21 @@ UUIDTextMixin, ) -from .query_utils import process_lhs, process_rhs +from .query_utils import is_direct_value, process_lhs, process_rhs -def builtin_lookup(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) +def is_constant_value(value): + return is_direct_value(value) or isinstance(value, Value) + + +def builtin_lookup(self, compiler, connection, as_expr=False): value = process_rhs(self, compiler, connection) - return connection.mongo_operators[self.lookup_name](lhs_mql, value) + if is_constant_value(self.rhs) and not as_expr: + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return connection.mongo_operators_match[self.lookup_name](lhs_mql, value) + + lhs_mql = process_lhs(self, compiler, connection) + return {"$expr": connection.mongo_operators_expr[self.lookup_name](lhs_mql, value)} _field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter @@ -75,11 +84,14 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) ] -def is_null(self, compiler, connection): +def is_null(self, compiler, connection, as_expr=False): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") + if is_constant_value(self.rhs) and not as_expr: + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return connection.mongo_operators_match["isnull"](lhs_mql, self.rhs) lhs_mql = process_lhs(self, compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, self.rhs) + return {"$expr": connection.mongo_operators_expr["isnull"](lhs_mql, self.rhs)} # from https://www.pcre.org/current/doc/html/pcre2pattern.html#SEC4 diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index c86b8721..bed64032 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -211,6 +211,7 @@ def _get_reroot_replacements(expression): compiler, connection ) ) + extra_conditions = {"$and": extra_conditions} if extra_conditions else {} lookup_pipeline = [ { "$lookup": { @@ -236,8 +237,8 @@ def _get_reroot_replacements(expression): {"$eq": [f"$${parent_template}{i}", field]} for i, field in enumerate(rhs_fields) ] - + extra_conditions - } + }, + **extra_conditions, } } ], @@ -331,7 +332,7 @@ def where_node(self, compiler, connection): raise FullResultSet if self.negated and mql: - mql = {"$not": mql} + mql = {"$nor": mql} return mql diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index 4b744241..391c7a41 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -7,7 +7,7 @@ def is_direct_value(node): return not hasattr(node, "as_sql") -def process_lhs(node, compiler, connection): +def process_lhs(node, compiler, connection, **extra): if not hasattr(node, "lhs"): # node is a Func or Expression, possibly with multiple source expressions. result = [] @@ -15,16 +15,16 @@ def process_lhs(node, compiler, connection): if expr is None: continue try: - result.append(expr.as_mql(compiler, connection)) + result.append(expr.as_mql(compiler, connection, **extra)) except FullResultSet: - result.append(Value(True).as_mql(compiler, connection)) + result.append(Value(True).as_mql(compiler, connection, **extra)) if isinstance(node, Aggregate): return result[0] return result # node is a Transform with just one source expression, aliased as "lhs". if is_direct_value(node.lhs): return node - return node.lhs.as_mql(compiler, connection) + return node.lhs.as_mql(compiler, connection, **extra) def process_rhs(node, compiler, connection): @@ -47,7 +47,14 @@ def process_rhs(node, compiler, connection): return value -def regex_match(field, regex_vals, insensitive=False): +def regex_expr(field, regex_vals, insensitive=False): regex = {"$concat": regex_vals} if isinstance(regex_vals, tuple) else regex_vals options = "i" if insensitive else "" return {"$regexMatch": {"input": field, "regex": regex, "options": options}} + + +def regex_match(field, regex_vals, insensitive=False): + regex = {"$concat": regex_vals} if isinstance(regex_vals, tuple) else regex_vals + options = "i" if insensitive else "" + # return {"$regexMatch": {"input": field, "regex": regex, "options": options}} + return {field: {"$regex": regex, "$options": options}} From 97eb697558257ff6d5df1a955f4979e6ae8e6432 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Wed, 10 Sep 2025 20:47:04 -0300 Subject: [PATCH 2/3] WIP. --- django_mongodb_backend/base.py | 10 +++---- .../expressions/builtins.py | 27 ++++++++++++++----- django_mongodb_backend/lookups.py | 8 ++++-- django_mongodb_backend/query.py | 12 ++++----- django_mongodb_backend/query_utils.py | 3 +-- 5 files changed, 37 insertions(+), 23 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index 4699cf94..08afb71f 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -162,11 +162,11 @@ def _isnull_operator_match(a, b): {"$or": [DatabaseWrapper._isnull_operator_match(b[1], True), {a: {"$lte": b[1]}}]}, ] }, - "iexact": lambda a, b: regex_match(a, ("^", b, {"$literal": "$"}), insensitive=True), - "startswith": lambda a, b: regex_match(a, ("^", b)), - "istartswith": lambda a, b: regex_match(a, ("^", b), insensitive=True), - "endswith": lambda a, b: regex_match(a, (b, {"$literal": "$"})), - "iendswith": lambda a, b: regex_match(a, (b, {"$literal": "$"}), insensitive=True), + "iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True), + "startswith": lambda a, b: regex_match(a, f"^{b}"), + "istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True), + "endswith": lambda a, b: regex_match(a, f"{b}$"), + "iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True), "contains": lambda a, b: regex_match(a, b), "icontains": lambda a, b: regex_match(a, b, insensitive=True), "regex": lambda a, b: regex_match(a, b), diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 6514ca3a..40c1d628 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -103,7 +103,7 @@ def order_by(self, compiler, connection): return self.expression.as_mql(compiler, connection) -def query(self, compiler, connection, get_wrapping_pipeline=None): +def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False): subquery_compiler = self.get_compiler(connection=connection) subquery_compiler.pre_sql_setup(with_col_aliases=False) field_name, expr = subquery_compiler.columns[0] @@ -145,6 +145,8 @@ def query(self, compiler, connection, get_wrapping_pipeline=None): # Erase project_fields since the required value is projected above. subquery.project_fields = None compiler.subqueries.append(subquery) + if as_path: + return f"{table_output}.{field_name}" return f"${table_output}.{field_name}" @@ -167,20 +169,31 @@ def ref(self, compiler, connection, as_path=False): # noqa: ARG001 return f"{prefix}{refs}" -def star(self, compiler, connection): # noqa: ARG001 +def star(self, compiler, connection, **extra): # noqa: ARG001 return {"$literal": True} -def subquery(self, compiler, connection, get_wrapping_pipeline=None): - return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) +def subquery(self, compiler, connection, get_wrapping_pipeline=None, **extra): + return self.query.as_mql( + compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, **extra + ) -def exists(self, compiler, connection, get_wrapping_pipeline=None): +def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False, **extra): try: - lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) + lhs_mql = subquery( + self, + compiler, + connection, + get_wrapping_pipeline=get_wrapping_pipeline, + as_path=as_path, + **extra, + ) except EmptyResultSet: return Value(False).as_mql(compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, False) + if as_path: + return connection.mongo_operators_match["isnull"](lhs_mql, False) + return connection.mongo_operators_expr["isnull"](lhs_mql, False) def when(self, compiler, connection, **extra): diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 168a30fa..86c57961 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -1,5 +1,5 @@ from django.db import NotSupportedError -from django.db.models.expressions import Value +from django.db.models.expressions import Col, Ref, Value from django.db.models.fields.related_lookups import In, RelatedIn from django.db.models.lookups import ( BuiltinLookup, @@ -16,9 +16,13 @@ def is_constant_value(value): return is_direct_value(value) or isinstance(value, Value) +def is_simple_column(lhs): + return isinstance(lhs, Col | Ref) + + def builtin_lookup(self, compiler, connection, as_expr=False): value = process_rhs(self, compiler, connection) - if is_constant_value(self.rhs) and not as_expr: + if is_simple_column(self.lhs) and is_constant_value(self.rhs) and not as_expr: lhs_mql = process_lhs(self, compiler, connection, as_path=True) return connection.mongo_operators_match[self.lookup_name](lhs_mql, value) diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index bed64032..a6f233fe 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -11,8 +11,6 @@ from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError -from .query_conversion.query_optimizer import convert_expr_to_match - def wrap_database_errors(func): @wraps(func) @@ -89,7 +87,7 @@ def get_pipeline(self): for query in self.subqueries or (): pipeline.extend(query.get_pipeline()) if self.match_mql: - pipeline.extend(convert_expr_to_match(self.match_mql)) + pipeline.append({"$match": self.match_mql}) if self.aggregation_pipeline: pipeline.extend(self.aggregation_pipeline) if self.project_fields: @@ -275,7 +273,7 @@ def _get_reroot_replacements(expression): return lookup_pipeline -def where_node(self, compiler, connection): +def where_node(self, compiler, connection, **extra): if self.connector == AND: full_needed, empty_needed = len(self.children), 1 else: @@ -298,14 +296,14 @@ def where_node(self, compiler, connection): if len(self.children) > 2: rhs_sum = Mod(rhs_sum, 2) rhs = Exact(1, rhs_sum) - return self.__class__([lhs, rhs], AND, self.negated).as_mql(compiler, connection) + return self.__class__([lhs, rhs], AND, self.negated).as_mql(compiler, connection, **extra) else: operator = "$or" children_mql = [] for child in self.children: try: - mql = child.as_mql(compiler, connection) + mql = child.as_mql(compiler, connection, **extra) except EmptyResultSet: empty_needed -= 1 except FullResultSet: @@ -332,7 +330,7 @@ def where_node(self, compiler, connection): raise FullResultSet if self.negated and mql: - mql = {"$nor": mql} + mql = {"$nor": [mql]} return mql diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index 391c7a41..77f2bb8a 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -53,8 +53,7 @@ def regex_expr(field, regex_vals, insensitive=False): return {"$regexMatch": {"input": field, "regex": regex, "options": options}} -def regex_match(field, regex_vals, insensitive=False): - regex = {"$concat": regex_vals} if isinstance(regex_vals, tuple) else regex_vals +def regex_match(field, regex, insensitive=False): options = "i" if insensitive else "" # return {"$regexMatch": {"input": field, "regex": regex, "options": options}} return {field: {"$regex": regex, "$options": options}} From 35a2abd96a9b4140977e980a03e243721b050caa Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Thu, 11 Sep 2025 19:34:50 -0300 Subject: [PATCH 3/3] wip --- django_mongodb_backend/compiler.py | 4 ++-- .../expressions/builtins.py | 23 ++++++++++--------- django_mongodb_backend/functions.py | 14 +++++++---- django_mongodb_backend/lookups.py | 6 ++--- django_mongodb_backend/query.py | 3 ++- 5 files changed, 28 insertions(+), 22 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index bd259694..d6436341 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -334,7 +334,7 @@ def pre_sql_setup(self, with_col_aliases=False): pipeline.extend(query.get_pipeline()) # Remove the added subqueries. self.subqueries = [] - pipeline.append({"$match": {"$expr": having}}) + pipeline.append({"$match": having}) self.aggregation_pipeline = pipeline self.annotations = { target: expr.replace_expressions(all_replacements) @@ -707,7 +707,7 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False # For brevity/simplicity, project {"field_name": 1} # instead of {"field_name": "$field_name"}. if isinstance(expr, Col) and name == expr.target.column and not force_expression - else expr.as_mql(self, self.connection) + else expr.as_mql(self, self.connection, as_expr=force_expression) ) except EmptyResultSet: empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 40c1d628..68356248 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -28,7 +28,8 @@ from ..query_utils import process_lhs -def case(self, compiler, connection): +# EXTRA IS TOTALLY IGNORED +def case(self, compiler, connection, **extra): # noqa: ARG001 case_parts = [] for case in self.cases: case_mql = {} @@ -53,7 +54,7 @@ def case(self, compiler, connection): } -def col(self, compiler, connection, as_path=False): # noqa: ARG001 +def col(self, compiler, connection, as_path=False, as_expr=None): # noqa: ARG001 # If the column is part of a subquery and belongs to one of the parent # queries, it will be stored for reference using $let in a $lookup stage. # If the query is built with `alias_cols=False`, treat the column as @@ -71,7 +72,7 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001 # Add the column's collection's alias for columns in joined collections. has_alias = self.alias and self.alias != compiler.collection_name prefix = f"{self.alias}." if has_alias else "" - if not as_path: + if not as_path or as_expr: prefix = f"${prefix}" return f"{prefix}{self.target.column}" @@ -83,16 +84,16 @@ def col_pairs(self, compiler, connection): return cols[0].as_mql(compiler, connection) -def combined_expression(self, compiler, connection): +def combined_expression(self, compiler, connection, **extra): expressions = [ - self.lhs.as_mql(compiler, connection), - self.rhs.as_mql(compiler, connection), + self.lhs.as_mql(compiler, connection, **extra), + self.rhs.as_mql(compiler, connection, **extra), ] return connection.ops.combine_expression(self.connector, expressions) -def expression_wrapper(self, compiler, connection): - return self.expression.as_mql(compiler, connection) +def expression_wrapper(self, compiler, connection, **extra): + return self.expression.as_mql(compiler, connection, **extra) def negated_expression(self, compiler, connection): @@ -103,7 +104,7 @@ def order_by(self, compiler, connection): return self.expression.as_mql(compiler, connection) -def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False): +def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False, as_expr=None): subquery_compiler = self.get_compiler(connection=connection) subquery_compiler.pre_sql_setup(with_col_aliases=False) field_name, expr = subquery_compiler.columns[0] @@ -145,7 +146,7 @@ def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False) # Erase project_fields since the required value is projected above. subquery.project_fields = None compiler.subqueries.append(subquery) - if as_path: + if as_path and not as_expr: return f"{table_output}.{field_name}" return f"${table_output}.{field_name}" @@ -200,7 +201,7 @@ def when(self, compiler, connection, **extra): return self.condition.as_mql(compiler, connection, **extra) -def value(self, compiler, connection): # noqa: ARG001 +def value(self, compiler, connection, **extra): # noqa: ARG001 value = self.value if isinstance(value, (list, int)): # Wrap lists & numbers in $literal to prevent ambiguity when Value diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index 3009ef5f..b7009e70 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -65,7 +65,11 @@ } -def cast(self, compiler, connection): +# TODO: ALL THOSE FUNCTION MAY CHECK AS_EXPR OR AS_PATH=FALSE. JUST NEED TO REVIEW ALL THE +# TEST THAT HAVE THOSE OPERATOR. + + +def cast(self, compiler, connection, **extra): # noqa: ARG001 output_type = connection.data_types[self.output_field.get_internal_type()] lhs_mql = process_lhs(self, compiler, connection)[0] if max_length := self.output_field.max_length: @@ -95,7 +99,7 @@ def cot(self, compiler, connection): return {"$divide": [1, {"$tan": lhs_mql}]} -def extract(self, compiler, connection): +def extract(self, compiler, connection, **extra): # noqa: ARG001 lhs_mql = process_lhs(self, compiler, connection) operator = EXTRACT_OPERATORS.get(self.lookup_name) if operator is None: @@ -105,7 +109,7 @@ def extract(self, compiler, connection): return {f"${operator}": lhs_mql} -def func(self, compiler, connection): +def func(self, compiler, connection, **extra): # noqa: ARG001 lhs_mql = process_lhs(self, compiler, connection) if self.function is None: raise NotSupportedError(f"{self} may need an as_mql() method.") @@ -117,7 +121,7 @@ def left(self, compiler, connection): return self.get_substr().as_mql(compiler, connection) -def length(self, compiler, connection): +def length(self, compiler, connection, as_path=False, as_expr=None): # noqa: ARG001 # Check for null first since $strLenCP only accepts strings. lhs_mql = process_lhs(self, compiler, connection) return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}} @@ -194,7 +198,7 @@ def wrapped(self, compiler, connection): return wrapped -def trunc(self, compiler, connection): +def trunc(self, compiler, connection, **extra): # noqa: ARG001 lhs_mql = process_lhs(self, compiler, connection) lhs_mql = {"date": lhs_mql, "unit": self.kind, "startOfWeek": "mon"} if timezone := self.get_tzname(): diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 86c57961..165cde33 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -46,14 +46,14 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param): return sql, sql_params -def in_(self, compiler, connection): +def in_(self, compiler, connection, **extra): db_rhs = getattr(self.rhs, "_db", None) if db_rhs is not None and db_rhs != connection.alias: raise ValueError( "Subqueries aren't allowed across different databases. Force " "the inner query to be evaluated using `list(inner_query)`." ) - return builtin_lookup(self, compiler, connection) + return builtin_lookup(self, compiler, connection, **extra) def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001 @@ -91,7 +91,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) def is_null(self, compiler, connection, as_expr=False): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") - if is_constant_value(self.rhs) and not as_expr: + if is_constant_value(self.rhs) and not as_expr and is_simple_column(self.lhs): lhs_mql = process_lhs(self, compiler, connection, as_path=True) return connection.mongo_operators_match["isnull"](lhs_mql, self.rhs) lhs_mql = process_lhs(self, compiler, connection) diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index a6f233fe..94789e00 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -329,8 +329,9 @@ def where_node(self, compiler, connection, **extra): if not mql: raise FullResultSet + as_expr = extra.get("as_expr") if self.negated and mql: - mql = {"$nor": [mql]} + mql = {"$nor": [mql]} if not as_expr else {"$not": [mql]} return mql