diff --git a/django_mongodb_backend/query_conversion/expression_converters.py b/django_mongodb_backend/query_conversion/expression_converters.py index b8362f3e..9972d23a 100644 --- a/django_mongodb_backend/query_conversion/expression_converters.py +++ b/django_mongodb_backend/query_conversion/expression_converters.py @@ -5,6 +5,45 @@ class BaseConverter: def convert(cls, expr): raise NotImplementedError("Subclasses must implement this method.") + @classmethod + def is_simple_path_name(cls, field_name): + return ( + isinstance(field_name, str) + and field_name != "" + and field_name.startswith("$") + # Case for catching variables + and not field_name.startswith("$$") + ) + + @classmethod + def is_simple_get_field(cls, get_field_object): + if not isinstance(get_field_object, dict): + return False + + get_field_expr = get_field_object.get("$getField") + + if ( + isinstance(get_field_expr, dict) + and "input" in get_field_expr + and "field" in get_field_expr + ): + input_expr = get_field_expr["input"] + field_name = get_field_expr["field"] + return cls.convert_path_name(input_expr) and ( + isinstance(field_name, str) and "$" not in field_name and "." not in field_name + ) + return False + + @classmethod + def convert_path_name(cls, field_name): + if cls.is_simple_path_name(field_name): + return field_name[1:] + if cls.is_simple_get_field(field_name): + get_field_input = field_name["$getField"]["input"] + get_field_field_name = field_name["$getField"]["field"] + return f"{cls.convert_path_name(get_field_input)}.{get_field_field_name}" + return None + @classmethod def is_simple_value(cls, value): """Is the value is a simple type (not a dict)?""" @@ -14,7 +53,6 @@ def is_simple_value(cls, value): return False if isinstance(value, (list, tuple, set)): return all(cls.is_simple_value(v) for v in value) - # TODO: Support `$getField` conversion. return not isinstance(value, dict) @@ -27,7 +65,7 @@ class BinaryConverter(BaseConverter): {"$gt": ["$price", 100]} } is converted to: - {"$gt": ["price", 100]} + {"price": {"$gt": 100}} """ operator: str @@ -37,15 +75,14 @@ def convert(cls, args): if isinstance(args, list) and len(args) == 2: field_expr, value = args # Check if first argument is a simple field reference. - if ( - isinstance(field_expr, str) - and field_expr.startswith("$") - and cls.is_simple_value(value) - ): - field_name = field_expr[1:] # Remove the $ prefix. + if (field_name := cls.convert_path_name(field_expr)) and cls.is_simple_value(value): if cls.operator == "$eq": - return {field_name: value} - return {field_name: {cls.operator: value}} + query = {field_name: value} + else: + query = {field_name: {cls.operator: value}} + if value is None: + query = {"$and": [{field_name: {"$exists": True}}, query]} + return query return None @@ -96,13 +133,18 @@ class InConverter(BaseConverter): def convert(cls, in_args): if isinstance(in_args, list) and len(in_args) == 2: field_expr, values = in_args - # Check if first argument is a simple field reference. - if isinstance(field_expr, str) and field_expr.startswith("$"): - field_name = field_expr[1:] # Remove the $ prefix. - if isinstance(values, (list, tuple, set)) and all( - cls.is_simple_value(v) for v in values - ): - return {field_name: {"$in": values}} + # Check if first argument is a simple field reference + # Check if second argument is a list of simple values + if (field_name := cls.convert_path_name(field_expr)) and ( + isinstance(values, list | tuple | set) + and all(cls.is_simple_value(v) for v in values) + ): + core_check = {field_name: {"$in": values}} + return ( + {"$and": [{field_name: {"$exists": True}}, core_check]} + if None in values + else core_check + ) return None diff --git a/tests/expression_converter_/test_match_conversion.py b/tests/expression_converter_/test_match_conversion.py index e78e5c0c..47cc5011 100644 --- a/tests/expression_converter_/test_match_conversion.py +++ b/tests/expression_converter_/test_match_conversion.py @@ -213,3 +213,49 @@ def test_deeply_nested_logical_operator_with_variable(self): } ] self.assertOptimizerEqual(expr, expected) + + def test_getfield_usage_on_dual_binary_operator(self): + expr = { + "$expr": { + "$gt": [ + {"$getField": {"input": "$price", "field": "value"}}, + {"$getField": {"input": "$discounted_price", "field": "value"}}, + ] + } + } + expected = [ + { + "$match": { + "$expr": { + "$gt": [ + {"$getField": {"input": "$price", "field": "value"}}, + {"$getField": {"input": "$discounted_price", "field": "value"}}, + ] + } + } + } + ] + self.assertOptimizerEqual(expr, expected) + + def test_getfield_usage_on_onesided_binary_operator(self): + expr = {"$expr": {"$gt": [{"$getField": {"input": "$price", "field": "value"}}, 100]}} + # This should create a proper match condition with no $expr + expected = [{"$match": {"price.value": {"$gt": 100}}}] + self.assertOptimizerEqual(expr, expected) + + def test_nested_getfield_usage_on_onesided_binary(self): + expr = { + "$expr": { + "$gt": [ + { + "$getField": { + "input": {"$getField": {"input": "$item", "field": "price"}}, + "field": "value", + } + }, + 100, + ] + } + } + expected = [{"$match": {"item.price.value": {"$gt": 100}}}] + self.assertOptimizerEqual(expr, expected) diff --git a/tests/expression_converter_/test_op_expressions.py b/tests/expression_converter_/test_op_expressions.py index ce4caf2d..80624834 100644 --- a/tests/expression_converter_/test_op_expressions.py +++ b/tests/expression_converter_/test_op_expressions.py @@ -7,6 +7,12 @@ from django_mongodb_backend.query_conversion.expression_converters import convert_expression +def _wrap_condition_if_null(_type, condition, path): + if _type is None: + return {"$and": [{path: {"$exists": True}}, condition]} + return condition + + class ConversionTestCase(SimpleTestCase): CONVERTIBLE_TYPES = { "int": 42, @@ -33,6 +39,50 @@ def _test_conversion_various_types(self, conversion_test): with self.subTest(_type=_type, val=val): conversion_test(val) + def _test_conversion_getfield(self, logical_op, value=10): + expr = {logical_op: [{"$getField": {"input": "$item", "field": "age"}}, value]} + self.assertConversionEqual( + expr, {"item.age": value} if logical_op == "$eq" else {"item.age": {logical_op: value}} + ) + + def _test_conversion_nested_getfield(self, logical_op, value=10): + expr = { + logical_op: [ + { + "$getField": { + "input": {"$getField": {"input": "$item", "field": "shelf_life"}}, + "field": "age", + } + }, + value, + ] + } + self.assertConversionEqual( + expr, + {"item.shelf_life.age": value} + if logical_op == "$eq" + else {"item.shelf_life.age": {logical_op: value}}, + ) + + def _test_conversion_dual_getfield_ineligible(self, logical_op): + expr = { + logical_op: [ + { + "$getField": { + "input": "$root", + "field": "age", + } + }, + { + "$getField": { + "input": "$value", + "field": "age", + } + }, + ] + } + self.assertNotOptimizable(expr) + class ExpressionTests(ConversionTestCase): def test_non_dict(self): @@ -53,10 +103,14 @@ def test_no_conversion_dict_value(self): self.assertNotOptimizable({"$eq": ["$status", {"$gt": 5}]}) def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$eq": ["$age", _type]}, {"age": _type}) + self.assertConversionEqual( + {"$eq": ["$age", _type]}, _wrap_condition_if_null(_type, {"age": _type}, "age") + ) def _test_conversion_valid_array_type(self, _type): - self.assertConversionEqual({"$eq": ["$age", _type]}, {"age": _type}) + self.assertConversionEqual( + {"$eq": ["$age", _type]}, _wrap_condition_if_null(_type, {"age": _type}, "age") + ) def test_conversion_various_types(self): self._test_conversion_various_types(self._test_conversion_valid_type) @@ -64,6 +118,15 @@ def test_conversion_various_types(self): def test_conversion_various_array_types(self): self._test_conversion_various_types(self._test_conversion_valid_array_type) + def test_conversion_getfield(self): + self._test_conversion_getfield("$eq") + + def test_conversion_nested_getfield(self): + self._test_conversion_nested_getfield("$eq") + + def test_conversion_dual_getfield_ineligible(self): + self._test_conversion_dual_getfield_ineligible("$eq") + class InTests(ConversionTestCase): def test_conversion(self): @@ -78,13 +141,43 @@ def test_no_conversion_dict_value(self): self.assertNotOptimizable({"$in": ["$status", [{"bad": "val"}]]}) def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$in": ["$age", [_type]]}, {"age": {"$in": [_type]}}) + self.assertConversionEqual( + {"$in": ["$age", [_type]]}, + _wrap_condition_if_null(_type, {"age": {"$in": [_type]}}, "age"), + ) def test_conversion_various_types(self): for _type, val in self.CONVERTIBLE_TYPES.items(): with self.subTest(_type=_type, val=val): self._test_conversion_valid_type(val) + def test_conversion_getfield(self): + self._test_conversion_getfield("$in", [10]) + + def test_conversion_nested_getfield(self): + self._test_conversion_nested_getfield("$in", [10]) + + def test_conversion_dual_getfield_ineligible(self): + expr = { + "$in": [ + { + "$getField": { + "input": "$root", + "field": "age", + } + }, + [ + { + "$getField": { + "input": "$value", + "field": "age", + } + } + ], + ] + } + self.assertNotOptimizable(expr) + class LogicalTests(ConversionTestCase): def test_and(self): @@ -146,6 +239,7 @@ def test_mixed(self): {"$in": ["$category", ["electronics", "books"]]}, {"$eq": ["$verified", True]}, {"$lte": ["$price", 2000]}, + {"$eq": [{"$getField": {"input": "$root", "field": "age"}}, 10]}, ] } expected = { @@ -154,6 +248,7 @@ def test_mixed(self): {"category": {"$in": ["electronics", "books"]}}, {"verified": True}, {"price": {"$lte": 2000}}, + {"root.age": 10}, ] } self.assertConversionEqual(expr, expected) @@ -170,11 +265,23 @@ def test_no_conversion_dict_value(self): self.assertNotOptimizable({"$gt": ["$price", {}]}) def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$gt": ["$price", _type]}, {"price": {"$gt": _type}}) + self.assertConversionEqual( + {"$gt": ["$price", _type]}, + _wrap_condition_if_null(_type, {"price": {"$gt": _type}}, "price"), + ) def test_conversion_various_types(self): self._test_conversion_various_types(self._test_conversion_valid_type) + def test_conversion_getfield(self): + self._test_conversion_getfield("$gt") + + def test_conversion_nested_getfield(self): + self._test_conversion_nested_getfield("$gt") + + def test_conversion_dual_getfield_ineligible(self): + self._test_conversion_dual_getfield_ineligible("$gt") + class GteTests(ConversionTestCase): def test_conversion(self): @@ -193,11 +300,20 @@ def test_no_conversion_dict_value(self): def _test_conversion_valid_type(self, _type): expr = {"$gte": ["$price", _type]} expected = {"price": {"$gte": _type}} - self.assertConversionEqual(expr, expected) + self.assertConversionEqual(expr, _wrap_condition_if_null(_type, expected, "price")) def test_conversion_various_types(self): self._test_conversion_various_types(self._test_conversion_valid_type) + def test_conversion_getfield(self): + self._test_conversion_getfield("$gte") + + def test_conversion_nested_getfield(self): + self._test_conversion_nested_getfield("$gte") + + def test_conversion_dual_getfield_ineligible(self): + self._test_conversion_dual_getfield_ineligible("$gte") + class LtTests(ConversionTestCase): def test_conversion(self): @@ -210,11 +326,23 @@ def test_no_conversion_dict_value(self): self.assertNotOptimizable({"$lt": ["$price", {}]}) def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$lt": ["$price", _type]}, {"price": {"$lt": _type}}) + self.assertConversionEqual( + {"$lt": ["$price", _type]}, + _wrap_condition_if_null(_type, {"price": {"$lt": _type}}, "price"), + ) def test_conversion_various_types(self): self._test_conversion_various_types(self._test_conversion_valid_type) + def test_conversion_getfield(self): + self._test_conversion_getfield("$lt") + + def test_conversion_nested_getfield(self): + self._test_conversion_nested_getfield("$lt") + + def test_conversion_dual_getfield_ineligible(self): + self._test_conversion_dual_getfield_ineligible("$lt") + class LteTests(ConversionTestCase): def test_conversion(self): @@ -227,7 +355,19 @@ def test_no_conversion_dict_value(self): self.assertNotOptimizable({"$lte": ["$price", {}]}) def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$lte": ["$price", _type]}, {"price": {"$lte": _type}}) + self.assertConversionEqual( + {"$lte": ["$price", _type]}, + _wrap_condition_if_null(_type, {"price": {"$lte": _type}}, "price"), + ) def test_conversion_various_types(self): self._test_conversion_various_types(self._test_conversion_valid_type) + + def test_conversion_getfield(self): + self._test_conversion_getfield("$lte") + + def test_conversion_nested_getfield(self): + self._test_conversion_nested_getfield("$lte") + + def test_conversion_dual_getfield_ineligible(self): + self._test_conversion_dual_getfield_ineligible("$lte")