Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions django_mongodb_backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .features import DatabaseFeatures
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .query_utils import regex_match
from .query_utils import regex_expr, regex_match
from .schema import DatabaseSchemaEditor
from .utils import OperationDebugWrapper
from .validation import DatabaseValidation
Expand Down Expand Up @@ -108,7 +108,12 @@ def _isnull_operator(a, b):
}
return is_null if b else {"$not": is_null}

mongo_operators = {
def _isnull_operator_match(a, b):
if b:
return {"$or": [{a: {"$exists": False}}, {a: None}]}
return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]}

mongo_operators_expr = {
"exact": lambda a, b: {"$eq": [a, b]},
"gt": lambda a, b: {"$gt": [a, b]},
"gte": lambda a, b: {"$gte": [a, b]},
Expand All @@ -126,11 +131,42 @@ def _isnull_operator(a, b):
{"$or": [DatabaseWrapper._isnull_operator(b[1], True), {"$lte": [a, b[1]]}]},
]
},
"iexact": lambda a, b: regex_match(a, ("^", b, {"$literal": "$"}), insensitive=True),
"startswith": lambda a, b: regex_match(a, ("^", b)),
"istartswith": lambda a, b: regex_match(a, ("^", b), insensitive=True),
"endswith": lambda a, b: regex_match(a, (b, {"$literal": "$"})),
"iendswith": lambda a, b: regex_match(a, (b, {"$literal": "$"}), insensitive=True),
"iexact": lambda a, b: regex_expr(a, ("^", b, {"$literal": "$"}), insensitive=True),
"startswith": lambda a, b: regex_expr(a, ("^", b)),
"istartswith": lambda a, b: regex_expr(a, ("^", b), insensitive=True),
"endswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"})),
"iendswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"}), insensitive=True),
"contains": lambda a, b: regex_expr(a, b),
"icontains": lambda a, b: regex_expr(a, b, insensitive=True),
"regex": lambda a, b: regex_expr(a, b),
"iregex": lambda a, b: regex_expr(a, b, insensitive=True),
}

mongo_operators_match = {
"exact": lambda a, b: {a: b},
"gt": lambda a, b: {a: {"$gt": b}},
"gte": lambda a, b: {a: {"$gte": b}},
# MongoDB considers null less than zero. Exclude null values to match
# SQL behavior.
"lt": lambda a, b: {
"$and": [{a: {"$lt": b}}, DatabaseWrapper._isnull_operator_match(a, False)]
},
"lte": lambda a, b: {
"$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator_match(a, False)]
},
"in": lambda a, b: {a: {"$in": list(b)}},
"isnull": _isnull_operator_match,
"range": lambda a, b: {
"$and": [
{"$or": [DatabaseWrapper._isnull_operator_match(b[0], True), {a: {"$gte": b[0]}}]},
{"$or": [DatabaseWrapper._isnull_operator_match(b[1], True), {a: {"$lte": b[1]}}]},
]
},
"iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True),
"startswith": lambda a, b: regex_match(a, f"^{b}"),
"istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True),
"endswith": lambda a, b: regex_match(a, f"{b}$"),
"iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True),
"contains": lambda a, b: regex_match(a, b),
"icontains": lambda a, b: regex_match(a, b, insensitive=True),
"regex": lambda a, b: regex_match(a, b),
Expand Down
6 changes: 3 additions & 3 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def pre_sql_setup(self, with_col_aliases=False):
pipeline.extend(query.get_pipeline())
# Remove the added subqueries.
self.subqueries = []
pipeline.append({"$match": {"$expr": having}})
pipeline.append({"$match": having})
self.aggregation_pipeline = pipeline
self.annotations = {
target: expr.replace_expressions(all_replacements)
Expand Down Expand Up @@ -485,7 +485,7 @@ def build_query(self, columns=None):
except FullResultSet:
query.match_mql = {}
else:
query.match_mql = {"$expr": expr}
query.match_mql = expr
if extra_fields:
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
query.subqueries = self.subqueries
Expand Down Expand Up @@ -707,7 +707,7 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
# For brevity/simplicity, project {"field_name": 1}
# instead of {"field_name": "$field_name"}.
if isinstance(expr, Col) and name == expr.target.column and not force_expression
else expr.as_mql(self, self.connection)
else expr.as_mql(self, self.connection, as_expr=force_expression)
)
except EmptyResultSet:
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)
Expand Down
58 changes: 37 additions & 21 deletions django_mongodb_backend/expressions/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
from ..query_utils import process_lhs


def case(self, compiler, connection):
# EXTRA IS TOTALLY IGNORED
def case(self, compiler, connection, **extra): # noqa: ARG001
case_parts = []
for case in self.cases:
case_mql = {}
try:
case_mql["case"] = case.as_mql(compiler, connection)
case_mql["case"] = case.as_mql(compiler, connection, as_expr=True)
except EmptyResultSet:
continue
except FullResultSet:
Expand All @@ -53,7 +54,7 @@ def case(self, compiler, connection):
}


def col(self, compiler, connection, as_path=False): # noqa: ARG001
def col(self, compiler, connection, as_path=False, as_expr=None): # noqa: ARG001
# If the column is part of a subquery and belongs to one of the parent
# queries, it will be stored for reference using $let in a $lookup stage.
# If the query is built with `alias_cols=False`, treat the column as
Expand All @@ -71,7 +72,7 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
# Add the column's collection's alias for columns in joined collections.
has_alias = self.alias and self.alias != compiler.collection_name
prefix = f"{self.alias}." if has_alias else ""
if not as_path:
if not as_path or as_expr:
prefix = f"${prefix}"
return f"{prefix}{self.target.column}"

Expand All @@ -83,16 +84,16 @@ def col_pairs(self, compiler, connection):
return cols[0].as_mql(compiler, connection)


def combined_expression(self, compiler, connection):
def combined_expression(self, compiler, connection, **extra):
expressions = [
self.lhs.as_mql(compiler, connection),
self.rhs.as_mql(compiler, connection),
self.lhs.as_mql(compiler, connection, **extra),
self.rhs.as_mql(compiler, connection, **extra),
]
return connection.ops.combine_expression(self.connector, expressions)


def expression_wrapper(self, compiler, connection):
return self.expression.as_mql(compiler, connection)
def expression_wrapper(self, compiler, connection, **extra):
return self.expression.as_mql(compiler, connection, **extra)


def negated_expression(self, compiler, connection):
Expand All @@ -103,7 +104,7 @@ def order_by(self, compiler, connection):
return self.expression.as_mql(compiler, connection)


def query(self, compiler, connection, get_wrapping_pipeline=None):
def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False, as_expr=None):
subquery_compiler = self.get_compiler(connection=connection)
subquery_compiler.pre_sql_setup(with_col_aliases=False)
field_name, expr = subquery_compiler.columns[0]
Expand Down Expand Up @@ -145,14 +146,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None):
# Erase project_fields since the required value is projected above.
subquery.project_fields = None
compiler.subqueries.append(subquery)
if as_path and not as_expr:
return f"{table_output}.{field_name}"
return f"${table_output}.{field_name}"


def raw_sql(self, compiler, connection): # noqa: ARG001
raise NotSupportedError("RawSQL is not supported on MongoDB.")


def ref(self, compiler, connection): # noqa: ARG001
def ref(self, compiler, connection, as_path=False): # noqa: ARG001
prefix = (
f"{self.source.alias}."
if isinstance(self.source, Col) and self.source.alias != compiler.collection_name
Expand All @@ -162,30 +165,43 @@ def ref(self, compiler, connection): # noqa: ARG001
refs, _ = compiler.columns[self.ordinal - 1]
else:
refs = self.refs
return f"${prefix}{refs}"
if not as_path:
prefix = f"${prefix}"
return f"{prefix}{refs}"


def star(self, compiler, connection): # noqa: ARG001
def star(self, compiler, connection, **extra): # noqa: ARG001
return {"$literal": True}


def subquery(self, compiler, connection, get_wrapping_pipeline=None):
return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
def subquery(self, compiler, connection, get_wrapping_pipeline=None, **extra):
return self.query.as_mql(
compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, **extra
)


def exists(self, compiler, connection, get_wrapping_pipeline=None):
def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False, **extra):
try:
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
lhs_mql = subquery(
self,
compiler,
connection,
get_wrapping_pipeline=get_wrapping_pipeline,
as_path=as_path,
**extra,
)
except EmptyResultSet:
return Value(False).as_mql(compiler, connection)
return connection.mongo_operators["isnull"](lhs_mql, False)
if as_path:
return connection.mongo_operators_match["isnull"](lhs_mql, False)
return connection.mongo_operators_expr["isnull"](lhs_mql, False)


def when(self, compiler, connection):
return self.condition.as_mql(compiler, connection)
def when(self, compiler, connection, **extra):
return self.condition.as_mql(compiler, connection, **extra)


def value(self, compiler, connection): # noqa: ARG001
def value(self, compiler, connection, **extra): # noqa: ARG001
value = self.value
if isinstance(value, (list, int)):
# Wrap lists & numbers in $literal to prevent ambiguity when Value
Expand Down
26 changes: 18 additions & 8 deletions django_mongodb_backend/fields/json.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import chain

from django.db import NotSupportedError
from django.db.models.fields.json import (
ContainedBy,
Expand All @@ -13,12 +15,14 @@
KeyTransformNumericLookupMixin,
)

from ..lookups import builtin_lookup
from ..lookups import builtin_lookup, is_constant_value
from ..query_utils import process_lhs, process_rhs


def build_json_mql_path(lhs, key_transforms):
def build_json_mql_path(lhs, key_transforms, as_path=False):
# Build the MQL path using the collected key transforms.
if as_path:
return ".".join(chain([lhs], key_transforms))
result = lhs
for key in key_transforms:
get_field = {"$getField": {"input": result, "field": key}}
Expand All @@ -45,8 +49,12 @@ def data_contains(self, compiler, connection): # noqa: ARG001
raise NotSupportedError("contains lookup is not supported on this database backend.")


def _has_key_predicate(path, root_column, negated=False):
def _has_key_predicate(path, root_column, negated=False, as_path=False):
"""Return MQL to check for the existence of `path`."""
if as_path:
if not negated:
return {"$and": [{path: {"$exists": True}}, {path: {"$ne": None}}]}
return {"$or": [{path: {"$exists": False}}, {path: None}]}
result = {
"$and": [
# The path must exist (i.e. not be "missing").
Expand All @@ -64,18 +72,20 @@ def _has_key_predicate(path, root_column, negated=False):
def has_key_lookup(self, compiler, connection):
"""Return MQL to check for the existence of a key."""
rhs = self.rhs
as_path = is_constant_value(rhs)
lhs = process_lhs(self, compiler, connection)
if not isinstance(rhs, (list, tuple)):
rhs = [rhs]
paths = []
# Transform any "raw" keys into KeyTransforms to allow consistent handling
# in the code that follows.

for key in rhs:
rhs_json_path = key if isinstance(key, KeyTransform) else KeyTransform(key, self.lhs)
paths.append(rhs_json_path.as_mql(compiler, connection))
paths.append(rhs_json_path.as_mql(compiler, connection, as_path=as_path))
keys = []
for path in paths:
keys.append(_has_key_predicate(path, lhs))
keys.append(_has_key_predicate(path, lhs, as_path=as_path))
if self.mongo_operator is None:
return keys[0]
return {self.mongo_operator: keys}
Expand All @@ -93,7 +103,7 @@ def json_exact_process_rhs(self, compiler, connection):
)


def key_transform(self, compiler, connection):
def key_transform(self, compiler, connection, **extra):
"""
Return MQL for this KeyTransform (JSON path).

Expand All @@ -108,8 +118,8 @@ def key_transform(self, compiler, connection):
while isinstance(previous, KeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
lhs_mql = previous.as_mql(compiler, connection)
return build_json_mql_path(lhs_mql, key_transforms)
lhs_mql = previous.as_mql(compiler, connection, **extra)
return build_json_mql_path(lhs_mql, key_transforms, **extra)


def key_transform_in(self, compiler, connection):
Expand Down
24 changes: 15 additions & 9 deletions django_mongodb_backend/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@
}


def cast(self, compiler, connection):
# TODO: ALL THOSE FUNCTION MAY CHECK AS_EXPR OR AS_PATH=FALSE. JUST NEED TO REVIEW ALL THE
# TEST THAT HAVE THOSE OPERATOR.


def cast(self, compiler, connection, **extra): # noqa: ARG001
output_type = connection.data_types[self.output_field.get_internal_type()]
lhs_mql = process_lhs(self, compiler, connection)[0]
if max_length := self.output_field.max_length:
Expand Down Expand Up @@ -95,7 +99,7 @@ def cot(self, compiler, connection):
return {"$divide": [1, {"$tan": lhs_mql}]}


def extract(self, compiler, connection):
def extract(self, compiler, connection, **extra): # noqa: ARG001
lhs_mql = process_lhs(self, compiler, connection)
operator = EXTRACT_OPERATORS.get(self.lookup_name)
if operator is None:
Expand All @@ -105,7 +109,7 @@ def extract(self, compiler, connection):
return {f"${operator}": lhs_mql}


def func(self, compiler, connection):
def func(self, compiler, connection, **extra): # noqa: ARG001
lhs_mql = process_lhs(self, compiler, connection)
if self.function is None:
raise NotSupportedError(f"{self} may need an as_mql() method.")
Expand All @@ -117,7 +121,7 @@ def left(self, compiler, connection):
return self.get_substr().as_mql(compiler, connection)


def length(self, compiler, connection):
def length(self, compiler, connection, as_path=False, as_expr=None): # noqa: ARG001
# Check for null first since $strLenCP only accepts strings.
lhs_mql = process_lhs(self, compiler, connection)
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}}
Expand Down Expand Up @@ -146,10 +150,12 @@ def preserve_null(operator):
def wrapped(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
return {
"$cond": {
"if": connection.mongo_operators["isnull"](lhs_mql, True),
"then": None,
"else": {f"${operator}": lhs_mql},
"$expr": {
"$cond": {
"if": connection.mongo_operators_expr["isnull"](lhs_mql, True),
"then": None,
"else": {f"${operator}": lhs_mql},
}
}
}

Expand Down Expand Up @@ -192,7 +198,7 @@ def wrapped(self, compiler, connection):
return wrapped


def trunc(self, compiler, connection):
def trunc(self, compiler, connection, **extra): # noqa: ARG001
lhs_mql = process_lhs(self, compiler, connection)
lhs_mql = {"date": lhs_mql, "unit": self.kind, "startOfWeek": "mon"}
if timezone := self.get_tzname():
Expand Down
Loading
Loading