diff --git a/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/Constants.cs b/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/Constants.cs index 748dff5aa20..433d6faa3ea 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/Constants.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/Constants.cs @@ -7,5 +7,5 @@ internal static class Constants { public const string NoContent = "[no-content-type]"; public const string UnreadableContent = "[unreadable-content-type]"; - public const string ReadCancelled = "[read-cancelled]"; + public const string ReadCancelledByTimeout = "[read-timeout]"; } diff --git a/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/HttpRequestBodyReader.cs b/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/HttpRequestBodyReader.cs index 38ed1c57378..ed5a3c3f33d 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/HttpRequestBodyReader.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/HttpRequestBodyReader.cs @@ -79,7 +79,7 @@ private static async ValueTask ReadFromStreamWithTimeoutAsync(HttpReques // when readTimeout occurred: catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) { - return Constants.ReadCancelled; + return Constants.ReadCancelledByTimeout; } } diff --git a/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/HttpResponseBodyReader.cs b/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/HttpResponseBodyReader.cs index 9235603767d..0c5b6a672b1 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/HttpResponseBodyReader.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Logging/Internal/HttpResponseBodyReader.cs @@ -3,15 +3,15 @@ using System; using System.Collections.Frozen; +using System.Collections.Generic; using System.IO; +using System.IO.Pipelines; using System.Net.Http; +using System.Net.Http.Headers; using System.Text; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.ObjectPool; -using Microsoft.IO; using Microsoft.Shared.Diagnostics; -using Microsoft.Shared.Pools; namespace Microsoft.Extensions.Http.Logging.Internal; @@ -22,15 +22,18 @@ internal sealed class HttpResponseBodyReader /// internal readonly TimeSpan ResponseReadTimeout; - private static readonly ObjectPool> _bufferWriterPool = BufferWriterPool.SharedBufferWriterPool; + // The chunk size of 8192 bytes (8 KB) is chosen as a balance between memory usage and performance. + // It is large enough to efficiently handle typical HTTP response sizes without excessive memory allocation, + // while still being small enough to avoid large object heap allocations and reduce memory fragmentation. + private const int ChunkSize = 8 * 1024; + private readonly FrozenSet _readableResponseContentTypes; private readonly int _responseReadLimit; - private readonly RecyclableMemoryStreamManager _streamManager; - public HttpResponseBodyReader(LoggingOptions responseOptions, IDebuggerState? debugger = null) { - _streamManager = new RecyclableMemoryStreamManager(); + _ = Throw.IfNull(responseOptions); + _readableResponseContentTypes = responseOptions.ResponseBodyContentTypes.ToFrozenSet(StringComparer.OrdinalIgnoreCase); _responseReadLimit = responseOptions.BodySizeLimit; @@ -43,7 +46,7 @@ public HttpResponseBodyReader(LoggingOptions responseOptions, IDebuggerState? de public ValueTask ReadAsync(HttpResponseMessage response, CancellationToken cancellationToken) { - var contentType = response.Content.Headers.ContentType; + MediaTypeHeaderValue? contentType = response.Content.Headers.ContentType; if (contentType == null) { return new(Constants.NoContent); @@ -54,90 +57,186 @@ public ValueTask ReadAsync(HttpResponseMessage response, CancellationTok return new(Constants.UnreadableContent); } - return ReadFromStreamWithTimeoutAsync(response, ResponseReadTimeout, _responseReadLimit, _streamManager, - cancellationToken).Preserve(); + return ReadFromStreamWithTimeoutAsync(response, ResponseReadTimeout, _responseReadLimit, cancellationToken).Preserve(); } - private static async ValueTask ReadFromStreamAsync(HttpResponseMessage response, int readSizeLimit, - RecyclableMemoryStreamManager streamManager, CancellationToken cancellationToken) + private static async ValueTask ReadFromStreamWithTimeoutAsync(HttpResponseMessage response, TimeSpan readTimeout, int readSizeLimit, CancellationToken cancellationToken) { -#if NET5_0_OR_GREATER - var streamToReadFrom = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); -#else - var streamToReadFrom = await response.Content.ReadAsStreamAsync().WaitAsync(cancellationToken).ConfigureAwait(false); -#endif + using var joinedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + joinedTokenSource.CancelAfter(readTimeout); + + // TimeSpan.Zero cannot be set from user's code as + // validation prevents values less than one millisecond + // However, this is useful during unit tests + if (readTimeout <= TimeSpan.Zero) + { + // cancel immediately, async cancel not required in tests +#pragma warning disable CA1849 // Call async methods when in an async method + joinedTokenSource.Cancel(); +#pragma warning restore CA1849 // Call async methods when in an async method + } - var bufferWriter = _bufferWriterPool.Get(); - var memory = bufferWriter.GetMemory(readSizeLimit).Slice(0, readSizeLimit); -#if !NETCOREAPP3_1_OR_GREATER - byte[] buffer = memory.ToArray(); -#endif try { -#if NETCOREAPP3_1_OR_GREATER - var charsWritten = await streamToReadFrom.ReadAsync(memory, cancellationToken).ConfigureAwait(false); - bufferWriter.Advance(charsWritten); - return Encoding.UTF8.GetString(memory.Slice(0, charsWritten).Span); + return await ReadFromStreamAsync(response, readSizeLimit, joinedTokenSource.Token).ConfigureAwait(false); + } + + // when readTimeout occurred: joined token source is cancelled and cancellationToken is not + catch (OperationCanceledException) when (joinedTokenSource.IsCancellationRequested && !cancellationToken.IsCancellationRequested) + { + return Constants.ReadCancelledByTimeout; + } + } + + private static async ValueTask ReadFromStreamAsync(HttpResponseMessage response, int readSizeLimit, CancellationToken cancellationToken) + { +#if NET6_0_OR_GREATER + Stream streamToReadFrom = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); #else - var charsWritten = await streamToReadFrom.ReadAsync(buffer, 0, readSizeLimit, cancellationToken).ConfigureAwait(false); - bufferWriter.Advance(charsWritten); - return Encoding.UTF8.GetString(buffer.AsMemory(0, charsWritten).ToArray()); + Stream streamToReadFrom = await response.Content.ReadAsStreamAsync().WaitAsync(cancellationToken).ConfigureAwait(false); #endif + + var pipe = new Pipe(); + + string bufferedString = await BufferStreamAndWriteToPipeAsync(streamToReadFrom, pipe.Writer, readSizeLimit, cancellationToken).ConfigureAwait(false); + + // if stream is seekable we can just rewind it and return the buffered string + if (streamToReadFrom.CanSeek) + { + streamToReadFrom.Seek(0, SeekOrigin.Begin); + + await pipe.Reader.CompleteAsync().ConfigureAwait(false); + + return bufferedString; } - finally + + // if stream is not seekable we need to write the rest of the stream to the pipe + // and create a new response content with the pipe reader as stream + _ = Task.Run(async () => { - if (streamToReadFrom.CanSeek) + await WriteStreamToPipeAsync(streamToReadFrom, pipe.Writer, cancellationToken).ConfigureAwait(false); + }, CancellationToken.None); + + // use the pipe reader as stream for the new content + var newContent = new StreamContent(pipe.Reader.AsStream()); + foreach (KeyValuePair> header in response.Content.Headers) + { + _ = newContent.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + + response.Content = newContent; + + return bufferedString; + } + +#if NET6_0_OR_GREATER + private static async Task BufferStreamAndWriteToPipeAsync(Stream stream, PipeWriter writer, int bufferSize, CancellationToken cancellationToken) + { + Memory memory = writer.GetMemory(bufferSize)[..bufferSize]; + +#if NET8_0_OR_GREATER + int bytesRead = await stream.ReadAtLeastAsync(memory, bufferSize, false, cancellationToken).ConfigureAwait(false); +#else + int bytesRead = 0; + while (bytesRead < bufferSize) + { + int read = await stream.ReadAsync(memory.Slice(bytesRead), cancellationToken).ConfigureAwait(false); + if (read == 0) { - streamToReadFrom.Seek(0, SeekOrigin.Begin); + break; } - else - { - var freshStream = streamManager.GetStream(); -#if NETCOREAPP3_1_OR_GREATER - var remainingSpace = memory.Slice(bufferWriter.WrittenCount, memory.Length - bufferWriter.WrittenCount); - var writtenCount = await streamToReadFrom.ReadAsync(remainingSpace, cancellationToken) - .ConfigureAwait(false); - - await freshStream.WriteAsync(memory.Slice(0, writtenCount + bufferWriter.WrittenCount), cancellationToken) - .ConfigureAwait(false); -#else - var writtenCount = await streamToReadFrom.ReadAsync(buffer, bufferWriter.WrittenCount, - buffer.Length - bufferWriter.WrittenCount, cancellationToken).ConfigureAwait(false); - await freshStream.WriteAsync(buffer, 0, writtenCount + bufferWriter.WrittenCount, cancellationToken).ConfigureAwait(false); + bytesRead += read; + } #endif - freshStream.Seek(0, SeekOrigin.Begin); - var newContent = new StreamContent(freshStream); + if (bytesRead == 0) + { + return string.Empty; + } + + writer.Advance(bytesRead); + + return Encoding.UTF8.GetString(memory[..bytesRead].Span); + } - foreach (var header in response.Content.Headers) - { - _ = newContent.Headers.TryAddWithoutValidation(header.Key, header.Value); - } + private static async Task WriteStreamToPipeAsync(Stream stream, PipeWriter writer, CancellationToken cancellationToken) + { + while (true) + { + Memory memory = writer.GetMemory(ChunkSize)[..ChunkSize]; - response.Content = newContent; + int bytesRead = await stream.ReadAsync(memory, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) + { + break; } - _bufferWriterPool.Return(bufferWriter); + writer.Advance(bytesRead); + + FlushResult result = await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + if (result.IsCompleted) + { + break; + } } - } - private static async ValueTask ReadFromStreamWithTimeoutAsync(HttpResponseMessage response, TimeSpan readTimeout, - int readSizeLimit, RecyclableMemoryStreamManager streamManager, CancellationToken cancellationToken) + await writer.CompleteAsync().ConfigureAwait(false); + } +#else + private static async Task BufferStreamAndWriteToPipeAsync(Stream stream, PipeWriter writer, int bufferSize, CancellationToken cancellationToken) { - using var joinedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - joinedTokenSource.CancelAfter(readTimeout); + var sb = new StringBuilder(); - try + int bytesRead = 0; + + while (bytesRead < bufferSize) { - return await ReadFromStreamAsync(response, readSizeLimit, streamManager, joinedTokenSource.Token) - .ConfigureAwait(false); + int chunkSize = Math.Min(ChunkSize, bufferSize - bytesRead); + + Memory memory = writer.GetMemory(chunkSize).Slice(0, chunkSize); + + byte[] buffer = memory.ToArray(); + + int read = await stream.ReadAsync(buffer, 0, chunkSize, cancellationToken).ConfigureAwait(false); + if (read == 0) + { + break; + } + + bytesRead += read; + + buffer.CopyTo(memory); + + writer.Advance(read); + + _ = sb.Append(Encoding.UTF8.GetString(buffer.AsMemory(0, read).ToArray())); } - // when readTimeout occurred: - catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) + return sb.ToString(); + } + + private static async Task WriteStreamToPipeAsync(Stream stream, PipeWriter writer, CancellationToken cancellationToken) + { + while (true) { - return Constants.ReadCancelled; + Memory memory = writer.GetMemory(ChunkSize).Slice(0, ChunkSize); + byte[] buffer = memory.ToArray(); + + int bytesRead = await stream.ReadAsync(buffer, 0, ChunkSize, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) + { + break; + } + + FlushResult result = await writer.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken).ConfigureAwait(false); + if (result.IsCompleted) + { + break; + } } + + await writer.CompleteAsync().ConfigureAwait(false); } +#endif } diff --git a/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Microsoft.Extensions.Http.Diagnostics.csproj b/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Microsoft.Extensions.Http.Diagnostics.csproj index f6c98baefce..cc5de094e47 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Microsoft.Extensions.Http.Diagnostics.csproj +++ b/src/Libraries/Microsoft.Extensions.Http.Diagnostics/Microsoft.Extensions.Http.Diagnostics.csproj @@ -38,7 +38,7 @@ - + diff --git a/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/AcceptanceTests.cs b/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/AcceptanceTests.cs index 9ae4ee7bd88..3143aab9185 100644 --- a/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/AcceptanceTests.cs +++ b/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/AcceptanceTests.cs @@ -171,9 +171,9 @@ public async Task AddHttpClientLogging_WithNamedHttpClients_WorksCorrectly() var collector = provider.GetFakeLogCollector(); var logRecord = collector.GetSnapshot().Single(l => l.Category == LoggingCategory); var state = logRecord.StructuredState; - state.Should().Contain(kvp => kvp.Value == responseString); - state.Should().Contain(kvp => kvp.Value == "Request Value"); - state.Should().Contain(kvp => kvp.Value == "Request Value 2,Request Value 3"); + state.Should().ContainValue(responseString); + state.Should().ContainValue("Request Value"); + state.Should().ContainValue("Request Value 2,Request Value 3"); using var httpRequestMessage2 = new HttpRequestMessage { @@ -187,9 +187,9 @@ public async Task AddHttpClientLogging_WithNamedHttpClients_WorksCorrectly() responseString = await SendRequest(namedClient2, httpRequestMessage2); logRecord = collector.GetSnapshot().Single(l => l.Category == LoggingCategory); state = logRecord.StructuredState; - state.Should().Contain(kvp => kvp.Value == responseString); - state.Should().Contain(kvp => kvp.Value == "Request Value"); - state.Should().Contain(kvp => kvp.Value == "Request Value 2,Request Value 3"); + state.Should().ContainValue(responseString); + state.Should().ContainValue("Request Value"); + state.Should().ContainValue("Request Value 2,Request Value 3"); } private static async Task SendRequest(HttpClient httpClient, HttpRequestMessage httpRequestMessage) @@ -258,9 +258,9 @@ public async Task AddHttpClientLogging_WithTypedHttpClients_WorksCorrectly() var logRecord = collector.GetSnapshot().Single(l => l.Category == LoggingCategory); var state = logRecord.StructuredState; state.Should().NotBeNull(); - state.Should().Contain(kvp => kvp.Value == responseString); - state.Should().Contain(kvp => kvp.Value == "Request Value"); - state.Should().Contain(kvp => kvp.Value == "Request Value 2,Request Value 3"); + state.Should().ContainValue(responseString); + state.Should().ContainValue("Request Value"); + state.Should().ContainValue("Request Value 2,Request Value 3"); using var httpRequestMessage2 = new HttpRequestMessage { @@ -279,9 +279,9 @@ public async Task AddHttpClientLogging_WithTypedHttpClients_WorksCorrectly() logRecord = collector.GetSnapshot().Single(l => l.Category == LoggingCategory); state = logRecord.StructuredState; - state.Should().Contain(kvp => kvp.Value == responseString); - state.Should().Contain(kvp => kvp.Value == "Request Value"); - state.Should().Contain(kvp => kvp.Value == "Request Value 2,Request Value 3"); + state.Should().ContainValue(responseString); + state.Should().ContainValue("Request Value"); + state.Should().ContainValue("Request Value 2,Request Value 3"); } [Theory] @@ -654,6 +654,8 @@ public async Task AddDefaultHttpClientLogging_DisablesNetScope() [InlineData(315_883)] public async Task HttpClientLoggingHandler_LogsBodyDataUpToSpecifiedLimit(int limit) { + const int LengthOfContentInTextFile = 64_751; + await using var provider = new ServiceCollection() .AddFakeLogging() .AddFakeRedaction() @@ -686,17 +688,18 @@ public async Task HttpClientLoggingHandler_LogsBodyDataUpToSpecifiedLimit(int li httpRequestMessage.Headers.Add("ReQuEStHeAdEr2", new List { "Request Value 2", "Request Value 3" }); var content = await client.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead); - var responseStream = await content.Content.ReadAsStreamAsync(); - var length = (int)responseStream.Length > limit ? limit : (int)responseStream.Length; - var buffer = new byte[length]; - _ = await responseStream.ReadAsync(buffer, 0, length); - var responseString = Encoding.UTF8.GetString(buffer); + var responseString = await content.Content.ReadAsStringAsync(); + var length = Math.Min(limit, responseString.Length); + var loggedBodyString = responseString.Substring(0, length); + + // length of the content in the Text.txt file + responseString.Length.Should().Be(LengthOfContentInTextFile); var collector = provider.GetFakeLogCollector(); var logRecord = collector.GetSnapshot().Single(l => l.Category == LoggingCategory); var state = logRecord.StructuredState; - state.Should().Contain(kvp => kvp.Value == responseString); - state.Should().Contain(kvp => kvp.Value == "Request Value"); - state.Should().Contain(kvp => kvp.Value == "Request Value 2,Request Value 3"); + state.Should().ContainValue(loggedBodyString); + state.Should().ContainValue("Request Value"); + state.Should().ContainValue("Request Value 2,Request Value 3"); } } diff --git a/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/HttpRequestBodyReaderTest.cs b/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/HttpRequestBodyReaderTest.cs index 9282e9a4838..f95d16f2afc 100644 --- a/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/HttpRequestBodyReaderTest.cs +++ b/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/HttpRequestBodyReaderTest.cs @@ -193,7 +193,7 @@ public async Task Reader_ReadingTakesTooLong_Timesout() var requestBody = await httpRequestBodyReader.ReadAsync(httpRequest, CancellationToken.None); var returnedValue = requestBody; - var expectedValue = Constants.ReadCancelled; + var expectedValue = Constants.ReadCancelledByTimeout; returnedValue.Should().BeEquivalentTo(expectedValue); } diff --git a/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/HttpResponseBodyReaderTest.cs b/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/HttpResponseBodyReaderTest.cs index c23568ddf80..ec78df392fc 100644 --- a/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/HttpResponseBodyReaderTest.cs +++ b/test/Libraries/Microsoft.Extensions.Http.Diagnostics.Tests/Logging/HttpResponseBodyReaderTest.cs @@ -20,6 +20,7 @@ namespace Microsoft.Extensions.Http.Logging.Test; public class HttpResponseBodyReaderTest { + private const string TextPlain = "text/plain"; private readonly Fixture _fixture; public HttpResponseBodyReaderTest() @@ -27,19 +28,26 @@ public HttpResponseBodyReaderTest() _fixture = new Fixture(); } + [Fact] + public void Reader_NullOptions_Throws() + { + var act = () => new HttpResponseBodyReader(null!); + act.Should().Throw(); + } + [Fact] public async Task Reader_SimpleContent_ReadsContent() { var options = new LoggingOptions { - ResponseBodyContentTypes = new HashSet { "text/plain" } + ResponseBodyContentTypes = new HashSet { TextPlain } }; var httpResponseBodyReader = new HttpResponseBodyReader(options); var expectedContentBody = _fixture.Create(); using var httpResponse = new HttpResponseMessage { - Content = new StringContent(expectedContentBody, Encoding.UTF8, "text/plain") + Content = new StringContent(expectedContentBody, Encoding.UTF8, TextPlain) }; var responseBody = await httpResponseBodyReader.ReadAsync(httpResponse, CancellationToken.None); @@ -48,11 +56,11 @@ public async Task Reader_SimpleContent_ReadsContent() } [Fact] - public async Task Reader_EmptyContent_ErrorMessage() + public async Task Reader_NoContentType_ErrorMessage() { var options = new LoggingOptions { - ResponseBodyContentTypes = new HashSet { "text/plain" } + ResponseBodyContentTypes = new HashSet { TextPlain } }; using var httpResponse = new HttpResponseMessage @@ -66,6 +74,24 @@ public async Task Reader_EmptyContent_ErrorMessage() responseBody.Should().Be(Constants.NoContent); } + [Fact] + public async Task Reader_EmptyContent_ReturnsEmptyString() + { + var options = new LoggingOptions + { + ResponseBodyContentTypes = new HashSet { TextPlain } + }; + using var httpResponse = new HttpResponseMessage + { + Content = new StringContent(string.Empty, Encoding.UTF8, TextPlain) + }; + + var httpResponseBodyReader = new HttpResponseBodyReader(options); + var responseBody = await httpResponseBodyReader.ReadAsync(httpResponse, CancellationToken.None); + + responseBody.Should().BeEmpty(); + } + [Theory] [CombinatorialData] public async Task Reader_UnreadableContent_ErrorMessage( @@ -75,7 +101,7 @@ public async Task Reader_UnreadableContent_ErrorMessage( { var options = new LoggingOptions { - ResponseBodyContentTypes = new HashSet { "text/plain" } + ResponseBodyContentTypes = new HashSet { TextPlain } }; var httpResponseBodyReader = new HttpResponseBodyReader(options); @@ -95,14 +121,14 @@ public async Task Reader_OperationCanceled_ThrowsTaskCanceledException() { var options = new LoggingOptions { - ResponseBodyContentTypes = new HashSet { "text/plain" } + ResponseBodyContentTypes = new HashSet { TextPlain } }; var httpResponseBodyReader = new HttpResponseBodyReader(options); var input = _fixture.Create(); using var httpResponse = new HttpResponseMessage { - Content = new StringContent(input, Encoding.UTF8, "text/plain") + Content = new StringContent(input, Encoding.UTF8, TextPlain) }; var token = new CancellationToken(true); @@ -119,19 +145,60 @@ public async Task Reader_BigContent_TrimsAtTheEnd([CombinatorialValues(32, 256, var options = new LoggingOptions { BodySizeLimit = limit, - ResponseBodyContentTypes = new HashSet { "text/plain" } + ResponseBodyContentTypes = new HashSet { TextPlain } }; var httpResponseBodyReader = new HttpResponseBodyReader(options); var bigContent = RandomStringGenerator.Generate(limit * 2); using var httpResponse = new HttpResponseMessage { - Content = new StringContent(bigContent, Encoding.UTF8, "text/plain") + Content = new StreamContent(new NotSeekableStream(new(Encoding.UTF8.GetBytes(bigContent)))) }; + httpResponse.Content.Headers.Add("Content-Type", TextPlain); var responseBody = await httpResponseBodyReader.ReadAsync(httpResponse, CancellationToken.None); responseBody.Should().Be(bigContent.Substring(0, limit)); + + // This should read from piped stream + var response = await httpResponse.Content.ReadAsStringAsync(); + + response.Should().Be(bigContent); + } + + [Fact] + public async Task Reader_ReaderCancelledAfterBuffering_ShouldCancelPipeReader() + { + const int BodySize = 10_000_000; + var options = new LoggingOptions + { + BodySizeLimit = 1, + ResponseBodyContentTypes = new HashSet { TextPlain } + }; + var httpResponseBodyReader = new HttpResponseBodyReader(options); + var bigContent = RandomStringGenerator.Generate(BodySize); + using var httpResponse = new HttpResponseMessage + { + Content = new StreamContent(new NotSeekableStream(new(Encoding.UTF8.GetBytes(bigContent)))) + }; + httpResponse.Content.Headers.Add("Content-Type", TextPlain); + + using var cts = new CancellationTokenSource(); + + var responseBody = await httpResponseBodyReader.ReadAsync(httpResponse, cts.Token); + + responseBody.Should().HaveLength(1); + + // This should read from piped stream + var responseStream = await httpResponse.Content.ReadAsStreamAsync(); + + var buffer = new byte[BodySize]; + + cts.Cancel(false); + + var act = async () => await responseStream.ReadAsync(buffer, 0, BodySize, cts.Token); + + await act.Should().ThrowAsync().Where(e => e.CancellationToken.IsCancellationRequested); } [Fact] @@ -139,12 +206,13 @@ public async Task Reader_ReadingTakesTooLong_TimesOut() { var options = new LoggingOptions { - ResponseBodyContentTypes = new HashSet { "text/plain" } + ResponseBodyContentTypes = new HashSet { TextPlain }, + BodyReadTimeout = TimeSpan.Zero }; var httpResponseBodyReader = new HttpResponseBodyReader(options); var streamMock = new Mock(); -#if NETCOREAPP3_1_OR_GREATER +#if NET6_0_OR_GREATER streamMock.Setup(x => x.ReadAsync(It.IsAny>(), It.IsAny())).Throws(); #else streamMock.Setup(x => x.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).Throws(); @@ -154,11 +222,39 @@ public async Task Reader_ReadingTakesTooLong_TimesOut() Content = new StreamContent(streamMock.Object) }; - httpResponse.Content.Headers.Add("Content-type", "text/plain"); + httpResponse.Content.Headers.Add("Content-type", TextPlain); + + var responseBody = await httpResponseBodyReader.ReadAsync(httpResponse, CancellationToken.None); - var requestBody = await httpResponseBodyReader.ReadAsync(httpResponse, CancellationToken.None); + responseBody.Should().Be(Constants.ReadCancelledByTimeout); + } + + [Fact] + public async Task Reader_ReadingTakesTooLongAndOperationCancelled_Throws() + { + var options = new LoggingOptions + { + ResponseBodyContentTypes = new HashSet { TextPlain }, + BodyReadTimeout = TimeSpan.Zero + }; + var httpResponseBodyReader = new HttpResponseBodyReader(options); + var streamMock = new Mock(); + var token = new CancellationToken(true); + var exception = new OperationCanceledException(token); +#if NET6_0_OR_GREATER + streamMock.Setup(x => x.ReadAsync(It.IsAny>(), It.IsAny())).Throws(exception); +#else + streamMock.Setup(x => x.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).Throws(exception); +#endif + using var httpResponse = new HttpResponseMessage + { + Content = new StreamContent(streamMock.Object) + }; + httpResponse.Content.Headers.Add("Content-type", TextPlain); + + var act = async () => await httpResponseBodyReader.ReadAsync(httpResponse, token); - requestBody.Should().Be(Constants.ReadCancelled); + await act.Should().ThrowAsync().Where(e => e.CancellationToken.IsCancellationRequested); } [Fact]