Skip to content

Commit 3838fb0

Browse files
SNOW-2176203: Support intermediates in trust stores (#2520)
Our current python connector does not support a configuration with intermediate certificates in a trust store as roots of trust. This is allowed by RFCs; the root does not need to be self signed, and this is the behavior we have in our Go client. These changes enable partial chain validation to normalize behavior across clients.
1 parent 249195a commit 3838fb0

File tree

5 files changed

+340
-36
lines changed

5 files changed

+340
-36
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1111
- Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP and AWS workloads only
1212
- Fixed `get_results_from_sfqid` when using `DictCursor` and executing multiple statements at once
1313
- Added the `oauth_credentials_in_body` parameter supporting an option to send the oauth client credentials in the request body
14+
- Added support for intermediate certificates as roots when they are stored in the trust store
1415

1516
- v3.17.3(September 02,2025)
1617
- Enhanced configuration file permission warning messages.

src/snowflake/connector/connection_diagnostic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@ def __test_socket_get_cert(
240240

241241
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
242242
context.load_verify_locations(certifi.where())
243+
# Best-effort: enable partial-chain when supported
244+
_partial_flag = getattr(ssl, "VERIFY_X509_PARTIAL_CHAIN", 0)
245+
if _partial_flag and hasattr(context, "verify_flags"):
246+
context.verify_flags |= _partial_flag
243247
sock = context.wrap_socket(conn, server_hostname=host)
244248
certificate = ssl.DER_cert_to_PEM_cert(sock.getpeercert(True))
245249
http_request = f"""GET / {host}:{port} HTTP/1.1\r\n

src/snowflake/connector/ssl_wrap_socket.py

Lines changed: 78 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
#
99
# and added OCSP validator on the top.
1010
import logging
11+
import os
12+
import ssl
1113
import time
1214
import weakref
1315
from contextvars import ContextVar
1416
from functools import wraps
15-
from inspect import getfullargspec as get_args
17+
from inspect import signature as _sig
1618
from socket import socket
1719
from typing import Any
1820

@@ -38,9 +40,57 @@
3840
log = logging.getLogger(__name__)
3941

4042

43+
# Helper utilities (private)
44+
def _resolve_cafile(kwargs: dict[str, Any]) -> str | None:
45+
"""Resolve CA bundle path from kwargs or standard environment variables.
46+
47+
Precedence:
48+
1) kwargs['ca_certs'] if provided by caller
49+
2) REQUESTS_CA_BUNDLE
50+
3) SSL_CERT_FILE
51+
"""
52+
caf = kwargs.get("ca_certs")
53+
if caf:
54+
return caf
55+
return os.environ.get("REQUESTS_CA_BUNDLE") or os.environ.get("SSL_CERT_FILE")
56+
57+
58+
def _ensure_partial_chain_on_context(ctx: PyOpenSSLContext, cafile: str | None) -> None:
59+
"""Load CA bundle (when provided) and enable OpenSSL partial-chain support on ctx."""
60+
if cafile:
61+
try:
62+
ctx.load_verify_locations(cafile=cafile, capath=None)
63+
except (ssl.SSLError, OSError, ValueError):
64+
# Leave context unchanged; handshake/validation surfaces failures
65+
pass
66+
try:
67+
store = ctx._ctx.get_cert_store()
68+
from OpenSSL import crypto as _crypto
69+
70+
if hasattr(_crypto, "X509StoreFlags") and hasattr(
71+
_crypto.X509StoreFlags, "PARTIAL_CHAIN"
72+
):
73+
store.set_flags(_crypto.X509StoreFlags.PARTIAL_CHAIN)
74+
except (AttributeError, ImportError, OpenSSL.SSL.Error, OSError, ValueError):
75+
# Best-effort; if not available, default chain building applies
76+
pass
77+
78+
79+
def _build_context_with_partial_chain(cafile: str | None) -> PyOpenSSLContext:
80+
"""Create PyOpenSSL context configured for CERT_REQUIRED and partial-chain trust."""
81+
ctx = PyOpenSSLContext(ssl_.PROTOCOL_TLS_CLIENT)
82+
try:
83+
ctx.verify_mode = ssl.CERT_REQUIRED
84+
except Exception:
85+
pass
86+
_ensure_partial_chain_on_context(ctx, cafile)
87+
return ctx
88+
89+
4190
# Store a *weak* reference so that the context variable doesn’t prolong the
4291
# lifetime of the SessionManager. Once all owning connections are GC-ed the
43-
# weakref goes dead and OCSP will fall back to its local manager (but most likely won't be used ever again anyway).
92+
# weakref goes dead and OCSP will fall back to its local manager (but most
93+
# likely won't be used ever again anyway).
4494
_CURRENT_SESSION_MANAGER: ContextVar[weakref.ref[SessionManager] | None] = ContextVar(
4595
"_CURRENT_SESSION_MANAGER",
4696
default=None,
@@ -71,7 +121,10 @@ def set_current_session_manager(sm: SessionManager | None) -> Any:
71121
Called from SnowflakeConnection so that OCSP downloads
72122
use the same proxy / header configuration as the initiating connection.
73123
74-
Alternative approach would be moving method inject_into_urllib3() inside connection initialization, but in case this delay (from module import time to connection initialization time) would cause some code to break we stayed with this approach, having in mind soon OCSP deprecation.
124+
Alternative approach would be moving method inject_into_urllib3() inside
125+
connection initialization, but in case this delay (from module import time
126+
to connection initialization time) would cause some code to break we stayed
127+
with this approach, having in mind soon OCSP deprecation.
75128
"""
76129
return _CURRENT_SESSION_MANAGER.set(weakref.ref(sm) if sm is not None else None)
77130

@@ -93,37 +146,29 @@ def inject_into_urllib3() -> None:
93146

94147
@wraps(ssl_.ssl_wrap_socket)
95148
def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
96-
# Extract host_name
97-
hostname_index = get_args(ssl_.ssl_wrap_socket).args.index("server_hostname")
98-
server_hostname = (
99-
args[hostname_index]
100-
if len(args) > hostname_index
101-
else kwargs.get("server_hostname", None)
102-
)
103-
# Remove context if present
104-
ssl_context_index = get_args(ssl_.ssl_wrap_socket).args.index("ssl_context")
105-
context_in_args = len(args) > ssl_context_index
106-
ssl_context = (
107-
args[hostname_index] if context_in_args else kwargs.get("ssl_context", None)
108-
)
109-
if not isinstance(ssl_context, PyOpenSSLContext):
110-
# Create new default context
111-
if context_in_args:
112-
new_args = list(args)
113-
new_args[ssl_context_index] = None
114-
args = tuple(new_args)
115-
else:
116-
del kwargs["ssl_context"]
117-
# Fix ca certs location
118-
ca_certs_index = get_args(ssl_.ssl_wrap_socket).args.index("ca_certs")
119-
ca_certs_in_args = len(args) > ca_certs_index
120-
if not ca_certs_in_args and not kwargs.get("ca_certs"):
121-
kwargs["ca_certs"] = certifi.where()
122-
123-
ret = ssl_.ssl_wrap_socket(*args, **kwargs)
149+
# Bind passed args/kwargs to the underlying signature to support both positional and keyword calls
150+
bound = _sig(ssl_.ssl_wrap_socket).bind_partial(*args, **kwargs)
151+
params = bound.arguments
152+
153+
server_hostname = params.get("server_hostname")
154+
155+
# Ensure CA bundle default if not provided
156+
if not params.get("ca_certs"):
157+
params["ca_certs"] = certifi.where()
158+
159+
# Ensure PyOpenSSL context with partial-chain is used if none or wrong type provided
160+
provided_ctx = params.get("ssl_context")
161+
if not isinstance(provided_ctx, PyOpenSSLContext):
162+
cafile_for_ctx = _resolve_cafile(params)
163+
params["ssl_context"] = _build_context_with_partial_chain(cafile_for_ctx)
164+
else:
165+
# If a PyOpenSSLContext is provided, ensure it trusts the provided CA and partial-chain is enabled
166+
_ensure_partial_chain_on_context(provided_ctx, _resolve_cafile(params))
167+
168+
ret = ssl_.ssl_wrap_socket(**params)
124169

125170
log.debug(
126-
"OCSP Mode: %s, " "OCSP response cache file name: %s",
171+
"OCSP Mode: %s, OCSP response cache file name: %s",
127172
FEATURE_OCSP_MODE.name,
128173
FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME,
129174
)
@@ -137,10 +182,7 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
137182
).validate(server_hostname, ret.connection)
138183
if not v:
139184
raise OperationalError(
140-
msg=(
141-
"The certificate is revoked or "
142-
"could not be validated: hostname={}".format(server_hostname)
143-
),
185+
msg=f"The certificate is revoked or could not be validated: hostname={server_hostname}",
144186
errno=ER_OCSP_RESPONSE_CERT_STATUS_REVOKED,
145187
)
146188
else:
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Unit tests for SSL wrapper partial-chain context injection."""
2+
3+
import types
4+
5+
import pytest
6+
7+
import snowflake.connector.ssl_wrap_socket as ssw # pylint: disable=import-error
8+
from snowflake.connector.constants import OCSPMode # pylint: disable=import-error
9+
from snowflake.connector.vendored.urllib3.contrib.pyopenssl import ( # pylint: disable=import-error
10+
PyOpenSSLContext,
11+
)
12+
13+
14+
@pytest.fixture(autouse=True)
15+
def disable_ocsp_checks():
16+
"""Disable OCSP checks for offline unit testing."""
17+
# Ensure wrapper doesn't perform OCSP to keep this unit test offline
18+
orig = ssw.FEATURE_OCSP_MODE
19+
ssw.FEATURE_OCSP_MODE = OCSPMode.DISABLE_OCSP_CHECKS
20+
try:
21+
yield
22+
finally:
23+
ssw.FEATURE_OCSP_MODE = orig
24+
25+
26+
def test_wrapper_injects_pyopenssl_context(monkeypatch):
27+
"""Wrapper should inject a PyOpenSSLContext when none is given."""
28+
captured = {}
29+
30+
def fake_ssl_wrap_socket( # pylint: disable=unused-argument,too-many-arguments,too-many-positional-arguments
31+
sock, ssl_context=None, **kwargs
32+
):
33+
# Assert that our wrapper provided a PyOpenSSLContext
34+
captured["ctx_is_pyopenssl"] = isinstance(ssl_context, PyOpenSSLContext)
35+
# Return a minimal object with a 'connection' attribute expected by wrapper
36+
return types.SimpleNamespace(connection=None)
37+
38+
# Patch underlying urllib3 ssl_wrap_socket used by our wrapper
39+
monkeypatch.setattr(ssw.ssl_, "ssl_wrap_socket", fake_ssl_wrap_socket)
40+
41+
# Call our wrapper without providing ssl_context; it should inject one
42+
ssw.ssl_wrap_socket_with_ocsp(
43+
sock=None,
44+
keyfile=None,
45+
certfile=None,
46+
cert_reqs=None,
47+
ca_certs=None,
48+
server_hostname="localhost",
49+
ssl_version=None,
50+
ciphers=None,
51+
ssl_context=None,
52+
ca_cert_dir=None,
53+
key_password=None,
54+
ca_cert_data=None,
55+
tls_in_tls=False,
56+
)
57+
58+
assert captured.get("ctx_is_pyopenssl") is True

0 commit comments

Comments
 (0)