diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs index 3f13b716cd4c7c..32a02d94d45391 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs @@ -50,15 +50,15 @@ public NetworkStream(Socket socket, FileAccess access, bool ownsSocket) // allowing non-blocking sockets could result in non-deterministic failures from those // operations. A developer that requires using NetworkStream with a non-blocking socket can // temporarily flip Socket.Blocking as a workaround. - throw GetCustomException(SR.net_sockets_blocking); + throw new IOException(SR.net_sockets_blocking); } if (!socket.Connected) { - throw GetCustomException(SR.net_notconnected); + throw new IOException(SR.net_notconnected); } if (socket.SocketType != SocketType.Stream) { - throw GetCustomException(SR.net_notstream); + throw new IOException(SR.net_notstream); } _streamSocket = socket; @@ -227,13 +227,9 @@ public override int Read(byte[] buffer, int offset, int count) { return _streamSocket.Receive(buffer, offset, count, 0); } - catch (SocketException socketException) - { - throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException); - } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception); + throw WrapException(SR.net_io_readfailure, exception); } } @@ -250,23 +246,14 @@ public override int Read(Span buffer) ThrowIfDisposed(); if (!CanRead) throw new InvalidOperationException(SR.net_writeonlystream); - int bytesRead; - SocketError errorCode; try { - bytesRead = _streamSocket.Receive(buffer, SocketFlags.None, out errorCode); + return _streamSocket.Receive(buffer, SocketFlags.None); } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception); - } - - if (errorCode != SocketError.Success) - { - var socketException = new SocketException((int)errorCode); - throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException); + throw WrapException(SR.net_io_readfailure, exception); } - return bytesRead; } public override unsafe int ReadByte() @@ -306,13 +293,9 @@ public override void Write(byte[] buffer, int offset, int count) // after ALL the requested number of bytes was transferred. _streamSocket.Send(buffer, offset, count, SocketFlags.None); } - catch (SocketException socketException) - { - throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException); - } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception); + throw WrapException(SR.net_io_writefailure, exception); } } @@ -330,20 +313,13 @@ public override void Write(ReadOnlySpan buffer) ThrowIfDisposed(); if (!CanWrite) throw new InvalidOperationException(SR.net_readonlystream); - SocketError errorCode; try { - _streamSocket.Send(buffer, SocketFlags.None, out errorCode); + _streamSocket.Send(buffer, SocketFlags.None); } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception); - } - - if (errorCode != SocketError.Success) - { - var socketException = new SocketException((int)errorCode); - throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException); + throw WrapException(SR.net_io_writefailure, exception); } } @@ -424,13 +400,9 @@ public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, Asy callback, state); } - catch (SocketException socketException) - { - throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException); - } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception); + throw WrapException(SR.net_io_readfailure, exception); } } @@ -456,13 +428,9 @@ public override int EndRead(IAsyncResult asyncResult) { return _streamSocket.EndReceive(asyncResult); } - catch (SocketException socketException) - { - throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException); - } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception); + throw WrapException(SR.net_io_readfailure, exception); } } @@ -500,13 +468,9 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As callback, state); } - catch (SocketException socketException) - { - throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException); - } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception); + throw WrapException(SR.net_io_writefailure, exception); } } @@ -528,13 +492,9 @@ public override void EndWrite(IAsyncResult asyncResult) { _streamSocket.EndSend(asyncResult); } - catch (SocketException socketException) - { - throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException); - } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception); + throw WrapException(SR.net_io_writefailure, exception); } } @@ -570,13 +530,9 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel fromNetworkStream: true, cancellationToken).AsTask(); } - catch (SocketException socketException) - { - throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException); - } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception); + throw WrapException(SR.net_io_readfailure, exception); } } @@ -597,13 +553,9 @@ public override ValueTask ReadAsync(Memory buffer, CancellationToken fromNetworkStream: true, cancellationToken: cancellationToken); } - catch (SocketException socketException) - { - throw GetExceptionFromSocketException(SR.Format(SR.net_io_readfailure, socketException.Message), socketException); - } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception); + throw WrapException(SR.net_io_readfailure, exception); } } @@ -638,13 +590,9 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati SocketFlags.None, cancellationToken).AsTask(); } - catch (SocketException socketException) - { - throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException); - } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception); + throw WrapException(SR.net_io_writefailure, exception); } } @@ -664,13 +612,9 @@ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationTo SocketFlags.None, cancellationToken); } - catch (SocketException socketException) - { - throw GetExceptionFromSocketException(SR.Format(SR.net_io_writefailure, socketException.Message), socketException); - } catch (Exception exception) when (!(exception is OutOfMemoryException)) { - throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception); + throw WrapException(SR.net_io_writefailure, exception); } } @@ -728,14 +672,9 @@ private void ThrowIfDisposed() void ThrowObjectDisposedException() => throw new ObjectDisposedException(GetType().FullName); } - private static IOException GetExceptionFromSocketException(string message, SocketException innerException) - { - return new IOException(message, innerException); - } - - private static IOException GetCustomException(string message, Exception? innerException = null) + private static IOException WrapException(string resourceFormatString, Exception innerException) { - return new IOException(message, innerException); + return new IOException(SR.Format(resourceFormatString, innerException.Message), innerException); } } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs index 4f03d4ab4c3ea0..05f7f9b8c1d6d2 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs @@ -321,21 +321,28 @@ public async Task DisposeSocketDirectly_ReadWriteThrowNetworkException(bool deri Task acceptTask = listener.AcceptAsync(); await Task.WhenAll(acceptTask, client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)listener.LocalEndPoint).Port))); using Socket serverSocket = await acceptTask; + using NetworkStream server = derivedNetworkStream ? (NetworkStream)new DerivedNetworkStream(serverSocket) : new NetworkStream(serverSocket); - + serverSocket.Dispose(); - Assert.Throws(() => server.Read(new byte[1], 0, 1)); - Assert.Throws(() => server.Write(new byte[1], 0, 1)); + ExpectIOException(() => server.Read(new byte[1], 0, 1)); + ExpectIOException(() => server.Write(new byte[1], 0, 1)); - Assert.Throws(() => server.Read((Span)new byte[1])); - Assert.Throws(() => server.Write((ReadOnlySpan)new byte[1])); + ExpectIOException(() => server.Read((Span)new byte[1])); + ExpectIOException(() => server.Write((ReadOnlySpan)new byte[1])); - Assert.Throws(() => server.BeginRead(new byte[1], 0, 1, null, null)); - Assert.Throws(() => server.BeginWrite(new byte[1], 0, 1, null, null)); + ExpectIOException(() => server.BeginRead(new byte[1], 0, 1, null, null)); + ExpectIOException(() => server.BeginWrite(new byte[1], 0, 1, null, null)); - Assert.Throws(() => { server.ReadAsync(new byte[1], 0, 1); }); - Assert.Throws(() => { server.WriteAsync(new byte[1], 0, 1); }); + ExpectIOException(() => { _ = server.ReadAsync(new byte[1], 0, 1); }); + ExpectIOException(() => { _ = server.WriteAsync(new byte[1], 0, 1); }); + } + + static void ExpectIOException(Action action) + { + IOException ex = Assert.Throws(action); + Assert.IsType(ex.InnerException); } }