diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs index 171980d0a7654a..a2553947286839 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs @@ -24,7 +24,42 @@ internal static partial class OpenSsl private const string TlsCacheSizeCtxName = "System.Net.Security.TlsCacheSize"; private const string TlsCacheSizeEnvironmentVariable = "DOTNET_SYSTEM_NET_SECURITY_TLSCACHESIZE"; private const SslProtocols FakeAlpnSslProtocol = (SslProtocols)1; // used to distinguish server sessions with ALPN - private static readonly ConcurrentDictionary s_clientSslContexts = new ConcurrentDictionary(); + + private sealed class SafeSslContextCache : SafeHandleCache { } + + private static readonly SafeSslContextCache s_clientSslContexts = new(); + + internal readonly struct SslContextCacheKey : IEquatable + { + public readonly byte[]? CertificateThumbprint; + public readonly SslProtocols SslProtocols; + + public SslContextCacheKey(SslProtocols sslProtocols, byte[]? certificateThumbprint) + { + SslProtocols = sslProtocols; + CertificateThumbprint = certificateThumbprint; + } + + public override bool Equals(object? obj) => obj is SslContextCacheKey key && Equals(key); + + public bool Equals(SslContextCacheKey other) => + SslProtocols == other.SslProtocols && + (CertificateThumbprint == null && other.CertificateThumbprint == null || + CertificateThumbprint != null && other.CertificateThumbprint != null && CertificateThumbprint.AsSpan().SequenceEqual(other.CertificateThumbprint)); + + public override int GetHashCode() + { + HashCode hash = default; + + hash.Add(SslProtocols); + if (CertificateThumbprint != null) + { + hash.AddBytes(CertificateThumbprint); + } + + return hash.ToHashCode(); + } + } #region internal methods internal static SafeChannelBindingHandle? QueryChannelBinding(SafeSslHandle context, ChannelBindingKind bindingType) @@ -113,6 +148,54 @@ private static SslProtocols CalculateEffectiveProtocols(SslAuthenticationOptions return protocols; } + internal static SafeSslContextHandle GetOrCreateSslContextHandle(SslAuthenticationOptions sslAuthenticationOptions, bool allowCached) + { + SslProtocols protocols = CalculateEffectiveProtocols(sslAuthenticationOptions); + + if (!allowCached) + { + return AllocateSslContext(sslAuthenticationOptions, protocols, allowCached); + } + + if (sslAuthenticationOptions.IsClient) + { + var key = new SslContextCacheKey(protocols, sslAuthenticationOptions.CertificateContext?.TargetCertificate.GetCertHash(HashAlgorithmName.SHA256)); + + return s_clientSslContexts.GetOrCreate(key, static (args) => + { + var (sslAuthOptions, protocols, allowCached) = args; + return AllocateSslContext(sslAuthOptions, protocols, allowCached); + }, (sslAuthenticationOptions, protocols, allowCached)); + } + + // cache in SslStreamCertificateContext is bounded and there is no eviction + // so the handle should always be valid, + + bool hasAlpn = sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0; + + SafeSslContextHandle? handle = AllocateSslContext(sslAuthenticationOptions, protocols, allowCached); + + if (!sslAuthenticationOptions.CertificateContext!.SslContexts!.TryGetValue(protocols | (hasAlpn ? FakeAlpnSslProtocol : SslProtocols.None), out handle)) + { + // not found in cache, create and insert + handle = AllocateSslContext(sslAuthenticationOptions, protocols, allowCached); + + SafeSslContextHandle cached = sslAuthenticationOptions.CertificateContext!.SslContexts!.GetOrAdd(protocols | (hasAlpn ? FakeAlpnSslProtocol : SslProtocols.None), handle); + + if (handle != cached) + { + // lost the race, another thread created the SSL_CTX meanwhile, prefer the cached one + handle.Dispose(); + Debug.Assert(handle.IsClosed); + handle = cached; + } + } + + Debug.Assert(!handle.IsClosed); + handle.TryAddRentCount(); + return handle; + } + // This essentially wraps SSL_CTX* aka SSL_CTX_new + setting internal static unsafe SafeSslContextHandle AllocateSslContext(SslAuthenticationOptions sslAuthenticationOptions, SslProtocols protocols, bool enableResume) { @@ -188,7 +271,7 @@ internal static unsafe SafeSslContextHandle AllocateSslContext(SslAuthentication Interop.Ssl.SslCtxSetAlpnSelectCb(sslCtx, &AlpnServerSelectCallback, IntPtr.Zero); } - if (sslAuthenticationOptions.CertificateContext != null) + if (sslAuthenticationOptions.CertificateContext != null && sslAuthenticationOptions.IsServer) { SetSslCertificate(sslCtx, sslAuthenticationOptions.CertificateContext.CertificateHandle, sslAuthenticationOptions.CertificateContext.KeyHandle); @@ -257,10 +340,6 @@ internal static void UpdateClientCertificate(SafeSslHandle ssl, SslAuthenticatio internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuthenticationOptions) { SafeSslHandle? sslHandle = null; - SafeSslContextHandle? sslCtxHandle = null; - SafeSslContextHandle? newCtxHandle = null; - SslProtocols protocols = CalculateEffectiveProtocols(sslAuthenticationOptions); - bool hasAlpn = sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0; bool cacheSslContext = sslAuthenticationOptions.AllowTlsResume && !SslStream.DisableTlsResume && sslAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.RequireEncryption && sslAuthenticationOptions.CipherSuitesPolicy == null; if (cacheSslContext) @@ -269,13 +348,12 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth { // We don't support client resume on old OpenSSL versions. // We don't want to try on empty TargetName since that is our key. - // And we don't want to mess up with client authentication. It may be possible - // but it seems safe to get full new session. + // If we already have CertificateContext, then we know which cert the user wants to use and we can cache. + // The only client auth scenario where we can't cache is when user provides a cert callback and we don't know + // beforehand which cert will be used. and wan't to avoid resuming session created with different certificate. if (!Interop.Ssl.Capabilities.Tls13Supported || string.IsNullOrEmpty(sslAuthenticationOptions.TargetHost) || - sslAuthenticationOptions.CertificateContext != null || - sslAuthenticationOptions.ClientCertificates?.Count > 0 || - sslAuthenticationOptions.CertSelectionDelegate != null) + (sslAuthenticationOptions.CertificateContext == null && sslAuthenticationOptions.CertSelectionDelegate != null)) { cacheSslContext = false; } @@ -292,35 +370,14 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth } } - if (cacheSslContext) - { - if (sslAuthenticationOptions.IsServer) - { - sslAuthenticationOptions.CertificateContext!.SslContexts!.TryGetValue(protocols | (hasAlpn ? FakeAlpnSslProtocol : SslProtocols.None), out sslCtxHandle); - } - else - { - - s_clientSslContexts.TryGetValue(protocols, out sslCtxHandle); - } - } - - if (sslCtxHandle == null) - { - // We did not get SslContext from cache - sslCtxHandle = newCtxHandle = AllocateSslContext(sslAuthenticationOptions, protocols, cacheSslContext); - - if (cacheSslContext) - { - bool added = sslAuthenticationOptions.IsServer ? - sslAuthenticationOptions.CertificateContext!.SslContexts!.TryAdd(protocols | (SslProtocols)(hasAlpn ? 1 : 0), newCtxHandle) : - s_clientSslContexts.TryAdd(protocols, newCtxHandle); - if (added) - { - newCtxHandle = null; - } - } - } + // We do not touch the SSL_CTX after we create and configure SSL + // objects, and SSL object created later in this function will keep an + // outstanding up-ref on SSL_CTX. + // + // For uncached SafeSslContextHandles, the handle will be disposed and closed. + // Cached SafeSslContextHandles are returned with increaset rent count so that + // Dispose() here will not close the handle. + using SafeSslContextHandle sslCtxHandle = GetOrCreateSslContextHandle(sslAuthenticationOptions, cacheSslContext); GCHandle alpnHandle = default; try @@ -361,19 +418,25 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth Crypto.ErrClearError(); } - if (cacheSslContext) { sslCtxHandle.TrySetSession(sslHandle, sslAuthenticationOptions.TargetHost); - bool ignored = false; - sslCtxHandle.DangerousAddRef(ref ignored); + + // Maintain additional rent count for the context so + // that it is not evicted from the cache and future + // SSL objects can reuse it. This call should always + // succeed because already have increased rent count + // when getting the context from the cache + bool success = sslCtxHandle.TryAddRentCount(); + Debug.Assert(success); sslHandle.SslContextHandle = sslCtxHandle; } } // relevant to TLS 1.3 only: if user supplied a client cert or cert callback, // advertise that we are willing to send the certificate post-handshake. - if (sslAuthenticationOptions.ClientCertificates?.Count > 0 || + if (sslAuthenticationOptions.CertificateContext != null || + sslAuthenticationOptions.ClientCertificates?.Count > 0 || sslAuthenticationOptions.CertSelectionDelegate != null) { Ssl.SslSetPostHandshakeAuth(sslHandle, 1); @@ -434,10 +497,6 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth throw; } - finally - { - newCtxHandle?.Dispose(); - } return sslHandle; } @@ -708,6 +767,12 @@ private static unsafe int NewSessionCallback(IntPtr ssl, IntPtr session) Debug.Assert(ssl != IntPtr.Zero); Debug.Assert(session != IntPtr.Zero); + // remember if the session used a certificate, this information is used after + // session resumption, the pointer is not being dereferenced and the refcount + // is not going to be manipulated. + IntPtr cert = Interop.Ssl.SslGetCertificate(ssl); + Interop.Ssl.SslSessionSetData(session, cert); + IntPtr ptr = Ssl.SslGetData(ssl); if (ptr != IntPtr.Zero) { diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs index e1f2dfdc1f23e4..9f3d05e43e96af 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs @@ -116,6 +116,12 @@ internal static unsafe ReadOnlySpan SslGetAlpnSelected(SafeSslHandle ssl) [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerCertificate")] internal static partial IntPtr SslGetPeerCertificate(SafeSslHandle ssl); + [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetCertificate")] + internal static partial IntPtr SslGetCertificate(SafeSslHandle ssl); + + [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetCertificate")] + internal static partial IntPtr SslGetCertificate(IntPtr ssl); + [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerCertChain")] internal static partial SafeSharedX509StackHandle SslGetPeerCertChain(SafeSslHandle ssl); @@ -129,6 +135,9 @@ internal static unsafe ReadOnlySpan SslGetAlpnSelected(SafeSslHandle ssl) [return: MarshalAs(UnmanagedType.Bool)] internal static partial bool SslSessionReused(SafeSslHandle ssl); + [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetSession")] + internal static partial IntPtr SslGetSession(SafeSslHandle ssl); + [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetClientCAList")] private static partial SafeSharedX509NameStackHandle SslGetClientCAList_private(SafeSslHandle ssl); @@ -182,6 +191,12 @@ internal static unsafe ReadOnlySpan SslGetAlpnSelected(SafeSslHandle ssl) [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionSetHostname")] internal static partial int SessionSetHostname(IntPtr session, IntPtr name); + [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionGetData")] + internal static partial IntPtr SslSessionGetData(IntPtr session); + + [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionSetData")] + internal static partial void SslSessionSetData(IntPtr session, IntPtr val); + internal static class Capabilities { // needs separate type (separate static cctor) to be sure OpenSSL is initialized. @@ -430,7 +445,9 @@ protected override bool ReleaseHandle() Disconnect(); } - SslContextHandle?.DangerousRelease(); + // drop reference to any SSL_CTX handle, any handle present here is being + // rented from (client) SSL_CTX cache. + SslContextHandle?.Dispose(); if (AlpnHandle.IsAllocated) { diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtx.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtx.cs index d92e15e940e65e..bad664886dac07 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtx.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.SslCtx.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Net; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Diagnostics; @@ -9,6 +10,7 @@ using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; using System.Text; +using System.Threading; using Microsoft.Win32.SafeHandles; internal static partial class Interop @@ -65,12 +67,17 @@ internal static bool AddExtraChainCertificates(SafeSslContextHandle ctx, ReadOnl namespace Microsoft.Win32.SafeHandles { - internal sealed class SafeSslContextHandle : SafeHandle + internal sealed class SafeSslContextHandle : SafeHandle, ISafeHandleCachable { // This is session cache keyed by SNI e.g. TargetHost private Dictionary? _sslSessions; private GCHandle _gch; + // SSL_CTX handles are cached, so we need to keep track of the + // number of times a handle is being used. Once we decide to dispose the handle, + // we set the _rentCount to -1. + private volatile int _rentCount; + public SafeSslContextHandle() : base(IntPtr.Zero, true) { @@ -86,6 +93,38 @@ public override bool IsInvalid get { return handle == IntPtr.Zero; } } + public bool TryAddRentCount() + { + int oldCount; + + do + { + oldCount = _rentCount; + if (oldCount < 0) + { + // The handle is already disposed. + return false; + } + } while (Interlocked.CompareExchange(ref _rentCount, oldCount + 1, oldCount) != oldCount); + + return true; + } + + public bool TryMarkForDispose() + { + return Interlocked.CompareExchange(ref _rentCount, -1, 0) == 0; + } + + protected override void Dispose(bool disposing) + { + if (Interlocked.Decrement(ref _rentCount) < 0) + { + // _rentCount is 0 if the handle was never rented (e.g. failure during creation), + // and is -1 when evicted from cache. + base.Dispose(disposing); + } + } + protected override bool ReleaseHandle() { if (_sslSessions != null) diff --git a/src/libraries/Common/src/System/Net/SafeHandleCache.cs b/src/libraries/Common/src/System/Net/SafeHandleCache.cs new file mode 100644 index 00000000000000..5fd9b6855a4101 --- /dev/null +++ b/src/libraries/Common/src/System/Net/SafeHandleCache.cs @@ -0,0 +1,157 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Globalization; +using System.Runtime.InteropServices; + +namespace System.Net +{ + internal interface ISafeHandleCachable + { + // Attempts to resever the handle for use. If the handle is already + // disposed (or scheduled to be disposed), this will return false. + // + // each successful call to TryAddRentCount() must be paired with a Dispose() call. + bool TryAddRentCount(); + + // Marks the handle as scheduled for disposal if it is not being used. + // Returns false if the handle is currently being used. + // once marked, no new renters are allowed. + bool TryMarkForDispose(); + } + + /// + /// Helper class for implementing a cache for types deriving from . The purpose of the cache is to allow reuse of + /// resources which may enable additional features (such as TLS resumption). + /// The cache handles insertion and eviction in a thread-safe manner and + /// implements simple mechanism for preventing unbounded growth and memory + /// leaks. + /// + internal class SafeHandleCache where TKey : IEquatable where THandle : SafeHandle, ISafeHandleCachable + { + private const int CheckExpiredModulo = 32; + + private readonly ConcurrentDictionary _cache = new(); + + /// + /// Gets the handle from the cache if it exists, otherwise creates a new one using the + /// provided factory function and context. + /// + /// In case of two racing inserts with the same key, the handle returned by the factory may + /// end up being discarded in favor of the one that was inserted first. In such case, the + /// factory handle is disposed and the cached handle is returned. + /// + /// The handle returned from this function should be disposed exactly once when it is no + /// longer needed. + /// + internal THandle GetOrCreate(TKey key, Func factory, TContext factoryContext) + { + if (_cache.TryGetValue(key, out THandle? handle) && handle.TryAddRentCount()) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(this, $"Found cached {handle}."); + } + return handle; + } + + // if we get here, the handle is either not in the cache, or we lost + // the race between TryAddRentCount on this thread and + // MarkForDispose on another thread doing cache cleanup. In either + // case, we need to create a new handle. + + handle = factory(factoryContext); + handle.TryAddRentCount(); // The caler is the first renter + + THandle cached; + do + { + cached = _cache.GetOrAdd(key, handle); + } + // If we get the same handle back, we successfully added it to the cache and we are done. + // If we get a different handle back, we need to increase the rent count. + // If we fail to add the rent count, then the existing/cached handle is in process of + // being removed from the cache and we can try again, eventually either succeeding to + // add our new handle or getting a fresh handle inserted by another thread meanwhile. + while (cached != handle && !cached.TryAddRentCount()); + + if (cached != handle) + { + // we lost a race with another thread to insert new handle into the cache + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(this, $"Discarding {handle} (preferring cached {cached})."); + } + + // First dispose decrements the rent count we added before attempting the cache insertion + // and second closes the handle + handle.Dispose(); + handle.Dispose(); + Debug.Assert(handle.IsClosed); + + return cached; + } + + CheckForCleanup(); + + return handle; + } + + private void CheckForCleanup() + { + // We check the cache size after every couple of insertions, and + // discard all handles which are not being actively rented. This + // should still be flexible enough to allow "stable set" of + // arbitrary size, while still preventing unbounded growth. + + var count = _cache.Count; + if (count % CheckExpiredModulo == 0) + { + // let only one thread perform cleanup at a time + lock (_cache) + { + // check again, if another thread just cleaned up (and cached count went down) we are unlikely + // to clean anything + if (_cache.Count >= count) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(this, $"Current size: {_cache.Count}."); + } + + foreach ((TKey key, THandle handle) in _cache) + { + if (!handle.TryMarkForDispose()) + { + // handle in use + continue; + } + + // the handle is not in use and has been marked such that no new rents can be added. + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(this, $"Evicting cached {handle}."); + } + + bool removed = _cache.TryRemove(key, out _); + Debug.Assert(removed); + handle.Dispose(); + + // Since the handle is not used anywhere, this should close the handle + Debug.Assert(handle.IsClosed); + } + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(this, $"New size: {_cache.Count}."); + } + } + } + } + } + } +} diff --git a/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj b/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj index 56f1d2837ac672..3dc678ce8442cf 100644 --- a/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj +++ b/src/libraries/System.Net.Quic/src/System.Net.Quic.csproj @@ -29,6 +29,7 @@ + diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.Cache.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.Cache.cs index 38a02cad2328b5..4db0f88bf71a48 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.Cache.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.Cache.cs @@ -15,8 +15,6 @@ namespace System.Net.Quic; internal static partial class MsQuicConfiguration { - private const int CheckExpiredModulo = 32; - private const string DisableCacheEnvironmentVariable = "DOTNET_SYSTEM_NET_QUIC_DISABLE_CONFIGURATION_CACHE"; private const string DisableCacheCtxSwitch = "System.Net.Quic.DisableConfigurationCache"; @@ -38,7 +36,12 @@ private static bool GetConfigurationCacheEnabled() // enabled by default return true; } - private static readonly ConcurrentDictionary s_configurationCache = new(); + + private static readonly MsQuicConfigurationCache s_configurationCache = new MsQuicConfigurationCache(); + + private sealed class MsQuicConfigurationCache : SafeHandleCache + { + } private readonly struct CacheKey : IEquatable { @@ -130,107 +133,10 @@ private static MsQuicConfigurationSafeHandle GetCachedCredentialOrCreate(QUIC_SE { CacheKey key = new CacheKey(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites); - MsQuicConfigurationSafeHandle? handle; - - if (s_configurationCache.TryGetValue(key, out handle) && handle.TryAddRentCount()) - { - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(null, $"Found cached MsQuicConfiguration: {handle}."); - } - return handle; - } - - // if we get here, the handle is either not in the cache, or we lost the race between - // TryAddRentCount on this thread and MarkForDispose on another thread doing cache cleanup. - // In either case, we need to create a new handle. - - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(null, $"MsQuicConfiguration not found in cache, creating new."); - } - - handle = CreateInternal(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites); - handle.TryAddRentCount(); // we are the first renter - - MsQuicConfigurationSafeHandle cached; - do - { - cached = s_configurationCache.GetOrAdd(key, handle); - } - // If we get the same handle back, we successfully added it to the cache and we are done. - // If we get a different handle back, we need to increase the rent count. - // If we fail to add the rent count, then the existing/cached handle is in process of - // being removed from the cache and we can try again, eventually either succeeding to add our - // new handle or getting a fresh handle inserted by another thread meanwhile. - while (cached != handle && !cached.TryAddRentCount()); - - if (cached != handle) - { - // we lost a race with another thread to insert new handle into the cache - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(null, $"Discarding MsQuicConfiguration {handle} (preferring cached {cached})."); - } - - // First dispose decrements the rent count we added before attempting the cache insertion - // and second closes the handle - handle.Dispose(); - handle.Dispose(); - Debug.Assert(handle.IsClosed); - - return cached; - } - - // we added a new handle, check if we need to cleanup - var count = s_configurationCache.Count; - if (count % CheckExpiredModulo == 0) - { - // let only one thread perform cleanup at a time - lock (s_configurationCache) - { - // check again, if another thread just cleaned up (and cached count went down) we are unlikely - // to clean anything - if (s_configurationCache.Count >= count) - { - CleanupCache(); - } - } - } - - return handle; - } - - private static void CleanupCache() - { - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(null, $"Cleaning up MsQuicConfiguration cache, current size: {s_configurationCache.Count}."); - } - - foreach ((CacheKey key, MsQuicConfigurationSafeHandle handle) in s_configurationCache) - { - if (!handle.TryMarkForDispose()) - { - // handle in use - continue; - } - - // the handle is not in use and has been marked such that no new rents can be added. - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Info(null, $"Removing cached MsQuicConfiguration {handle}."); - } - - bool removed = s_configurationCache.TryRemove(key, out _); - Debug.Assert(removed); - handle.Dispose(); - Debug.Assert(handle.IsClosed); - } - - if (NetEventSource.Log.IsEnabled()) + return s_configurationCache.GetOrCreate(key, static (args) => { - NetEventSource.Info(null, $"Cleaning up MsQuicConfiguration cache, new size: {s_configurationCache.Count}."); - } + var (settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites) = args; + return CreateInternal(settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites); + }, (settings, flags, certificate, intermediates, alpnProtocols, allowedCipherSuites)); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs index 8e70ec20572454..3015eff1767be0 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs @@ -145,7 +145,7 @@ protected override unsafe bool ReleaseHandle() } } -internal sealed class MsQuicConfigurationSafeHandle : MsQuicSafeHandle +internal sealed class MsQuicConfigurationSafeHandle : MsQuicSafeHandle, ISafeHandleCachable { // MsQuicConfiguration handles are cached, so we need to keep track of the // number of times a handle is rented. Once we decide to dispose the handle, diff --git a/src/libraries/System.Net.Security/src/System.Net.Security.csproj b/src/libraries/System.Net.Security/src/System.Net.Security.csproj index e6650f80670e17..2e263173ad81de 100644 --- a/src/libraries/System.Net.Security/src/System.Net.Security.csproj +++ b/src/libraries/System.Net.Security/src/System.Net.Security.csproj @@ -353,8 +353,10 @@ Link="Common\Interop\Unix\System.Security.Cryptography.Native\Interop.OCSP.cs" /> + + Link="Common\Interop\Unix\System.Security.Cryptography.Native\Interop.OpenSslVersion.cs" /> true; + internal static bool IsLocalCertificateUsed(SafeFreeCredentials? _1, SafeDeleteContext? ctx) + { + if (ctx is not SafeSslHandle ssl) + { + return false; + } + + if (!Interop.Ssl.SslSessionReused(ssl)) + { + // Fresh session, we set the certificate on the SSL object only + // if the peer explicitly requested it. + return Interop.Ssl.SslGetCertificate(ssl) != IntPtr.Zero; + } + + // resumed session, we keep the information about cert being used in the SSL_SESSION + // object's ex_data + bool addref = false; + try + { + // make sure the ssl is not freed while we accessing its SSL_SESSION + // this makes sure the `session` pointer is valid during this call + // despite not being a SafeHandle. + ssl.DangerousAddRef(ref addref); + + // the information about certificate usage is stored in the session ex data + IntPtr session = Interop.Ssl.SslGetSession(ssl); + Debug.Assert(session != IntPtr.Zero); + return Interop.Ssl.SslSessionGetData(session) != IntPtr.Zero; + } + finally + { + if (addref) + { + ssl.DangerousRelease(); + } + } + } // // Used only by client SSL code, never returns null. diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs index 1ebe6467d0f63d..0aa6238c847518 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs @@ -17,7 +17,7 @@ public static Exception GetException(SecurityStatusPal status) return status.Exception ?? new Interop.OpenSsl.SslException((int)status.ErrorCode); } - internal const bool StartMutualAuthAsAnonymous = true; + internal const bool StartMutualAuthAsAnonymous = false; internal const bool CanEncryptEmptyMessage = false; public static void VerifyPackageInfo() @@ -168,8 +168,8 @@ public static bool TryUpdateClintCertificate( return true; } - private static ProtocolToken HandshakeInternal(ref SafeDeleteSslContext? context, - ReadOnlySpan inputBuffer, SslAuthenticationOptions sslAuthenticationOptions) + private static ProtocolToken HandshakeInternal(ref SafeDeleteSslContext? context, + ReadOnlySpan inputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { ProtocolToken token = default; token.RentBuffer = true; @@ -186,8 +186,20 @@ private static ProtocolToken HandshakeInternal(ref SafeDeleteSslContext? context { // this should happen only for clients Debug.Assert(sslAuthenticationOptions.IsClient); - token.Status = new SecurityStatusPal(errorCode); - return token; + + // if we don't have a client certificate ready, bubble up so + // that the certificate selection routine runs again. This + // happens if the first call to LocalCertificateSelectionCallback + // returns null. + if (sslAuthenticationOptions.CertificateContext == null) + { + token.Status = new SecurityStatusPal(SecurityStatusPalErrorCode.CredentialsNeeded); + return token; + } + + // set the cert and continue + TryUpdateClintCertificate(null, context, sslAuthenticationOptions); + errorCode = Interop.OpenSsl.DoSslHandshake((SafeSslHandle)context, ReadOnlySpan.Empty, ref token); } // sometimes during renegotiation processing message does not yield new output. diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamAllowTlsResumeTests.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamAllowTlsResumeTests.cs index 494250e92e3b67..bc0df20697b38b 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamAllowTlsResumeTests.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamAllowTlsResumeTests.cs @@ -1,18 +1,23 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; using System.Reflection; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; +using System.Linq; using Xunit; using Microsoft.DotNet.XUnitExtensions; +using System.Net.Test.Common; + #if DEBUG namespace System.Net.Security.Tests { using Configuration = System.Net.Test.Common.Configuration; + [PlatformSpecific(TestPlatforms.Windows | TestPlatforms.Linux)] public class SslStreamTlsResumeTests { private static FieldInfo connectionInfo = typeof(SslStream).GetField( @@ -29,8 +34,7 @@ private bool CheckResumeFlag(SslStream ssl) [ConditionalTheory] [InlineData(true)] [InlineData(false)] - [PlatformSpecific(TestPlatforms.Windows | TestPlatforms.Linux)] - public async Task SslStream_ClientDisableTlsResume_Succeeds(bool testClient) + public async Task ClientDisableTlsResume_Succeeds(bool testClient) { SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions { @@ -128,6 +132,285 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout( client.Dispose(); server.Dispose(); } + + [Theory] + [MemberData(nameof(SslProtocolsData))] + public Task NoClientCert_DefaultValue_ResumeSucceeds(SslProtocols sslProtocol) + { + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions + { + EnabledSslProtocols = sslProtocol, + ServerCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetServerCertificate(), null, false) + }; + + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions + { + TargetHost = Guid.NewGuid().ToString("N"), + EnabledSslProtocols = sslProtocol, + CertificateRevocationCheckMode = X509RevocationMode.NoCheck, + RemoteCertificateValidationCallback = (sender, cert, chain, errors) => true, + }; + + return ResumeSucceedsInternal(serverOptions, clientOptions); + } + + public static TheoryData SslProtocolsData() + { + var data = new TheoryData(); + + data.Add(SslProtocols.None); + + if (PlatformDetection.SupportsTls12) + { + data.Add(SslProtocols.Tls12); + } + + if (PlatformDetection.SupportsTls13) + { + data.Add(SslProtocols.Tls13); + } + + return data; + } + + public enum ClientCertSource + { + ClientCertificate, + SelectionCallback, + CertificateContext + } + + public static TheoryData ClientCertTestData() + { + var data = new TheoryData(); + + foreach (SslProtocols protocol in SslProtocolsData().Select(x => x[0])) + foreach (bool certRequired in new[] { true, false }) + foreach (ClientCertSource source in Enum.GetValues(typeof(ClientCertSource))) + { + data.Add(protocol, certRequired, source); + } + + return data; + } + + [Theory] + [MemberData(nameof(ClientCertTestData))] + public Task ClientCert_DefaultValue_ResumeSucceeds(SslProtocols sslProtocol, bool certificateRequired, ClientCertSource certSource) + { + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions + { + EnabledSslProtocols = sslProtocol, + ServerCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetServerCertificate(), null, false), + RemoteCertificateValidationCallback = (sender, cert, chain, errors) => true, + ClientCertificateRequired = certificateRequired, + }; + + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions + { + TargetHost = Guid.NewGuid().ToString("N"), + EnabledSslProtocols = sslProtocol, + CertificateRevocationCheckMode = X509RevocationMode.NoCheck, + RemoteCertificateValidationCallback = (sender, cert, chain, errors) => true, + }; + + X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate(); + + switch (certSource) + { + case ClientCertSource.ClientCertificate: + clientOptions.ClientCertificates = new X509CertificateCollection() { clientCertificate }; + break; + case ClientCertSource.SelectionCallback: + clientOptions.LocalCertificateSelectionCallback = delegate { return clientCertificate; }; + break; + case ClientCertSource.CertificateContext: + clientOptions.ClientCertificateContext = SslStreamCertificateContext.Create(clientCertificate, new()); + break; + } + + return ResumeSucceedsInternal(serverOptions, clientOptions); + } + + private async Task ResumeSucceedsInternal(SslServerAuthenticationOptions serverOptions, SslClientAuthenticationOptions clientOptions) + { + // no resume on the first run + await RunConnectionAsync(serverOptions, clientOptions, false); + + for (int i = 0; i < 3; i++) + { + // create new TLS to the same server. This should resume TLS. + await RunConnectionAsync(serverOptions, clientOptions, true); + } + } + + private async Task RunConnectionAsync(SslServerAuthenticationOptions serverOptions, SslClientAuthenticationOptions clientOptions, bool? expectResume) + { + (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(); + using (client) + using (server) + { + await TestConfiguration.WhenAllOrAnyFailedWithTimeout( + client.AuthenticateAsClientAsync(clientOptions), + server.AuthenticateAsServerAsync(serverOptions)); + + if (expectResume.HasValue) + { + Assert.Equal(expectResume.Value, CheckResumeFlag(client)); + Assert.Equal(expectResume.Value, CheckResumeFlag(server)); + } + + await TestHelper.PingPong(client, server); + + await client.ShutdownAsync(); + await server.ShutdownAsync(); + } + } + + [Theory] + [MemberData(nameof(SslProtocolsData))] + public Task ClientChangeCert_NoResume(SslProtocols sslProtocol) + { + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions + { + EnabledSslProtocols = sslProtocol, + ServerCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetServerCertificate(), null, false), + RemoteCertificateValidationCallback = (sender, cert, chain, errors) => true, + ClientCertificateRequired = true, + }; + + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions + { + TargetHost = Guid.NewGuid().ToString("N"), + EnabledSslProtocols = sslProtocol, + CertificateRevocationCheckMode = X509RevocationMode.NoCheck, + RemoteCertificateValidationCallback = (sender, cert, chain, errors) => true, + ClientCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetClientCertificate(), null, false) + }; + + return TestNoResumeAfterChange(serverOptions, clientOptions, + (clientOps, _) => clientOps.ClientCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetSelfSignedClientCertificate(), null, false), + (clientOps, _) => clientOps.ClientCertificateContext = null); + } + + [Theory] + [MemberData(nameof(SslProtocolsData))] + public Task DifferentHost_NoResume(SslProtocols sslProtocol) + { + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions + { + EnabledSslProtocols = sslProtocol, + ServerCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetServerCertificate(), null, false) + }; + + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions + { + TargetHost = Guid.NewGuid().ToString("N"), + EnabledSslProtocols = sslProtocol, + CertificateRevocationCheckMode = X509RevocationMode.NoCheck, + RemoteCertificateValidationCallback = (sender, cert, chain, errors) => true, + ClientCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetClientCertificate(), null, false) + }; + + return TestNoResumeAfterChange(serverOptions, clientOptions, + (clientOps, _) => clientOps.TargetHost = Guid.NewGuid().ToString("N")); + } + + [Fact] + public Task DifferentProtocol_NoResume() + { + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions + { + ServerCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetServerCertificate(), null, false) + }; + + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions + { + TargetHost = Guid.NewGuid().ToString("N"), + EnabledSslProtocols = SslProtocols.Tls12, + CertificateRevocationCheckMode = X509RevocationMode.NoCheck, + RemoteCertificateValidationCallback = (sender, cert, chain, errors) => true, + }; + + return TestNoResumeAfterChange(serverOptions, clientOptions, + (clientOps, _) => clientOps.EnabledSslProtocols = SslProtocols.None); + } + + [Fact] + [PlatformSpecific(TestPlatforms.Windows)] + public Task DifferentRevocationCheckMode_NoResume() + { + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions + { + ServerCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetServerCertificate(), null, false) + }; + + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions + { + TargetHost = Guid.NewGuid().ToString("N"), + RemoteCertificateValidationCallback = (sender, cert, chain, errors) => true, + CertificateRevocationCheckMode = X509RevocationMode.Offline, + }; + + return TestNoResumeAfterChange(serverOptions, clientOptions, + (clientOps, _) => clientOps.CertificateRevocationCheckMode = X509RevocationMode.NoCheck); + } + + [Fact] + public Task DifferentEncryptionPolicy_NoResume() + { + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions + { + EnabledSslProtocols = SslProtocols.Tls12, + ServerCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetServerCertificate(), null, false) + }; + + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions + { + TargetHost = Guid.NewGuid().ToString("N"), + CertificateRevocationCheckMode = X509RevocationMode.NoCheck, + RemoteCertificateValidationCallback = (sender, cert, chain, errors) => true, + }; +#pragma warning disable SYSLIB0040 // 'AllowNoEncryption' is obsolete + return TestNoResumeAfterChange(serverOptions, clientOptions, + (clientOps, _) => clientOps.EncryptionPolicy = EncryptionPolicy.AllowNoEncryption); +#pragma warning restore SYSLIB0040 // 'AllowNoEncryption' is obsolete + } + + [Fact] + [PlatformSpecific(TestPlatforms.Linux)] // CipherSuitesPolicy is suppoted only on Linux + public Task DifferentCipherSuitesPolicy_NoResume() + { + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions + { + ServerCertificateContext = SslStreamCertificateContext.Create(Configuration.Certificates.GetServerCertificate(), null, false) + }; + + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions + { + TargetHost = Guid.NewGuid().ToString("N"), + CertificateRevocationCheckMode = X509RevocationMode.NoCheck, + RemoteCertificateValidationCallback = (sender, cert, chain, errors) => true, + }; + + return TestNoResumeAfterChange(serverOptions, clientOptions, + (clientOps, _) => clientOps.CipherSuitesPolicy = new CipherSuitesPolicy(new[] { TlsCipherSuite.TLS_AES_128_GCM_SHA256 })); + } + + private async Task TestNoResumeAfterChange(SslServerAuthenticationOptions serverOptions, SslClientAuthenticationOptions clientOptions, params Action[] updateOptions) + { + // confirm sessions are resumable and prime for resumption + await RunConnectionAsync(serverOptions, clientOptions, false); + await RunConnectionAsync(serverOptions, clientOptions, true); + + foreach (Action update in updateOptions) + { + update(clientOptions, serverOptions); + + // after changing options, the session should not be resumed + await RunConnectionAsync(serverOptions, clientOptions, false); + } + } } } #endif diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamMutualAuthenticationTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamMutualAuthenticationTest.cs index 7a858a644bab3e..e8ddb07f82df76 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamMutualAuthenticationTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamMutualAuthenticationTest.cs @@ -266,7 +266,7 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout( } else { - Assert.Null(server.RemoteCertificate); + Assert.Null(server.RemoteCertificate); } }; } @@ -320,7 +320,7 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout( } else { - Assert.Null(server.RemoteCertificate); + Assert.Null(server.RemoteCertificate); } }; } @@ -357,7 +357,7 @@ public async Task SslStream_ResumedSessionsCallbackMaybeSet_IsMutuallyAuthentica if (expectMutualAuthentication) { - clientOptions.LocalCertificateSelectionCallback = (s, t, l, r, a) => _clientCertificate; + clientOptions.LocalCertificateSelectionCallback = (s, t, l, r, a) => _clientCertificate; } else { @@ -378,7 +378,7 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout( } else { - Assert.Null(server.RemoteCertificate); + Assert.Null(server.RemoteCertificate); } }; } diff --git a/src/libraries/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj b/src/libraries/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj index 5602fc0cab3562..46e18e22b86ae4 100644 --- a/src/libraries/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj +++ b/src/libraries/System.Net.Security/tests/UnitTests/System.Net.Security.Unit.Tests.csproj @@ -87,8 +87,10 @@ Link="Common\Interop\Unix\System.Security.Cryptography.Native\Interop.OCSP.cs" /> + + Link="Common\Interop\Unix\System.Security.Cryptography.Native\Interop.OpenSslVersion.cs" /> = OPENSSL_VERSION_1_1_0_RTM const CRYPTO_EX_DATA* from, @@ -1260,6 +1261,52 @@ static int ExDataDup( return 1; } +static void ExDataFreeNoOp( + void* parent, + void* ptr, + CRYPTO_EX_DATA* ad, + int idx, + long argl, + void* argp) +{ + (void)parent; + (void)ptr; + (void)ad; + (void)idx; + (void)argl; + (void)argp; + + // do nothing. +} + +static int ExDataDupNoOp( + CRYPTO_EX_DATA* to, +#if OPENSSL_VERSION_NUMBER >= OPENSSL_VERSION_1_1_0_RTM + const CRYPTO_EX_DATA* from, +#else + CRYPTO_EX_DATA* from, +#endif +#if OPENSSL_VERSION_NUMBER >= OPENSSL_VERSION_3_0_RTM + void** from_d, +#else + void* from_d, +#endif + int idx, + long argl, + void* argp) +{ + (void)to; + (void)from; + (void)from_d; + (void)idx; + (void)argl; + (void)argp; + + // do nothing, this should lead to copy of the pointer being stored in the + // destination, we treat the ptr as an opaque blob. + return 1; +} + void CryptoNative_RegisterLegacyAlgorithms(void) { #ifdef NEED_OPENSSL_3_0 @@ -1393,7 +1440,9 @@ static int32_t EnsureOpenSsl10Initialized(void) ERR_load_crypto_strings(); // In OpenSSL 1.0.2-, CRYPTO_EX_INDEX_X509 is 10. - g_x509_ocsp_index = CRYPTO_get_ex_new_index(10, 0, NULL, NULL, ExDataDup, ExDataFree); + g_x509_ocsp_index = CRYPTO_get_ex_new_index(10, 0, NULL, NULL, ExDataDupOcspResponse, ExDataFreeOcspResponse); + // In OpenSSL 1.0.2-, CRYPTO_EX_INDEX_SSL_SESSION is 3. + g_ssl_sess_cert_index = CRYPTO_get_ex_new_index(3, 0, NULL, NULL, ExDataDupNoOp, ExDataFreeNoOp); done: if (ret != 0) @@ -1461,7 +1510,9 @@ static int32_t EnsureOpenSsl11Initialized(void) atexit(HandleShutdown); // In OpenSSL 1.1.0+, CRYPTO_EX_INDEX_X509 is 3. - g_x509_ocsp_index = CRYPTO_get_ex_new_index(3, 0, NULL, NULL, ExDataDup, ExDataFree); + g_x509_ocsp_index = CRYPTO_get_ex_new_index(3, 0, NULL, NULL, ExDataDupOcspResponse, ExDataFreeOcspResponse); + // In OpenSSL 1.1.0+, CRYPTO_EX_INDEX_SSL_SESSION is 2. + g_ssl_sess_cert_index = CRYPTO_get_ex_new_index(2, 0, NULL, NULL, ExDataDupNoOp, ExDataFreeNoOp); return 0; } @@ -1480,6 +1531,7 @@ int32_t CryptoNative_OpenSslAvailable(void) static int32_t g_initStatus = 1; int g_x509_ocsp_index = -1; +int g_ssl_sess_cert_index = -1; static int32_t EnsureOpenSslInitializedCore(void) { @@ -1510,6 +1562,7 @@ static int32_t EnsureOpenSslInitializedCore(void) // On OpenSSL 1.0.2 our expected index is 0. // On OpenSSL 1.1.0+ 0 is a reserved value and we expect 1. assert(g_x509_ocsp_index != -1); + assert(g_ssl_sess_cert_index != -1); } return ret; diff --git a/src/native/libs/System.Security.Cryptography.Native/opensslshim.h b/src/native/libs/System.Security.Cryptography.Native/opensslshim.h index c4aa47d18cfae2..cf24f810bb6e8e 100644 --- a/src/native/libs/System.Security.Cryptography.Native/opensslshim.h +++ b/src/native/libs/System.Security.Cryptography.Native/opensslshim.h @@ -589,6 +589,7 @@ extern bool g_libSslUses32BitTime; REQUIRED_FUNCTION(SSL_get_version) \ LIGHTUP_FUNCTION(SSL_get0_alpn_selected) \ RENAMED_FUNCTION(SSL_get1_peer_certificate, SSL_get_peer_certificate) \ + REQUIRED_FUNCTION(SSL_get_certificate) \ LEGACY_FUNCTION(SSL_library_init) \ LEGACY_FUNCTION(SSL_load_error_strings) \ REQUIRED_FUNCTION(SSL_new) \ @@ -597,6 +598,8 @@ extern bool g_libSslUses32BitTime; REQUIRED_FUNCTION(SSL_renegotiate) \ REQUIRED_FUNCTION(SSL_renegotiate_pending) \ REQUIRED_FUNCTION(SSL_SESSION_free) \ + REQUIRED_FUNCTION(SSL_SESSION_get_ex_data) \ + REQUIRED_FUNCTION(SSL_SESSION_set_ex_data) \ LIGHTUP_FUNCTION(SSL_SESSION_get0_hostname) \ LIGHTUP_FUNCTION(SSL_SESSION_set1_hostname) \ FALLBACK_FUNCTION(SSL_session_reused) \ @@ -609,6 +612,7 @@ extern bool g_libSslUses32BitTime; REQUIRED_FUNCTION(SSL_set_ex_data) \ FALLBACK_FUNCTION(SSL_set_options) \ REQUIRED_FUNCTION(SSL_set_session) \ + REQUIRED_FUNCTION(SSL_get_session) \ REQUIRED_FUNCTION(SSL_set_verify) \ REQUIRED_FUNCTION(SSL_shutdown) \ LEGACY_FUNCTION(SSL_state) \ @@ -1109,6 +1113,7 @@ extern TYPEOF(OPENSSL_gmtime)* OPENSSL_gmtime_ptr; #define SSL_free SSL_free_ptr #define SSL_get_ciphers SSL_get_ciphers_ptr #define SSL_get_client_CA_list SSL_get_client_CA_list_ptr +#define SSL_get_certificate SSL_get_certificate_ptr #define SSL_get_current_cipher SSL_get_current_cipher_ptr #define SSL_get_error SSL_get_error_ptr #define SSL_get_ex_data SSL_get_ex_data_ptr @@ -1133,6 +1138,8 @@ extern TYPEOF(OPENSSL_gmtime)* OPENSSL_gmtime_ptr; #define SSL_SESSION_get0_hostname SSL_SESSION_get0_hostname_ptr #define SSL_SESSION_set1_hostname SSL_SESSION_set1_hostname_ptr #define SSL_session_reused SSL_session_reused_ptr +#define SSL_SESSION_get_ex_data SSL_SESSION_get_ex_data_ptr +#define SSL_SESSION_set_ex_data SSL_SESSION_set_ex_data_ptr #define SSL_set_accept_state SSL_set_accept_state_ptr #define SSL_set_bio SSL_set_bio_ptr #define SSL_set_cert_cb SSL_set_cert_cb_ptr @@ -1142,6 +1149,7 @@ extern TYPEOF(OPENSSL_gmtime)* OPENSSL_gmtime_ptr; #define SSL_set_ex_data SSL_set_ex_data_ptr #define SSL_set_options SSL_set_options_ptr #define SSL_set_session SSL_set_session_ptr +#define SSL_get_session SSL_get_session_ptr #define SSL_set_verify SSL_set_verify_ptr #define SSL_shutdown SSL_shutdown_ptr #define SSL_state SSL_state_ptr diff --git a/src/native/libs/System.Security.Cryptography.Native/pal_ssl.c b/src/native/libs/System.Security.Cryptography.Native/pal_ssl.c index e320d1c73d776b..dc0478ef2a230e 100644 --- a/src/native/libs/System.Security.Cryptography.Native/pal_ssl.c +++ b/src/native/libs/System.Security.Cryptography.Native/pal_ssl.c @@ -595,6 +595,11 @@ X509* CryptoNative_SslGetPeerCertificate(SSL* ssl) return cert; } +X509* CryptoNative_SslGetCertificate(SSL* ssl) +{ + return SSL_get_certificate(ssl); +} + X509Stack* CryptoNative_SslGetPeerCertChain(SSL* ssl) { // No error queue impact. @@ -711,6 +716,11 @@ const char* CryptoNative_SslGetServerName(SSL* ssl) return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); } +SSL_SESSION* CryptoNative_SslGetSession(SSL* ssl) +{ + return SSL_get_session(ssl); +} + int32_t CryptoNative_SslSetSession(SSL* ssl, SSL_SESSION* session) { return SSL_set_session(ssl, session); @@ -748,6 +758,16 @@ int CryptoNative_SslSessionSetHostname(SSL_SESSION* session, const char* hostnam return 0; } +void CryptoNative_SslSessionSetData(SSL_SESSION* session, void* val) +{ + SSL_SESSION_set_ex_data(session, g_ssl_sess_cert_index, val); +} + +void* CryptoNative_SslSessionGetData(SSL_SESSION* session) +{ + return SSL_SESSION_get_ex_data(session, g_ssl_sess_cert_index); +} + int32_t CryptoNative_SslCtxSetEncryptionPolicy(SSL_CTX* ctx, EncryptionPolicy policy) { // No error queue impact. diff --git a/src/native/libs/System.Security.Cryptography.Native/pal_ssl.h b/src/native/libs/System.Security.Cryptography.Native/pal_ssl.h index 9c9d7026119c5c..8566c7b8ff9b35 100644 --- a/src/native/libs/System.Security.Cryptography.Native/pal_ssl.h +++ b/src/native/libs/System.Security.Cryptography.Native/pal_ssl.h @@ -5,6 +5,10 @@ #include "pal_compiler.h" #include "opensslshim.h" +// index for storing an opaque pointer of used (client) certificate in SSL_SESSION. +// we need dedicated index in order to tell OpenSSL how to copy the pointer during SSL_SESSION_dup. +extern int g_ssl_sess_cert_index; + /* These values should be kept in sync with System.Security.Authentication.SslProtocols. */ @@ -183,6 +187,11 @@ OpenSSL holds reference to it and it must not be freed. */ PALEXPORT const char* CryptoNative_SslGetServerName(SSL* ssl); +/* +Returns session associated with given ssl. +*/ +PALEXPORT SSL_SESSION* CryptoNative_SslGetSession(SSL* ssl); + /* This function will attach existing ssl session for possible TLS resume. */ @@ -326,6 +335,13 @@ Returns the certificate presented by the peer. */ PALEXPORT X509* CryptoNative_SslGetPeerCertificate(SSL* ssl); +/* +Shims the SSL_get_certificate method. + +Returns the certificate representing local peer. +*/ +PALEXPORT X509* CryptoNative_SslGetCertificate(SSL* ssl); + /* Shims the SSL_get_peer_cert_chain method. @@ -450,6 +466,16 @@ Shims the SSL_session_reused macro. */ PALEXPORT int32_t CryptoNative_SslSessionReused(SSL* ssl); +/* +Sets the app data pointer for the given session instance. +*/ +PALEXPORT void CryptoNative_SslSessionSetData(SSL_SESSION* session, void* val); + +/* +Returns the stored application data pointer. +*/ +PALEXPORT void* CryptoNative_SslSessionGetData(SSL_SESSION* session); + /* Adds the given certificate to the extra chain certificates associated with ctx.