diff --git a/CHANGELOG.md b/CHANGELOG.md index fd948e20c..84f58dbc7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## dbt-databricks 1.11.3 (TBD) +### Fixes + +- Truncate (128 charactesr max) and escape special characters for default query tag values + ## dbt-databricks 1.11.2 (Nov 18, 2025) ### Fixes diff --git a/dbt/adapters/databricks/utils.py b/dbt/adapters/databricks/utils.py index 2abaf104b..1eb5b918d 100644 --- a/dbt/adapters/databricks/utils.py +++ b/dbt/adapters/databricks/utils.py @@ -142,6 +142,20 @@ def parse_query_tags(query_tags_str: Optional[str]) -> dict[str, str]: except json.JSONDecodeError as e: raise DbtValidationError(f"Invalid JSON in query_tags: {e}") + @staticmethod + def escape_tag_value(key: str, value: str, source: str = "") -> str: + """Escape special characters in tag values (backslash, comma, colon).""" + source_prefix = f"{source}: " if source else "" + + if re.search(r"[\\,:]", value): + logger.warning( + f"{source_prefix}Query tag value for key '{key}' contains unescaped " + f"character(s): {value}. Escaping..." + ) + value = value.replace("\\", r"\\").replace(",", r"\,").replace(":", r"\:") + + return value + @staticmethod def validate_query_tags(tags: dict[str, str], source: str = "") -> None: """Validate query tags for reserved keys and limits.""" @@ -156,17 +170,11 @@ def validate_query_tags(tags: dict[str, str], source: str = "") -> None: f"Reserved keys are: {', '.join(sorted(QueryTagsUtils.RESERVED_KEYS))}" ) - # Escape commas, colons, and backslashes in tag values + # Escape values (modifies dict in place) for key in tags.keys(): - value = tags[key] - if re.search(r"[\\,:]", value): - logger.warning( - f"{source_prefix}Query tag value for key '{key}' contains unescaped " - f"character(s): {value}. Escaping..." - ) - tags[key] = value.replace("\\", "\\\\").replace(",", "\\,").replace(":", "\\:") + tags[key] = QueryTagsUtils.escape_tag_value(key, tags[key], source) - # Validate that no tag value exceeds 128 characters + # Validate that no tag value exceeds 128 characters (after escaping) long_values = {k: v for k, v in tags.items() if len(v) > 128} if long_values: raise DbtValidationError( @@ -181,6 +189,28 @@ def validate_query_tags(tags: dict[str, str], source: str = "") -> None: f"Maximum allowed is {QueryTagsUtils.MAX_TAGS}" ) + @staticmethod + def process_default_tags(tags: dict[str, str]) -> dict[str, str]: + """ + Process default tags: truncate long values, then escape special characters. + + Note: We truncate BEFORE escaping to avoid cutting escape sequences in half, + which would create invalid sequences that can't be deserialized. + """ + processed = {} + for key, value in tags.items(): + if len(value) > 128: + logger.debug( + f"Default tags: Query tag value for key '{key}' exceeds 128 characters " + f"({len(value)} chars). Truncating to 128 characters." + ) + value = value[:128] + + escaped_value = QueryTagsUtils.escape_tag_value(key, value, "Default tags") + processed[key] = escaped_value + + return processed + @staticmethod def merge_query_tags( connection_tags: dict[str, str], @@ -191,23 +221,20 @@ def merge_query_tags( Merge query tags with precedence: model > connection > default. Validates that no reserved keys are used and tag limits are respected. """ - # All sources are now already parsed dicts - conn_tags = connection_tags - model_tags_dict = model_tags - default_tags_dict = default_tags + # Process default tags (escape and truncate, don't validate reserved keys) + processed_default_tags = QueryTagsUtils.process_default_tags(default_tags) - # Validate each source (user-provided tags cannot use reserved keys) - QueryTagsUtils.validate_query_tags(conn_tags, "Connection config") - QueryTagsUtils.validate_query_tags(model_tags_dict, "Model config") + # Validate user-provided tags (cannot use reserved keys) + QueryTagsUtils.validate_query_tags(connection_tags, "Connection config") + QueryTagsUtils.validate_query_tags(model_tags, "Model config") # Merge with precedence: model > connection > default merged = {} - merged.update(default_tags_dict) - merged.update(conn_tags) - merged.update(model_tags_dict) + merged.update(processed_default_tags) + merged.update(connection_tags) + merged.update(model_tags) - # Final validation of merged tags (only check total count, not reserved keys - # since default tags are allowed to use reserved keys) + # Final validation of merged tags (only check total count) if len(merged) > QueryTagsUtils.MAX_TAGS: raise DbtValidationError( f"Too many total query tags ({len(merged)}). " diff --git a/tests/unit/test_query_tags.py b/tests/unit/test_query_tags.py index 58e4ae3d9..a74f04fbd 100644 --- a/tests/unit/test_query_tags.py +++ b/tests/unit/test_query_tags.py @@ -88,25 +88,25 @@ def test_validate_query_tags_escapes_comma(self): """Test that commas in tag values are escaped.""" tags = {"team": "marketing,sales"} QueryTagsUtils.validate_query_tags(tags) - assert tags["team"] == "marketing\\,sales" + assert tags["team"] == r"marketing\,sales" def test_validate_query_tags_escapes_colon(self): """Test that colons in tag values are escaped.""" tags = {"description": "project:alpha"} QueryTagsUtils.validate_query_tags(tags) - assert tags["description"] == "project\\:alpha" + assert tags["description"] == r"project\:alpha" def test_validate_query_tags_escapes_backslash(self): """Test that backslashes in tag values are escaped.""" - tags = {"path": "folder\\subfolder"} + tags = {"path": r"folder\subfolder"} QueryTagsUtils.validate_query_tags(tags) - assert tags["path"] == "folder\\\\subfolder" + assert tags["path"] == r"folder\\subfolder" def test_validate_query_tags_escapes_multiple_special_chars(self): """Test that multiple special characters are all escaped.""" - tags = {"complex": "value:with,comma\\and\\backslash"} + tags = {"complex": r"value:with,comma\and\backslash"} QueryTagsUtils.validate_query_tags(tags) - assert tags["complex"] == "value\\:with\\,comma\\\\and\\\\backslash" + assert tags["complex"] == r"value\:with\,comma\\and\\backslash" def test_validate_query_tags_multiple_values_too_long(self): tags = { @@ -138,6 +138,79 @@ def test_validate_query_tags_value_after_escaping_too_long(self): with pytest.raises(DbtValidationError, match=expected_msg): QueryTagsUtils.validate_query_tags(tags) + def test_process_default_tags_escapes_special_chars(self): + """Test that process_default_tags escapes special characters.""" + tags = { + "key1": "value:with:colons", + "key2": "value,with,commas", + "key3": r"value\with\backslashes", + "key4": r"path\to:file,v1", + "key5": r"a\b:c,d\e:f,g", + "key6": r"start\,middle:,end", + } + result = QueryTagsUtils.process_default_tags(tags) + + assert result["key1"] == r"value\:with\:colons" + assert result["key2"] == r"value\,with\,commas" + assert result["key3"] == r"value\\with\\backslashes" + assert result["key4"] == r"path\\to\:file\,v1" + assert result["key5"] == r"a\\b\:c\,d\\e\:f\,g" + assert result["key6"] == r"start\\\,middle\:\,end" + + def test_process_default_tags_truncates_long_values(self): + """Test that process_default_tags truncates values exceeding 128 characters.""" + long_value = "x" * 150 + tags = {"long_key": long_value} + + result = QueryTagsUtils.process_default_tags(tags) + + # Should be truncated to 128 characters + assert len(result["long_key"]) == 128 + assert result["long_key"] == "x" * 128 + + def test_process_default_tags_truncates_before_escaping(self): + """Test that truncation happens before escaping to avoid cutting escape sequences.""" + # Create a value longer than 128 chars that contains special characters + # 126 x's + 3 colons = 129 chars (exceeds limit) + value = "x" * 126 + ":::" + tags = {"key": value} + + result = QueryTagsUtils.process_default_tags(tags) + + # Should truncate to 128 first (removing 1 char): "xxx...xxx::" + # Then escape the remaining colons: "xxx...xxx\:\:" + # Result: 126 x's + 4 chars from escaped colons = 130 chars (longer than 128, but safe) + assert len(result["key"]) == 130 + assert result["key"] == ("x" * 126 + r"\:\:") + + def test_process_default_tags_truncation_avoids_broken_escapes(self): + """Test that truncating before escaping avoids creating invalid escape sequences.""" + # If we truncated after escaping, we could cut "value\," to "value\" + # which would be an invalid/incomplete escape sequence + value = "x" * 127 + "," # 128 chars exactly + tags = {"key": value} + + result = QueryTagsUtils.process_default_tags(tags) + + # Should keep all 128 chars, then escape the comma: 127 x's + r"\," (2 chars) = 129 chars + assert len(result["key"]) == 129 + assert result["key"] == ("x" * 127 + r"\,") + + def test_process_default_tags_allows_reserved_keys(self): + """Test that process_default_tags allows reserved keys (unlike validate_query_tags).""" + tags = { + "@@dbt_model_name": "test_model", + "@@dbt_core_version": "1.5.0", + "custom_tag": "value", + } + + # Should not raise error even with reserved keys + result = QueryTagsUtils.process_default_tags(tags) + + assert result["@@dbt_model_name"] == "test_model" + assert result["@@dbt_core_version"] == "1.5.0" + assert result["custom_tag"] == "value" + def test_merge_query_tags_precedence(self): """Test that model tags override connection tags.""" connection_tags = {"team": "marketing", "cost_center": "3000"}