Skip to content
Merged
92 changes: 85 additions & 7 deletions src/snowflake/connector/ssl_wrap_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#
# and added OCSP validator on the top.
import logging
import ssl
import time
import weakref
from contextvars import ContextVar
Expand Down Expand Up @@ -40,7 +41,8 @@

# Store a *weak* reference so that the context variable doesn’t prolong the
# lifetime of the SessionManager. Once all owning connections are GC-ed the
# weakref goes dead and OCSP will fall back to its local manager (but most likely won't be used ever again anyway).
# weakref goes dead and OCSP will fall back to its local manager (but most
# likely won't be used ever again anyway).
_CURRENT_SESSION_MANAGER: ContextVar[weakref.ref[SessionManager] | None] = ContextVar(
"_CURRENT_SESSION_MANAGER",
default=None,
Expand Down Expand Up @@ -71,7 +73,10 @@ def set_current_session_manager(sm: SessionManager | None) -> Any:
Called from SnowflakeConnection so that OCSP downloads
use the same proxy / header configuration as the initiating connection.

Alternative approach would be moving method inject_into_urllib3() inside connection initialization, but in case this delay (from module import time to connection initialization time) would cause some code to break we stayed with this approach, having in mind soon OCSP deprecation.
Alternative approach would be moving method inject_into_urllib3() inside
connection initialization, but in case this delay (from module import time
to connection initialization time) would cause some code to break we stayed
with this approach, having in mind soon OCSP deprecation.
"""
return _CURRENT_SESSION_MANAGER.set(weakref.ref(sm) if sm is not None else None)

Expand All @@ -85,6 +90,35 @@ def reset_current_session_manager(token) -> None:
pass


def _build_pyopenssl_context_with_ca_and_partial_chain(
cafile: str | None,
) -> PyOpenSSLContext:
ctx = PyOpenSSLContext(ssl_.PROTOCOL_TLS_CLIENT)
try:
# Ensure certificate verification is enabled
ctx.verify_mode = ssl.CERT_REQUIRED
except Exception:
pass
try:
if cafile:
ctx.load_verify_locations(cafile=cafile, capath=None)
except Exception:
pass
# Enable partial-chain verification so intermediates in trust store can
# terminate chains
try:
store = ctx._ctx.get_cert_store()
from OpenSSL import crypto as _crypto

if hasattr(_crypto, "X509StoreFlags") and hasattr(
_crypto.X509StoreFlags, "PARTIAL_CHAIN"
):
store.set_flags(_crypto.X509StoreFlags.PARTIAL_CHAIN)
except Exception:
pass
return ctx


def inject_into_urllib3() -> None:
"""Monkey-patch urllib3 with PyOpenSSL-backed SSL-support and OCSP."""
log.debug("Injecting ssl_wrap_socket_with_ocsp")
Expand Down Expand Up @@ -120,10 +154,57 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
if not ca_certs_in_args and not kwargs.get("ca_certs"):
kwargs["ca_certs"] = certifi.where()

# Ensure PyOpenSSL context with partial-chain is used if no suitable context provided
try:
provided_ctx = kwargs.get("ssl_context", None)
if not isinstance(provided_ctx, PyOpenSSLContext):
cafile_for_ctx = kwargs.get("ca_certs")
if not cafile_for_ctx:
import os as _os

cafile_for_ctx = _os.environ.get(
"REQUESTS_CA_BUNDLE"
) or _os.environ.get("SSL_CERT_FILE")
kwargs["ssl_context"] = _build_pyopenssl_context_with_ca_and_partial_chain(
cafile_for_ctx
)
except Exception:
pass

# If a PyOpenSSLContext is provided, ensure it trusts the provided CA and
# partial-chain is enabled
try:
provided_ctx = kwargs.get("ssl_context", None)
if isinstance(provided_ctx, PyOpenSSLContext):
caf = kwargs.get("ca_certs")
if not caf:
import os as _os

caf = _os.environ.get("REQUESTS_CA_BUNDLE") or _os.environ.get(
"SSL_CERT_FILE"
)
try:
if caf:
provided_ctx.load_verify_locations(cafile=caf, capath=None)
except Exception:
pass
try:
store = provided_ctx._ctx.get_cert_store()
from OpenSSL import crypto as _crypto

if hasattr(_crypto, "X509StoreFlags") and hasattr(
_crypto.X509StoreFlags, "PARTIAL_CHAIN"
):
store.set_flags(_crypto.X509StoreFlags.PARTIAL_CHAIN)
except Exception:
pass
except Exception:
pass

ret = ssl_.ssl_wrap_socket(*args, **kwargs)

log.debug(
"OCSP Mode: %s, " "OCSP response cache file name: %s",
"OCSP Mode: %s, OCSP response cache file name: %s",
FEATURE_OCSP_MODE.name,
FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME,
)
Expand All @@ -137,10 +218,7 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
).validate(server_hostname, ret.connection)
if not v:
raise OperationalError(
msg=(
"The certificate is revoked or "
"could not be validated: hostname={}".format(server_hostname)
),
msg=f"The certificate is revoked or could not be validated: hostname={server_hostname}",
errno=ER_OCSP_RESPONSE_CERT_STATUS_REVOKED,
)
else:
Expand Down
70 changes: 70 additions & 0 deletions test/unit/test_ssl_partial_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Unit tests for SSL wrapper partial-chain context injection."""

import types

import pytest

import snowflake.connector.ssl_wrap_socket as ssw # pylint: disable=import-error
from snowflake.connector.constants import OCSPMode # pylint: disable=import-error
from snowflake.connector.vendored.urllib3.contrib.pyopenssl import ( # pylint: disable=import-error
PyOpenSSLContext,
)


@pytest.fixture(autouse=True)
def disable_ocsp_checks():
"""Disable OCSP checks for offline unit testing."""
# Ensure wrapper doesn't perform OCSP to keep this unit test offline
orig = ssw.FEATURE_OCSP_MODE
ssw.FEATURE_OCSP_MODE = OCSPMode.DISABLE_OCSP_CHECKS
try:
yield
finally:
ssw.FEATURE_OCSP_MODE = orig


def test_wrapper_injects_pyopenssl_context(monkeypatch):
"""Wrapper should inject a PyOpenSSLContext when none is given."""
captured = {}

def fake_ssl_wrap_socket( # pylint: disable=unused-argument,too-many-arguments,too-many-positional-arguments
sock,
keyfile=None,
certfile=None,
cert_reqs=None,
ca_certs=None,
server_hostname=None,
ssl_version=None,
ciphers=None,
ssl_context=None,
ca_cert_dir=None,
key_password=None,
ca_cert_data=None,
tls_in_tls=False,
):
# Assert that our wrapper provided a PyOpenSSLContext
captured["ctx_is_pyopenssl"] = isinstance(ssl_context, PyOpenSSLContext)
# Return a minimal object with a 'connection' attribute expected by wrapper
return types.SimpleNamespace(connection=None)

# Patch underlying urllib3 ssl_wrap_socket used by our wrapper
monkeypatch.setattr(ssw.ssl_, "ssl_wrap_socket", fake_ssl_wrap_socket)

# Call our wrapper without providing ssl_context; it should inject one
ssw.ssl_wrap_socket_with_ocsp(
sock=None,
keyfile=None,
certfile=None,
cert_reqs=None,
ca_certs=None,
server_hostname="localhost",
ssl_version=None,
ciphers=None,
ssl_context=None,
ca_cert_dir=None,
key_password=None,
ca_cert_data=None,
tls_in_tls=False,
)

assert captured.get("ctx_is_pyopenssl") is True
176 changes: 176 additions & 0 deletions test/unit/test_ssl_partial_chain_handshake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Integration-style unit test for partial-chain TLS handshake."""

import socket
import ssl
import tempfile as _tempfile
import threading

import OpenSSL
import pytest

import snowflake.connector.ssl_wrap_socket as ssw # pylint: disable=import-error
from snowflake.connector.constants import OCSPMode # pylint: disable=import-error


@pytest.fixture(autouse=True)
def disable_ocsp_checks():
"""Disable OCSP checks for offline unit testing."""
orig = ssw.FEATURE_OCSP_MODE
ssw.FEATURE_OCSP_MODE = OCSPMode.DISABLE_OCSP_CHECKS
try:
yield
finally:
ssw.FEATURE_OCSP_MODE = orig


def _create_key():
"""Create a new RSA key for certificate generation."""
k = OpenSSL.crypto.PKey()
k.generate_key(OpenSSL.crypto.TYPE_RSA, 2048)
return k


def _create_cert(subject_cn, issuer_cert, issuer_key, is_ca, subject_key, ca=False):
"""Create a certificate signed by issuer or self-signed if issuer is None."""
cert = OpenSSL.crypto.X509()
cert.set_version(2)
cert.set_serial_number(1)
subj = cert.get_subject()
subj.CN = subject_cn
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(60 * 60)
cert.set_pubkey(subject_key)
if issuer_cert is None:
issuer = subj
else:
issuer = issuer_cert.get_subject()
cert.set_issuer(issuer)
if is_ca:
cert.add_extensions(
[
OpenSSL.crypto.X509Extension(
b"basicConstraints", True, b"CA:TRUE, pathlen:1"
),
OpenSSL.crypto.X509Extension(
b"keyUsage", True, b"keyCertSign, cRLSign"
),
OpenSSL.crypto.X509Extension(
b"subjectKeyIdentifier", False, b"hash", subject=cert
),
]
)
else:
cert.add_extensions(
[
OpenSSL.crypto.X509Extension(b"basicConstraints", True, b"CA:FALSE"),
OpenSSL.crypto.X509Extension(b"extendedKeyUsage", False, b"serverAuth"),
OpenSSL.crypto.X509Extension(
b"subjectAltName", False, b"DNS:localhost,IP:127.0.0.1"
),
]
)
if issuer_cert is not None:
cert.add_extensions(
[
OpenSSL.crypto.X509Extension(
b"authorityKeyIdentifier",
False,
b"keyid:always,issuer:always",
issuer=issuer_cert,
),
]
)
cert.sign(issuer_key if issuer_key is not None else subject_key, "sha256")
return cert


def _pem(obj, is_key=False):
"""Return PEM-encoded certificate or key."""
if is_key:
return OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, obj)
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, obj)


def _run_tls_server(server_cert_pem, server_key_pem, chain_pem, ready_evt, addr_holder):
"""Run a minimal TLS server presenting server+intermediate chain."""
# Minimal TLS server using Python ssl to present server+intermediate chain
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Combine server and intermediate into one PEM for certfile
with _tempfile.NamedTemporaryFile(delete=False) as cert_chain_file:
cert_chain_file.write(server_cert_pem)
cert_chain_file.write(chain_pem)
cert_chain_file.flush()
certfile_path = cert_chain_file.name
with _tempfile.NamedTemporaryFile(delete=False) as key_file:
key_file.write(server_key_pem)
key_file.flush()
keyfile_path = key_file.name
ctx.load_cert_chain(certfile=certfile_path, keyfile=keyfile_path)

s = socket.socket()
s.bind(("127.0.0.1", 0))
s.listen(1)
addr = s.getsockname()
addr_holder.append(addr)
ready_evt.set()
with ctx.wrap_socket(s, server_side=True) as ssock:
conn, _ = ssock.accept()
conn.close()
s.close()


def test_partial_chain_handshake_succeeds_with_intermediate_as_anchor():
"""Client should handshake trusting only the intermediate as anchor."""
# Generate Root -> Intermediate -> Server
root_key = _create_key()
root_cert = _create_cert("Root", None, None, True, root_key)

inter_key = _create_key()
inter_cert = _create_cert("Intermediate", root_cert, root_key, True, inter_key)

server_key = _create_key()
server_cert = _create_cert("Server", inter_cert, inter_key, False, server_key)

# Start TLS server presenting server + intermediate chain
ready_evt = threading.Event()
addr_holder = []
t = threading.Thread(
target=_run_tls_server,
args=(
_pem(server_cert),
_pem(server_key, True),
_pem(inter_cert),
ready_evt,
addr_holder,
),
)
t.daemon = True
t.start()
ready_evt.wait(5)
host, port = addr_holder[0]

# Build PyOpenSSL context with only intermediate as trust anchor
ctx = ssw._build_pyopenssl_context_with_ca_and_partial_chain(
None
) # pylint: disable=protected-access
# Load intermediate into store via PEM file path by reusing helper
with _tempfile.NamedTemporaryFile(delete=False) as caf:
caf.write(_pem(inter_cert))
caf.flush()
ctx.load_verify_locations(cafile=caf.name)

# Wrap a socket with our wrapper specifying our context
s = socket.socket()
s.settimeout(5)
s.connect((host, port))

# The wrapper expects kwargs similar to urllib3; use provided context
ws = ssw.ssl_wrap_socket_with_ocsp(
sock=s,
server_hostname="localhost",
ssl_context=ctx,
)
# If we reached here without SSL error, TLS handshake succeeded with
# intermediate-only trust; access attribute to assert presence
assert hasattr(ws, "connection")
s.close()
Loading