Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
## dbt-databricks 1.11.3 (TBD)

### Fixes

- Truncate (128 charactesr max) and escape special characters for default query tag values

## dbt-databricks 1.11.2 (Nov 18, 2025)

### Fixes
Expand Down
69 changes: 48 additions & 21 deletions dbt/adapters/databricks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ def parse_query_tags(query_tags_str: Optional[str]) -> dict[str, str]:
except json.JSONDecodeError as e:
raise DbtValidationError(f"Invalid JSON in query_tags: {e}")

@staticmethod
def escape_tag_value(key: str, value: str, source: str = "") -> str:
"""Escape special characters in tag values (backslash, comma, colon)."""
source_prefix = f"{source}: " if source else ""

if re.search(r"[\\,:]", value):
logger.warning(
f"{source_prefix}Query tag value for key '{key}' contains unescaped "
f"character(s): {value}. Escaping..."
)
value = value.replace("\\", r"\\").replace(",", r"\,").replace(":", r"\:")

return value

@staticmethod
def validate_query_tags(tags: dict[str, str], source: str = "") -> None:
"""Validate query tags for reserved keys and limits."""
Expand All @@ -156,17 +170,11 @@ def validate_query_tags(tags: dict[str, str], source: str = "") -> None:
f"Reserved keys are: {', '.join(sorted(QueryTagsUtils.RESERVED_KEYS))}"
)

# Escape commas, colons, and backslashes in tag values
# Escape values (modifies dict in place)
for key in tags.keys():
value = tags[key]
if re.search(r"[\\,:]", value):
logger.warning(
f"{source_prefix}Query tag value for key '{key}' contains unescaped "
f"character(s): {value}. Escaping..."
)
tags[key] = value.replace("\\", "\\\\").replace(",", "\\,").replace(":", "\\:")
tags[key] = QueryTagsUtils.escape_tag_value(key, tags[key], source)

# Validate that no tag value exceeds 128 characters
# Validate that no tag value exceeds 128 characters (after escaping)
long_values = {k: v for k, v in tags.items() if len(v) > 128}
if long_values:
raise DbtValidationError(
Expand All @@ -181,6 +189,28 @@ def validate_query_tags(tags: dict[str, str], source: str = "") -> None:
f"Maximum allowed is {QueryTagsUtils.MAX_TAGS}"
)

@staticmethod
def process_default_tags(tags: dict[str, str]) -> dict[str, str]:
"""
Process default tags: truncate long values, then escape special characters.

Note: We truncate BEFORE escaping to avoid cutting escape sequences in half,
which would create invalid sequences that can't be deserialized.
"""
processed = {}
for key, value in tags.items():
if len(value) > 128:
logger.debug(
f"Default tags: Query tag value for key '{key}' exceeds 128 characters "
f"({len(value)} chars). Truncating to 128 characters."
)
value = value[:128]

escaped_value = QueryTagsUtils.escape_tag_value(key, value, "Default tags")
processed[key] = escaped_value

return processed

@staticmethod
def merge_query_tags(
connection_tags: dict[str, str],
Expand All @@ -191,23 +221,20 @@ def merge_query_tags(
Merge query tags with precedence: model > connection > default.
Validates that no reserved keys are used and tag limits are respected.
"""
# All sources are now already parsed dicts
conn_tags = connection_tags
model_tags_dict = model_tags
default_tags_dict = default_tags
# Process default tags (escape and truncate, don't validate reserved keys)
processed_default_tags = QueryTagsUtils.process_default_tags(default_tags)

# Validate each source (user-provided tags cannot use reserved keys)
QueryTagsUtils.validate_query_tags(conn_tags, "Connection config")
QueryTagsUtils.validate_query_tags(model_tags_dict, "Model config")
# Validate user-provided tags (cannot use reserved keys)
QueryTagsUtils.validate_query_tags(connection_tags, "Connection config")
QueryTagsUtils.validate_query_tags(model_tags, "Model config")

# Merge with precedence: model > connection > default
merged = {}
merged.update(default_tags_dict)
merged.update(conn_tags)
merged.update(model_tags_dict)
merged.update(processed_default_tags)
merged.update(connection_tags)
merged.update(model_tags)

# Final validation of merged tags (only check total count, not reserved keys
# since default tags are allowed to use reserved keys)
# Final validation of merged tags (only check total count)
if len(merged) > QueryTagsUtils.MAX_TAGS:
raise DbtValidationError(
f"Too many total query tags ({len(merged)}). "
Expand Down
85 changes: 79 additions & 6 deletions tests/unit/test_query_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,25 @@ def test_validate_query_tags_escapes_comma(self):
"""Test that commas in tag values are escaped."""
tags = {"team": "marketing,sales"}
QueryTagsUtils.validate_query_tags(tags)
assert tags["team"] == "marketing\\,sales"
assert tags["team"] == r"marketing\,sales"

def test_validate_query_tags_escapes_colon(self):
"""Test that colons in tag values are escaped."""
tags = {"description": "project:alpha"}
QueryTagsUtils.validate_query_tags(tags)
assert tags["description"] == "project\\:alpha"
assert tags["description"] == r"project\:alpha"

def test_validate_query_tags_escapes_backslash(self):
"""Test that backslashes in tag values are escaped."""
tags = {"path": "folder\\subfolder"}
tags = {"path": r"folder\subfolder"}
QueryTagsUtils.validate_query_tags(tags)
assert tags["path"] == "folder\\\\subfolder"
assert tags["path"] == r"folder\\subfolder"

def test_validate_query_tags_escapes_multiple_special_chars(self):
"""Test that multiple special characters are all escaped."""
tags = {"complex": "value:with,comma\\and\\backslash"}
tags = {"complex": r"value:with,comma\and\backslash"}
QueryTagsUtils.validate_query_tags(tags)
assert tags["complex"] == "value\\:with\\,comma\\\\and\\\\backslash"
assert tags["complex"] == r"value\:with\,comma\\and\\backslash"

def test_validate_query_tags_multiple_values_too_long(self):
tags = {
Expand Down Expand Up @@ -138,6 +138,79 @@ def test_validate_query_tags_value_after_escaping_too_long(self):
with pytest.raises(DbtValidationError, match=expected_msg):
QueryTagsUtils.validate_query_tags(tags)

def test_process_default_tags_escapes_special_chars(self):
"""Test that process_default_tags escapes special characters."""
tags = {
"key1": "value:with:colons",
"key2": "value,with,commas",
"key3": r"value\with\backslashes",
"key4": r"path\to:file,v1",
"key5": r"a\b:c,d\e:f,g",
"key6": r"start\,middle:,end",
}
result = QueryTagsUtils.process_default_tags(tags)

assert result["key1"] == r"value\:with\:colons"
assert result["key2"] == r"value\,with\,commas"
assert result["key3"] == r"value\\with\\backslashes"
assert result["key4"] == r"path\\to\:file\,v1"
assert result["key5"] == r"a\\b\:c\,d\\e\:f\,g"
assert result["key6"] == r"start\\\,middle\:\,end"

def test_process_default_tags_truncates_long_values(self):
"""Test that process_default_tags truncates values exceeding 128 characters."""
long_value = "x" * 150
tags = {"long_key": long_value}

result = QueryTagsUtils.process_default_tags(tags)

# Should be truncated to 128 characters
assert len(result["long_key"]) == 128
assert result["long_key"] == "x" * 128

def test_process_default_tags_truncates_before_escaping(self):
"""Test that truncation happens before escaping to avoid cutting escape sequences."""
# Create a value longer than 128 chars that contains special characters
# 126 x's + 3 colons = 129 chars (exceeds limit)
value = "x" * 126 + ":::"
tags = {"key": value}

result = QueryTagsUtils.process_default_tags(tags)

# Should truncate to 128 first (removing 1 char): "xxx...xxx::"
# Then escape the remaining colons: "xxx...xxx\:\:"
# Result: 126 x's + 4 chars from escaped colons = 130 chars (longer than 128, but safe)
assert len(result["key"]) == 130
assert result["key"] == ("x" * 126 + r"\:\:")

def test_process_default_tags_truncation_avoids_broken_escapes(self):
"""Test that truncating before escaping avoids creating invalid escape sequences."""
# If we truncated after escaping, we could cut "value\," to "value\"
# which would be an invalid/incomplete escape sequence
value = "x" * 127 + "," # 128 chars exactly
tags = {"key": value}

result = QueryTagsUtils.process_default_tags(tags)

# Should keep all 128 chars, then escape the comma: 127 x's + r"\," (2 chars) = 129 chars
assert len(result["key"]) == 129
assert result["key"] == ("x" * 127 + r"\,")

def test_process_default_tags_allows_reserved_keys(self):
"""Test that process_default_tags allows reserved keys (unlike validate_query_tags)."""
tags = {
"@@dbt_model_name": "test_model",
"@@dbt_core_version": "1.5.0",
"custom_tag": "value",
}

# Should not raise error even with reserved keys
result = QueryTagsUtils.process_default_tags(tags)

assert result["@@dbt_model_name"] == "test_model"
assert result["@@dbt_core_version"] == "1.5.0"
assert result["custom_tag"] == "value"

def test_merge_query_tags_precedence(self):
"""Test that model tags override connection tags."""
connection_tags = {"team": "marketing", "cost_center": "3000"}
Expand Down
Loading