diff --git a/DESCRIPTION.md b/DESCRIPTION.md index e071f3da23..58fc1bee3d 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -11,6 +11,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne - Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP and AWS workloads only - Fixed `get_results_from_sfqid` when using `DictCursor` and executing multiple statements at once - Added the `oauth_credentials_in_body` parameter supporting an option to send the oauth client credentials in the request body + - Added support for intermediate certificates as roots when they are stored in the trust store - v3.17.3(September 02,2025) - Enhanced configuration file permission warning messages. diff --git a/src/snowflake/connector/connection_diagnostic.py b/src/snowflake/connector/connection_diagnostic.py index ba81a4ecb9..9bafc0e780 100644 --- a/src/snowflake/connector/connection_diagnostic.py +++ b/src/snowflake/connector/connection_diagnostic.py @@ -240,6 +240,10 @@ def __test_socket_get_cert( context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.load_verify_locations(certifi.where()) + # Best-effort: enable partial-chain when supported + _partial_flag = getattr(ssl, "VERIFY_X509_PARTIAL_CHAIN", 0) + if _partial_flag and hasattr(context, "verify_flags"): + context.verify_flags |= _partial_flag sock = context.wrap_socket(conn, server_hostname=host) certificate = ssl.DER_cert_to_PEM_cert(sock.getpeercert(True)) http_request = f"""GET / {host}:{port} HTTP/1.1\r\n diff --git a/src/snowflake/connector/ssl_wrap_socket.py b/src/snowflake/connector/ssl_wrap_socket.py index 2cebb66262..f1a14e5c89 100644 --- a/src/snowflake/connector/ssl_wrap_socket.py +++ b/src/snowflake/connector/ssl_wrap_socket.py @@ -8,11 +8,13 @@ # # and added OCSP validator on the top. import logging +import os +import ssl import time import weakref from contextvars import ContextVar from functools import wraps -from inspect import getfullargspec as get_args +from inspect import signature as _sig from socket import socket from typing import Any @@ -38,9 +40,57 @@ log = logging.getLogger(__name__) +# Helper utilities (private) +def _resolve_cafile(kwargs: dict[str, Any]) -> str | None: + """Resolve CA bundle path from kwargs or standard environment variables. + + Precedence: + 1) kwargs['ca_certs'] if provided by caller + 2) REQUESTS_CA_BUNDLE + 3) SSL_CERT_FILE + """ + caf = kwargs.get("ca_certs") + if caf: + return caf + return os.environ.get("REQUESTS_CA_BUNDLE") or os.environ.get("SSL_CERT_FILE") + + +def _ensure_partial_chain_on_context(ctx: PyOpenSSLContext, cafile: str | None) -> None: + """Load CA bundle (when provided) and enable OpenSSL partial-chain support on ctx.""" + if cafile: + try: + ctx.load_verify_locations(cafile=cafile, capath=None) + except (ssl.SSLError, OSError, ValueError): + # Leave context unchanged; handshake/validation surfaces failures + pass + try: + store = ctx._ctx.get_cert_store() + from OpenSSL import crypto as _crypto + + if hasattr(_crypto, "X509StoreFlags") and hasattr( + _crypto.X509StoreFlags, "PARTIAL_CHAIN" + ): + store.set_flags(_crypto.X509StoreFlags.PARTIAL_CHAIN) + except (AttributeError, ImportError, OpenSSL.SSL.Error, OSError, ValueError): + # Best-effort; if not available, default chain building applies + pass + + +def _build_context_with_partial_chain(cafile: str | None) -> PyOpenSSLContext: + """Create PyOpenSSL context configured for CERT_REQUIRED and partial-chain trust.""" + ctx = PyOpenSSLContext(ssl_.PROTOCOL_TLS_CLIENT) + try: + ctx.verify_mode = ssl.CERT_REQUIRED + except Exception: + pass + _ensure_partial_chain_on_context(ctx, cafile) + return ctx + + # Store a *weak* reference so that the context variable doesn’t prolong the # lifetime of the SessionManager. Once all owning connections are GC-ed the -# weakref goes dead and OCSP will fall back to its local manager (but most likely won't be used ever again anyway). +# weakref goes dead and OCSP will fall back to its local manager (but most +# likely won't be used ever again anyway). _CURRENT_SESSION_MANAGER: ContextVar[weakref.ref[SessionManager] | None] = ContextVar( "_CURRENT_SESSION_MANAGER", default=None, @@ -71,7 +121,10 @@ def set_current_session_manager(sm: SessionManager | None) -> Any: Called from SnowflakeConnection so that OCSP downloads use the same proxy / header configuration as the initiating connection. - Alternative approach would be moving method inject_into_urllib3() inside connection initialization, but in case this delay (from module import time to connection initialization time) would cause some code to break we stayed with this approach, having in mind soon OCSP deprecation. + Alternative approach would be moving method inject_into_urllib3() inside + connection initialization, but in case this delay (from module import time + to connection initialization time) would cause some code to break we stayed + with this approach, having in mind soon OCSP deprecation. """ return _CURRENT_SESSION_MANAGER.set(weakref.ref(sm) if sm is not None else None) @@ -93,37 +146,29 @@ def inject_into_urllib3() -> None: @wraps(ssl_.ssl_wrap_socket) def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket: - # Extract host_name - hostname_index = get_args(ssl_.ssl_wrap_socket).args.index("server_hostname") - server_hostname = ( - args[hostname_index] - if len(args) > hostname_index - else kwargs.get("server_hostname", None) - ) - # Remove context if present - ssl_context_index = get_args(ssl_.ssl_wrap_socket).args.index("ssl_context") - context_in_args = len(args) > ssl_context_index - ssl_context = ( - args[hostname_index] if context_in_args else kwargs.get("ssl_context", None) - ) - if not isinstance(ssl_context, PyOpenSSLContext): - # Create new default context - if context_in_args: - new_args = list(args) - new_args[ssl_context_index] = None - args = tuple(new_args) - else: - del kwargs["ssl_context"] - # Fix ca certs location - ca_certs_index = get_args(ssl_.ssl_wrap_socket).args.index("ca_certs") - ca_certs_in_args = len(args) > ca_certs_index - if not ca_certs_in_args and not kwargs.get("ca_certs"): - kwargs["ca_certs"] = certifi.where() - - ret = ssl_.ssl_wrap_socket(*args, **kwargs) + # Bind passed args/kwargs to the underlying signature to support both positional and keyword calls + bound = _sig(ssl_.ssl_wrap_socket).bind_partial(*args, **kwargs) + params = bound.arguments + + server_hostname = params.get("server_hostname") + + # Ensure CA bundle default if not provided + if not params.get("ca_certs"): + params["ca_certs"] = certifi.where() + + # Ensure PyOpenSSL context with partial-chain is used if none or wrong type provided + provided_ctx = params.get("ssl_context") + if not isinstance(provided_ctx, PyOpenSSLContext): + cafile_for_ctx = _resolve_cafile(params) + params["ssl_context"] = _build_context_with_partial_chain(cafile_for_ctx) + else: + # If a PyOpenSSLContext is provided, ensure it trusts the provided CA and partial-chain is enabled + _ensure_partial_chain_on_context(provided_ctx, _resolve_cafile(params)) + + ret = ssl_.ssl_wrap_socket(**params) log.debug( - "OCSP Mode: %s, " "OCSP response cache file name: %s", + "OCSP Mode: %s, OCSP response cache file name: %s", FEATURE_OCSP_MODE.name, FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME, ) @@ -137,10 +182,7 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket: ).validate(server_hostname, ret.connection) if not v: raise OperationalError( - msg=( - "The certificate is revoked or " - "could not be validated: hostname={}".format(server_hostname) - ), + msg=f"The certificate is revoked or could not be validated: hostname={server_hostname}", errno=ER_OCSP_RESPONSE_CERT_STATUS_REVOKED, ) else: diff --git a/test/unit/test_ssl_partial_chain.py b/test/unit/test_ssl_partial_chain.py new file mode 100644 index 0000000000..4f74e63f95 --- /dev/null +++ b/test/unit/test_ssl_partial_chain.py @@ -0,0 +1,58 @@ +"""Unit tests for SSL wrapper partial-chain context injection.""" + +import types + +import pytest + +import snowflake.connector.ssl_wrap_socket as ssw # pylint: disable=import-error +from snowflake.connector.constants import OCSPMode # pylint: disable=import-error +from snowflake.connector.vendored.urllib3.contrib.pyopenssl import ( # pylint: disable=import-error + PyOpenSSLContext, +) + + +@pytest.fixture(autouse=True) +def disable_ocsp_checks(): + """Disable OCSP checks for offline unit testing.""" + # Ensure wrapper doesn't perform OCSP to keep this unit test offline + orig = ssw.FEATURE_OCSP_MODE + ssw.FEATURE_OCSP_MODE = OCSPMode.DISABLE_OCSP_CHECKS + try: + yield + finally: + ssw.FEATURE_OCSP_MODE = orig + + +def test_wrapper_injects_pyopenssl_context(monkeypatch): + """Wrapper should inject a PyOpenSSLContext when none is given.""" + captured = {} + + def fake_ssl_wrap_socket( # pylint: disable=unused-argument,too-many-arguments,too-many-positional-arguments + sock, ssl_context=None, **kwargs + ): + # Assert that our wrapper provided a PyOpenSSLContext + captured["ctx_is_pyopenssl"] = isinstance(ssl_context, PyOpenSSLContext) + # Return a minimal object with a 'connection' attribute expected by wrapper + return types.SimpleNamespace(connection=None) + + # Patch underlying urllib3 ssl_wrap_socket used by our wrapper + monkeypatch.setattr(ssw.ssl_, "ssl_wrap_socket", fake_ssl_wrap_socket) + + # Call our wrapper without providing ssl_context; it should inject one + ssw.ssl_wrap_socket_with_ocsp( + sock=None, + keyfile=None, + certfile=None, + cert_reqs=None, + ca_certs=None, + server_hostname="localhost", + ssl_version=None, + ciphers=None, + ssl_context=None, + ca_cert_dir=None, + key_password=None, + ca_cert_data=None, + tls_in_tls=False, + ) + + assert captured.get("ctx_is_pyopenssl") is True diff --git a/test/unit/test_ssl_partial_chain_handshake.py b/test/unit/test_ssl_partial_chain_handshake.py new file mode 100644 index 0000000000..1e51ab4ff5 --- /dev/null +++ b/test/unit/test_ssl_partial_chain_handshake.py @@ -0,0 +1,199 @@ +"""Integration-style unit test for partial-chain TLS handshake.""" + +import ipaddress as _ip +import socket +import ssl +import tempfile as _tempfile +import threading +from datetime import datetime, timedelta, timezone + +import pytest +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID + +import snowflake.connector.ssl_wrap_socket as ssw # pylint: disable=import-error +from snowflake.connector.constants import OCSPMode # pylint: disable=import-error + + +@pytest.fixture(autouse=True) +def disable_ocsp_checks(): + """Disable OCSP checks for offline unit testing.""" + orig = ssw.FEATURE_OCSP_MODE + ssw.FEATURE_OCSP_MODE = OCSPMode.DISABLE_OCSP_CHECKS + try: + yield + finally: + ssw.FEATURE_OCSP_MODE = orig + + +def _create_key(): + """Create a new RSA key for certificate generation.""" + return rsa.generate_private_key(public_exponent=65537, key_size=2048) + + +def _create_cert(subject_cn, issuer_cert, issuer_key, is_ca, subject_key, ca=False): + """Create a certificate signed by issuer or self-signed if issuer is None.""" + subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, subject_cn)]) + issuer_name = subject if issuer_cert is None else issuer_cert.subject + + builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer_name) + .public_key(subject_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc) - timedelta(minutes=1)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(hours=1)) + ) + + if is_ca: + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=1), critical=True + ).add_extension( + x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + else: + builder = ( + builder.add_extension( + x509.BasicConstraints(ca=False, path_length=None), critical=True + ) + .add_extension( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False + ) + .add_extension( + x509.SubjectAlternativeName( + [ + x509.DNSName("localhost"), + x509.IPAddress(_ip.ip_address("127.0.0.1")), + ] + ), + critical=False, + ) + ) + + # Subject Key Identifier + builder = builder.add_extension( + x509.SubjectKeyIdentifier.from_public_key(subject_key.public_key()), + critical=False, + ) + # Authority Key Identifier (referencing issuer public key) + authority_pubkey = ( + subject_key.public_key() if issuer_key is None else issuer_key.public_key() + ) + builder = builder.add_extension( + x509.AuthorityKeyIdentifier.from_issuer_public_key(authority_pubkey), + critical=False, + ) + + signer_key = issuer_key if issuer_key is not None else subject_key + cert = builder.sign(private_key=signer_key, algorithm=hashes.SHA256()) + return cert + + +def _pem(obj, is_key=False): + """Return PEM-encoded certificate or key.""" + if is_key: + return obj.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + return obj.public_bytes(encoding=serialization.Encoding.PEM) + + +def _run_tls_server(server_cert_pem, server_key_pem, chain_pem, ready_evt, addr_holder): + """Run a minimal TLS server presenting server+intermediate chain.""" + # Minimal TLS server using Python ssl to present server+intermediate chain + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + # Combine server and intermediate into one PEM for certfile + with _tempfile.NamedTemporaryFile(delete=False) as cert_chain_file: + cert_chain_file.write(server_cert_pem) + cert_chain_file.write(chain_pem) + cert_chain_file.flush() + certfile_path = cert_chain_file.name + with _tempfile.NamedTemporaryFile(delete=False) as key_file: + key_file.write(server_key_pem) + key_file.flush() + keyfile_path = key_file.name + ctx.load_cert_chain(certfile=certfile_path, keyfile=keyfile_path) + + s = socket.socket() + s.bind(("127.0.0.1", 0)) + s.listen(1) + addr = s.getsockname() + addr_holder.append(addr) + ready_evt.set() + with ctx.wrap_socket(s, server_side=True) as ssock: + conn, _ = ssock.accept() + conn.close() + s.close() + + +def test_partial_chain_handshake_succeeds_with_intermediate_as_anchor(): + """Client should handshake trusting only the intermediate as anchor.""" + # Generate Root -> Intermediate -> Server + root_key = _create_key() + root_cert = _create_cert("Root", None, None, True, root_key) + + inter_key = _create_key() + inter_cert = _create_cert("Intermediate", root_cert, root_key, True, inter_key) + + server_key = _create_key() + server_cert = _create_cert("Server", inter_cert, inter_key, False, server_key) + + # Start TLS server presenting server + intermediate chain + ready_evt = threading.Event() + addr_holder = [] + t = threading.Thread( + target=_run_tls_server, + args=( + _pem(server_cert), + _pem(server_key, True), + _pem(inter_cert), + ready_evt, + addr_holder, + ), + ) + t.daemon = True + t.start() + ready_evt.wait(5) + host, port = addr_holder[0] + + # Build PyOpenSSL context with only intermediate as trust anchor + ctx = ssw._build_context_with_partial_chain( + None + ) # pylint: disable=protected-access + # Load intermediate into store via PEM file path by reusing helper + with _tempfile.NamedTemporaryFile(delete=False) as caf: + caf.write(_pem(inter_cert)) + caf.flush() + ctx.load_verify_locations(cafile=caf.name) + + # Wrap a socket with our wrapper specifying our context + s = socket.socket() + s.settimeout(5) + s.connect((host, port)) + + # The wrapper expects kwargs similar to urllib3; use provided context + ws = ssw.ssl_wrap_socket_with_ocsp( + sock=s, + server_hostname="localhost", + ssl_context=ctx, + ) + # If we reached here without SSL error, TLS handshake succeeded with + # intermediate-only trust; access attribute to assert presence + assert hasattr(ws, "connection") + s.close()