diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs index 38b097ebd961bd..21af8954f42587 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs @@ -54,10 +54,10 @@ private void CloseInternal() // Ensure a Read or Auth operation is not in progress, // block potential future read and auth operations since SslStream is disposing. - // This leaves the _nestedRead = 1 and _nestedAuth = 1, but that's ok, since + // This leaves the _nestedRead = 2 and _nestedAuth = 2, but that's ok, since // subsequent operations check the _exception sentinel first - if (Interlocked.Exchange(ref _nestedRead, 1) == 0 && - Interlocked.Exchange(ref _nestedAuth, 1) == 0) + if (Interlocked.Exchange(ref _nestedRead, StreamDisposed) == StreamNotInUse && + Interlocked.Exchange(ref _nestedAuth, StreamDisposed) == StreamNotInUse) { _buffer.ReturnBuffer(); } @@ -162,19 +162,22 @@ private async Task ReplyOnReAuthenticationAsync(byte[]? buffer, Canc private async Task RenegotiateAsync(CancellationToken cancellationToken) where TIOAdapter : IReadWriteAdapter { - if (Interlocked.Exchange(ref _nestedAuth, 1) == 1) + if (Interlocked.CompareExchange(ref _nestedAuth, StreamInUse, StreamNotInUse) != StreamNotInUse) { + ObjectDisposedException.ThrowIf(_nestedAuth == StreamDisposed, this); throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate")); } - if (Interlocked.Exchange(ref _nestedRead, 1) == 1) + if (Interlocked.CompareExchange(ref _nestedRead, StreamInUse, StreamNotInUse) != StreamNotInUse) { + ObjectDisposedException.ThrowIf(_nestedRead == StreamDisposed, this); throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read")); } - if (Interlocked.Exchange(ref _nestedWrite, 1) == 1) + // Write is different since we do not do anything special in Dispose + if (Interlocked.Exchange(ref _nestedWrite, StreamInUse) != StreamNotInUse) { - _nestedRead = 0; + _nestedRead = StreamNotInUse; throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write")); } @@ -231,8 +234,8 @@ private async Task RenegotiateAsync(CancellationToken cancellationTo _buffer.ReturnBuffer(); } - _nestedRead = 0; - _nestedWrite = 0; + _nestedRead = StreamNotInUse; + _nestedWrite = StreamNotInUse; _isRenego = false; // We will not release _nestedAuth at this point to prevent another renegotiation attempt. } @@ -248,7 +251,7 @@ private async Task ForceAuthenticationAsync(bool receiveFirst, byte[ if (reAuthenticationData == null) { // prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation transparently. - if (Interlocked.Exchange(ref _nestedAuth, 1) == 1) + if (Interlocked.Exchange(ref _nestedAuth, StreamInUse) == StreamInUse) { throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate")); } @@ -335,7 +338,7 @@ private async Task ForceAuthenticationAsync(bool receiveFirst, byte[ { if (reAuthenticationData == null) { - _nestedAuth = 0; + _nestedAuth = StreamNotInUse; _isRenego = false; } } @@ -494,7 +497,7 @@ private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyError { ProcessHandshakeSuccess(); - if (_nestedAuth != 1) + if (_nestedAuth != StreamInUse) { if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, $"Ignoring unsolicited renegotiated certificate."); // ignore certificates received outside of handshake or requested renegotiation. @@ -763,13 +766,16 @@ private SecurityStatusPal DecryptData(int frameSize) private async ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) where TIOAdapter : IReadWriteAdapter { - if (Interlocked.Exchange(ref _nestedRead, 1) == 1) + // Throw first if we already have exception. + // Check for disposal is not atomic so we will check again below. + ThrowIfExceptionalOrNotAuthenticated(); + + if (Interlocked.CompareExchange(ref _nestedRead, StreamInUse, StreamNotInUse) != StreamNotInUse) { + ObjectDisposedException.ThrowIf(_nestedRead == StreamDisposed, this); throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read")); } - ThrowIfExceptionalOrNotAuthenticated(); - try { int processedLength = 0; @@ -904,7 +910,7 @@ private async ValueTask ReadAsyncInternal(Memory buffer, finally { ReturnReadBufferIfEmpty(); - _nestedRead = 0; + _nestedRead = StreamNotInUse; } } @@ -919,7 +925,7 @@ private async ValueTask WriteAsyncInternal(ReadOnlyMemory buff return; } - if (Interlocked.Exchange(ref _nestedWrite, 1) == 1) + if (Interlocked.Exchange(ref _nestedWrite, StreamInUse) == StreamInUse) { throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write")); } @@ -942,7 +948,7 @@ private async ValueTask WriteAsyncInternal(ReadOnlyMemory buff } finally { - _nestedWrite = 0; + _nestedWrite = StreamNotInUse; } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs index 5ec9c5fcb1cb7b..e609d3dbdccafd 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -170,6 +170,11 @@ public void ReturnBuffer() } } + // used to track ussage in _nested* variables bellow + private const int StreamNotInUse = 0; + private const int StreamInUse = 1; + private const int StreamDisposed = 2; + private int _nestedWrite; private int _nestedRead; @@ -703,7 +708,7 @@ public override async ValueTask DisposeAsync() public override int ReadByte() { ThrowIfExceptionalOrNotAuthenticated(); - if (Interlocked.Exchange(ref _nestedRead, 1) == 1) + if (Interlocked.Exchange(ref _nestedRead, StreamInUse) == StreamInUse) { throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read")); } @@ -724,7 +729,7 @@ public override int ReadByte() // Regardless of whether we were able to read a byte from the buffer, // reset the read tracking. If we weren't able to read a byte, the // subsequent call to Read will set the flag again. - _nestedRead = 0; + _nestedRead = StreamNotInUse; } // Otherwise, fall back to reading a byte via Read, the same way Stream.ReadByte does. diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamDisposeTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamDisposeTest.cs index d3f7f17512f65e..de7aa502933b02 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamDisposeTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamDisposeTest.cs @@ -2,8 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; -using System.Net.Test.Common; using System.Security.Cryptography.X509Certificates; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -12,13 +12,13 @@ namespace System.Net.Security.Tests { using Configuration = System.Net.Test.Common.Configuration; - public abstract class SslStreamDisposeTest + public class SslStreamDisposeTest { [Fact] public async Task DisposeAsync_NotConnected_ClosesStream() { bool disposed = false; - var stream = new SslStream(new DelegateStream(disposeFunc: _ => disposed = true), false, delegate { return true; }); + var stream = new SslStream(new DelegateStream(disposeFunc: _ => disposed = true, canReadFunc: () => true, canWriteFunc: () => true), false, delegate { return true; }); Assert.False(disposed); await stream.DisposeAsync(); @@ -50,5 +50,57 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout( await serverStream.DisposeAsync(); Assert.NotEqual(0, trackingStream2.TimesCalled(nameof(Stream.DisposeAsync))); } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task Dispose_PendingReadAsync_ThrowsODE(bool bufferedRead) + { + using CancellationTokenSource cts = new CancellationTokenSource(); + cts.CancelAfter(TestConfiguration.PassingTestTimeout); + + (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(leaveInnerStreamOpen: true); + using (client) + using (server) + using (X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate()) + using (X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate()) + { + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions() + { + TargetHost = Guid.NewGuid().ToString("N"), + }; + clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true; + + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions() + { + ServerCertificate = serverCertificate, + }; + + await TestConfiguration.WhenAllOrAnyFailedWithTimeout( + client.AuthenticateAsClientAsync(clientOptions, default), + server.AuthenticateAsServerAsync(serverOptions, default)); + + await TestHelper.PingPong(client, server, cts.Token); + + await server.WriteAsync("PINGPONG"u8.ToArray(), cts.Token); + var readBuffer = new byte[1024]; + + Task? task = null; + if (bufferedRead) + { + // This will read everything into internal buffer. Following ReadAsync will not need IO. + task = client.ReadAsync(readBuffer, 0, 4, cts.Token); + client.Dispose(); + int readLength = await task.ConfigureAwait(false); + Assert.Equal(4, readLength); + } + else + { + client.Dispose(); + } + + await Assert.ThrowsAnyAsync(() => client.ReadAsync(readBuffer, cts.Token).AsTask()); + } + } } } diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs index aa2e8e46f23d9f..fd0b91f07ea77b 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs @@ -51,10 +51,10 @@ public static bool AllowAnyServerCertificate(object sender, X509Certificate cert return true; } - public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams() + public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams(bool leaveInnerStreamOpen = false) { (Stream clientStream, Stream serverStream) = GetConnectedStreams(); - return (new SslStream(clientStream), new SslStream(serverStream)); + return (new SslStream(clientStream, leaveInnerStreamOpen), new SslStream(serverStream, leaveInnerStreamOpen)); } public static (Stream ClientStream, Stream ServerStream) GetConnectedStreams()