Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
using System;
using System.Collections.Frozen;
using System.IO;
using System.IO.Pipelines;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.ObjectPool;
using Microsoft.IO;
using Microsoft.Shared.Diagnostics;
using Microsoft.Shared.Pools;

namespace Microsoft.Extensions.Http.Logging.Internal;

Expand All @@ -22,15 +20,14 @@ internal sealed class HttpResponseBodyReader
/// </summary>
internal readonly TimeSpan ResponseReadTimeout;

private static readonly ObjectPool<BufferWriter<byte>> _bufferWriterPool = BufferWriterPool.SharedBufferWriterPool;
private const int ChunkSize = 8 * 1024;
private readonly FrozenSet<string> _readableResponseContentTypes;
private readonly int _responseReadLimit;

private readonly RecyclableMemoryStreamManager _streamManager;

public HttpResponseBodyReader(LoggingOptions responseOptions, IDebuggerState? debugger = null)
{
_streamManager = new RecyclableMemoryStreamManager();
_ = Throw.IfNull(responseOptions);

_readableResponseContentTypes = responseOptions.ResponseBodyContentTypes.ToFrozenSet(StringComparer.OrdinalIgnoreCase);
_responseReadLimit = responseOptions.BodySizeLimit;

Expand All @@ -54,90 +51,186 @@ public ValueTask<string> ReadAsync(HttpResponseMessage response, CancellationTok
return new(Constants.UnreadableContent);
}

return ReadFromStreamWithTimeoutAsync(response, ResponseReadTimeout, _responseReadLimit, _streamManager,
cancellationToken).Preserve();
return ReadFromStreamWithTimeoutAsync(response, ResponseReadTimeout, _responseReadLimit, cancellationToken).Preserve();
}

private static async ValueTask<string> ReadFromStreamWithTimeoutAsync(HttpResponseMessage response, TimeSpan readTimeout, int readSizeLimit, CancellationToken cancellationToken)
{
using var joinedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
joinedTokenSource.CancelAfter(readTimeout);

// TimeSpan.Zero cannot be set from user's code as
// validation prevents values less than one millisecond
// However, this is useful during unit tests
if (readTimeout <= TimeSpan.Zero)
{
// cancel immediately, async cancel not required in tests
#pragma warning disable CA1849 // Call async methods when in an async method
joinedTokenSource.Cancel();
#pragma warning restore CA1849 // Call async methods when in an async method
}

try
{
return await ReadFromStreamAsync(response, readSizeLimit, joinedTokenSource.Token).ConfigureAwait(false);
}

// when readTimeout occurred: joined token source is cancelled and cancellationToken is not
catch (OperationCanceledException) when (joinedTokenSource.IsCancellationRequested && !cancellationToken.IsCancellationRequested)
{
return Constants.ReadCancelled;
}
}

private static async ValueTask<string> ReadFromStreamAsync(HttpResponseMessage response, int readSizeLimit,
RecyclableMemoryStreamManager streamManager, CancellationToken cancellationToken)
private static async ValueTask<string> ReadFromStreamAsync(HttpResponseMessage response, int readSizeLimit, CancellationToken cancellationToken)
{
#if NET5_0_OR_GREATER
#if NET6_0_OR_GREATER
var streamToReadFrom = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
#else
var streamToReadFrom = await response.Content.ReadAsStreamAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
#endif

var bufferWriter = _bufferWriterPool.Get();
var memory = bufferWriter.GetMemory(readSizeLimit).Slice(0, readSizeLimit);
#if !NETCOREAPP3_1_OR_GREATER
byte[] buffer = memory.ToArray();
#endif
try
var pipe = new Pipe();

var bufferedString = await BufferStreamAndWriteToPipeAsync(streamToReadFrom, pipe.Writer, readSizeLimit, cancellationToken).ConfigureAwait(false);

// if stream is seekable we can just rewind it and return the buffered string
if (streamToReadFrom.CanSeek)
{
streamToReadFrom.Seek(0, SeekOrigin.Begin);

await pipe.Reader.CompleteAsync().ConfigureAwait(false);

return bufferedString;
}

// if stream is not seekable we need to write the rest of the stream to the pipe
// and create a new response content with the pipe reader as stream
_ = Task.Run(async () =>
{
#if NETCOREAPP3_1_OR_GREATER
var charsWritten = await streamToReadFrom.ReadAsync(memory, cancellationToken).ConfigureAwait(false);
bufferWriter.Advance(charsWritten);
return Encoding.UTF8.GetString(memory.Slice(0, charsWritten).Span);
await WriteStreamToPipeAsync(streamToReadFrom, pipe.Writer, cancellationToken).ConfigureAwait(false);
}, CancellationToken.None).ConfigureAwait(false);

// use the pipe reader as stream for the new content
var newContent = new StreamContent(pipe.Reader.AsStream());
foreach (var header in response.Content.Headers)
{
_ = newContent.Headers.TryAddWithoutValidation(header.Key, header.Value);
}

response.Content = newContent;

return bufferedString;
}

#if NET6_0_OR_GREATER
private static async Task<string> BufferStreamAndWriteToPipeAsync(Stream stream, PipeWriter writer, int bufferSize, CancellationToken cancellationToken)
{
var memory = writer.GetMemory(bufferSize)[..bufferSize];

#if NET8_0_OR_GREATER
int bytesRead = await stream.ReadAtLeastAsync(memory, bufferSize, false, cancellationToken).ConfigureAwait(false);
#else
var charsWritten = await streamToReadFrom.ReadAsync(buffer, 0, readSizeLimit, cancellationToken).ConfigureAwait(false);
bufferWriter.Advance(charsWritten);
return Encoding.UTF8.GetString(buffer.AsMemory(0, charsWritten).ToArray());
int bytesRead = 0;
while (bytesRead < bufferSize)
{
int read = await stream.ReadAsync(memory.Slice(bytesRead), cancellationToken).ConfigureAwait(false);
if (read == 0)
{
break;
}

bytesRead += read;
}
#endif

if (bytesRead == 0)
{
return string.Empty;
}
finally

writer.Advance(bytesRead);

return Encoding.UTF8.GetString(memory[..bytesRead].Span);
}

private static async Task WriteStreamToPipeAsync(Stream stream, PipeWriter writer, CancellationToken cancellationToken)
{
while (true)
{
if (streamToReadFrom.CanSeek)
Memory<byte> memory = writer.GetMemory(ChunkSize)[..ChunkSize];

int bytesRead = await stream.ReadAsync(memory, cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
streamToReadFrom.Seek(0, SeekOrigin.Begin);
break;
}
else

writer.Advance(bytesRead);

FlushResult result = await writer.FlushAsync(cancellationToken).ConfigureAwait(false);
if (result.IsCompleted)
{
var freshStream = streamManager.GetStream();
#if NETCOREAPP3_1_OR_GREATER
var remainingSpace = memory.Slice(bufferWriter.WrittenCount, memory.Length - bufferWriter.WrittenCount);
var writtenCount = await streamToReadFrom.ReadAsync(remainingSpace, cancellationToken)
.ConfigureAwait(false);

await freshStream.WriteAsync(memory.Slice(0, writtenCount + bufferWriter.WrittenCount), cancellationToken)
.ConfigureAwait(false);
break;
}
}

await writer.CompleteAsync().ConfigureAwait(false);
}
#else
var writtenCount = await streamToReadFrom.ReadAsync(buffer, bufferWriter.WrittenCount,
buffer.Length - bufferWriter.WrittenCount, cancellationToken).ConfigureAwait(false);
private static async Task<string> BufferStreamAndWriteToPipeAsync(Stream stream, PipeWriter writer, int bufferSize, CancellationToken cancellationToken)
{
var sb = new StringBuilder();

await freshStream.WriteAsync(buffer, 0, writtenCount + bufferWriter.WrittenCount, cancellationToken).ConfigureAwait(false);
#endif
freshStream.Seek(0, SeekOrigin.Begin);
int bytesRead = 0;

while (bytesRead < bufferSize)
{
var chunkSize = Math.Min(ChunkSize, bufferSize - bytesRead);

var newContent = new StreamContent(freshStream);
var memory = writer.GetMemory(chunkSize).Slice(0, chunkSize);

foreach (var header in response.Content.Headers)
{
_ = newContent.Headers.TryAddWithoutValidation(header.Key, header.Value);
}
byte[] buffer = memory.ToArray();

response.Content = newContent;
int read = await stream.ReadAsync(buffer, 0, chunkSize, cancellationToken).ConfigureAwait(false);
if (read == 0)
{
break;
}

_bufferWriterPool.Return(bufferWriter);
bytesRead += read;

buffer.CopyTo(memory);

writer.Advance(read);

_ = sb.Append(Encoding.UTF8.GetString(buffer.AsMemory(0, read).ToArray()));
}

return sb.ToString();
}

private static async ValueTask<string> ReadFromStreamWithTimeoutAsync(HttpResponseMessage response, TimeSpan readTimeout,
int readSizeLimit, RecyclableMemoryStreamManager streamManager, CancellationToken cancellationToken)
private static async Task WriteStreamToPipeAsync(Stream stream, PipeWriter writer, CancellationToken cancellationToken)
{
using var joinedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
joinedTokenSource.CancelAfter(readTimeout);

try
while (true)
{
return await ReadFromStreamAsync(response, readSizeLimit, streamManager, joinedTokenSource.Token)
.ConfigureAwait(false);
}
Memory<byte> memory = writer.GetMemory(ChunkSize).Slice(0, ChunkSize);
byte[] buffer = memory.ToArray();

// when readTimeout occurred:
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
return Constants.ReadCancelled;
int bytesRead = await stream.ReadAsync(buffer, 0, ChunkSize, cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
break;
}

FlushResult result = await writer.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken).ConfigureAwait(false);
if (result.IsCompleted)
{
break;
}
}

await writer.CompleteAsync().ConfigureAwait(false);
}
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.IO.RecyclableMemoryStream" />
<PackageReference Include="System.IO.Pipelines" />
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions" />
<PackageReference Include="Microsoft.Extensions.Http" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ public async Task AddHttpClientLogging_WithNamedHttpClients_WorksCorrectly()
var collector = provider.GetFakeLogCollector();
var logRecord = collector.GetSnapshot().Single(l => l.Category == LoggingCategory);
var state = logRecord.StructuredState;
state.Should().Contain(kvp => kvp.Value == responseString);
state.Should().Contain(kvp => kvp.Value == "Request Value");
state.Should().Contain(kvp => kvp.Value == "Request Value 2,Request Value 3");
state.Should().ContainValue(responseString);
state.Should().ContainValue("Request Value");
state.Should().ContainValue("Request Value 2,Request Value 3");

using var httpRequestMessage2 = new HttpRequestMessage
{
Expand All @@ -187,9 +187,9 @@ public async Task AddHttpClientLogging_WithNamedHttpClients_WorksCorrectly()
responseString = await SendRequest(namedClient2, httpRequestMessage2);
logRecord = collector.GetSnapshot().Single(l => l.Category == LoggingCategory);
state = logRecord.StructuredState;
state.Should().Contain(kvp => kvp.Value == responseString);
state.Should().Contain(kvp => kvp.Value == "Request Value");
state.Should().Contain(kvp => kvp.Value == "Request Value 2,Request Value 3");
state.Should().ContainValue(responseString);
state.Should().ContainValue("Request Value");
state.Should().ContainValue("Request Value 2,Request Value 3");
}

private static async Task<string> SendRequest(HttpClient httpClient, HttpRequestMessage httpRequestMessage)
Expand Down Expand Up @@ -258,9 +258,9 @@ public async Task AddHttpClientLogging_WithTypedHttpClients_WorksCorrectly()
var logRecord = collector.GetSnapshot().Single(l => l.Category == LoggingCategory);
var state = logRecord.StructuredState;
state.Should().NotBeNull();
state.Should().Contain(kvp => kvp.Value == responseString);
state.Should().Contain(kvp => kvp.Value == "Request Value");
state.Should().Contain(kvp => kvp.Value == "Request Value 2,Request Value 3");
state.Should().ContainValue(responseString);
state.Should().ContainValue("Request Value");
state.Should().ContainValue("Request Value 2,Request Value 3");

using var httpRequestMessage2 = new HttpRequestMessage
{
Expand All @@ -279,9 +279,9 @@ public async Task AddHttpClientLogging_WithTypedHttpClients_WorksCorrectly()

logRecord = collector.GetSnapshot().Single(l => l.Category == LoggingCategory);
state = logRecord.StructuredState;
state.Should().Contain(kvp => kvp.Value == responseString);
state.Should().Contain(kvp => kvp.Value == "Request Value");
state.Should().Contain(kvp => kvp.Value == "Request Value 2,Request Value 3");
state.Should().ContainValue(responseString);
state.Should().ContainValue("Request Value");
state.Should().ContainValue("Request Value 2,Request Value 3");
}

[Theory]
Expand Down Expand Up @@ -654,6 +654,8 @@ public async Task AddDefaultHttpClientLogging_DisablesNetScope()
[InlineData(315_883)]
public async Task HttpClientLoggingHandler_LogsBodyDataUpToSpecifiedLimit(int limit)
{
const int LengthOfContentInTextFile = 64_751;

await using var provider = new ServiceCollection()
.AddFakeLogging()
.AddFakeRedaction()
Expand Down Expand Up @@ -686,17 +688,18 @@ public async Task HttpClientLoggingHandler_LogsBodyDataUpToSpecifiedLimit(int li
httpRequestMessage.Headers.Add("ReQuEStHeAdEr2", new List<string> { "Request Value 2", "Request Value 3" });

var content = await client.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead);
var responseStream = await content.Content.ReadAsStreamAsync();
var length = (int)responseStream.Length > limit ? limit : (int)responseStream.Length;
var buffer = new byte[length];
_ = await responseStream.ReadAsync(buffer, 0, length);
var responseString = Encoding.UTF8.GetString(buffer);
var responseString = await content.Content.ReadAsStringAsync();
var length = Math.Min(limit, responseString.Length);
var loggedBodyString = responseString.Substring(0, length);

// length of the content in the Text.txt file
responseString.Length.Should().Be(LengthOfContentInTextFile);

var collector = provider.GetFakeLogCollector();
var logRecord = collector.GetSnapshot().Single(l => l.Category == LoggingCategory);
var state = logRecord.StructuredState;
state.Should().Contain(kvp => kvp.Value == responseString);
state.Should().Contain(kvp => kvp.Value == "Request Value");
state.Should().Contain(kvp => kvp.Value == "Request Value 2,Request Value 3");
state.Should().ContainValue(loggedBodyString);
state.Should().ContainValue("Request Value");
state.Should().ContainValue("Request Value 2,Request Value 3");
}
}
Loading
Loading