Skip to content

Commit ef84205

Browse files
committed
CSHARP-5777: Avoid ThreadPool-dependent IO methods in sync API
1 parent d2ff8ab commit ef84205

File tree

4 files changed

+112
-99
lines changed

4 files changed

+112
-99
lines changed

src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs

Lines changed: 99 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -38,41 +38,12 @@ public static void EfficientCopyTo(this Stream input, Stream output)
3838

3939
public static int Read(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
4040
{
41-
try
42-
{
43-
using var manualResetEvent = new ManualResetEventSlim();
44-
var readOperation = stream.BeginRead(
45-
buffer,
46-
offset,
47-
count,
48-
state => ((ManualResetEventSlim)state.AsyncState).Set(),
49-
manualResetEvent);
50-
51-
if (readOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken))
52-
{
53-
return stream.EndRead(readOperation);
54-
}
55-
}
56-
catch (OperationCanceledException)
57-
{
58-
// Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed.
59-
}
60-
catch (ObjectDisposedException)
61-
{
62-
throw new IOException();
63-
}
64-
65-
try
66-
{
67-
stream.Dispose();
68-
}
69-
catch
70-
{
71-
// Ignore any exceptions
72-
}
73-
74-
cancellationToken.ThrowIfCancellationRequested();
75-
throw new TimeoutException();
41+
return UseStreamWithTimeout(
42+
stream,
43+
(str, state) => str.Read(state.buffer, state.offset, state.count),
44+
(buffer, offset, count),
45+
timeout,
46+
cancellationToken);
7647
}
7748

7849
public static async Task<int> ReadAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
@@ -219,43 +190,16 @@ public static async Task ReadBytesAsync(this Stream stream, byte[] destination,
219190

220191
public static void Write(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
221192
{
222-
try
223-
{
224-
using var manualResetEvent = new ManualResetEventSlim();
225-
var writeOperation = stream.BeginWrite(
226-
buffer,
227-
offset,
228-
count,
229-
state => ((ManualResetEventSlim)state.AsyncState).Set(),
230-
manualResetEvent);
231-
232-
if (writeOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken))
193+
UseStreamWithTimeout(
194+
stream,
195+
(str, state) =>
233196
{
234-
stream.EndWrite(writeOperation);
235-
return;
236-
}
237-
}
238-
catch (OperationCanceledException)
239-
{
240-
// Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed.
241-
}
242-
catch (ObjectDisposedException)
243-
{
244-
// It's possible to get ObjectDisposedException when the connection pool was closed with interruptInUseConnections set to true.
245-
throw new IOException();
246-
}
247-
248-
try
249-
{
250-
stream.Dispose();
251-
}
252-
catch
253-
{
254-
// Ignore any exceptions
255-
}
256-
257-
cancellationToken.ThrowIfCancellationRequested();
258-
throw new TimeoutException();
197+
str.Write(state.buffer, state.offset, state.count);
198+
return true;
199+
},
200+
(buffer, offset, count),
201+
timeout,
202+
cancellationToken);
259203
}
260204

261205
public static async Task WriteAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
@@ -325,5 +269,89 @@ public static async Task WriteBytesAsync(this Stream stream, OperationContext op
325269
count -= bytesToWrite;
326270
}
327271
}
272+
273+
private static TResult UseStreamWithTimeout<TResult, TState>(Stream stream, Func<Stream, TState, TResult> method, TState state, TimeSpan timeout, CancellationToken cancellationToken)
274+
{
275+
StreamDisposeCallbackState callbackState = null;
276+
Timer timer = null;
277+
CancellationTokenRegistration cancellationSubscription = default;
278+
if (timeout != Timeout.InfiniteTimeSpan)
279+
{
280+
callbackState = new StreamDisposeCallbackState(stream);
281+
timer = new Timer(DisposeStreamCallback, callbackState, timeout, Timeout.InfiniteTimeSpan);
282+
}
283+
284+
if (cancellationToken.CanBeCanceled)
285+
{
286+
callbackState ??= new StreamDisposeCallbackState(stream);
287+
cancellationSubscription = cancellationToken.Register(DisposeStreamCallback, callbackState);
288+
}
289+
290+
try
291+
{
292+
var result = method(stream, state);
293+
if (callbackState?.TryChangeState(OperationState.Done) == false)
294+
{
295+
// if cannot change the state - then the stream was/will be disposed, throw here
296+
throw new IOException();
297+
}
298+
299+
return result;
300+
}
301+
catch (IOException)
302+
{
303+
if (callbackState?.OperationState == OperationState.Cancelled)
304+
{
305+
cancellationToken.ThrowIfCancellationRequested();
306+
throw new TimeoutException();
307+
}
308+
309+
throw;
310+
}
311+
finally
312+
{
313+
timer?.Dispose();
314+
cancellationSubscription.Dispose();
315+
}
316+
317+
static void DisposeStreamCallback(object state)
318+
{
319+
var disposeCallbackState = (StreamDisposeCallbackState)state;
320+
if (!disposeCallbackState.TryChangeState(OperationState.Cancelled))
321+
{
322+
// if cannot change the state - then I/O was already succeeded
323+
return;
324+
}
325+
326+
try
327+
{
328+
disposeCallbackState.Stream.Dispose();
329+
}
330+
catch (Exception)
331+
{
332+
// callbacks should not fail, suppress any exceptions here
333+
}
334+
}
335+
}
336+
337+
private record StreamDisposeCallbackState(Stream Stream)
338+
{
339+
private int _operationState = 0;
340+
341+
public OperationState OperationState
342+
{
343+
get => (OperationState)_operationState;
344+
}
345+
346+
public bool TryChangeState(OperationState newState) =>
347+
Interlocked.CompareExchange(ref _operationState, (int)newState, (int)OperationState.InProgress) == (int)OperationState.InProgress;
348+
}
349+
350+
private enum OperationState
351+
{
352+
InProgress = 0,
353+
Done,
354+
Cancelled,
355+
}
328356
}
329357
}

tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -811,19 +811,8 @@ public async Task SendMessage_should_put_the_message_on_the_stream_and_raise_the
811811

812812
private void SetupStreamRead(Mock<Stream> streamMock, TaskCompletionSource<int> tcs)
813813
{
814-
streamMock.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
815-
.Returns((byte[] _, int __, int ___, AsyncCallback callback, object state) =>
816-
{
817-
var innerTcs = new TaskCompletionSource<int>(state);
818-
tcs.Task.ContinueWith(t =>
819-
{
820-
innerTcs.TrySetException(t.Exception.InnerException);
821-
callback(innerTcs.Task);
822-
});
823-
return innerTcs.Task;
824-
});
825-
streamMock.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
826-
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
814+
streamMock.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
815+
.Returns((byte[] _, int __, int ___) => tcs.Task.GetAwaiter().GetResult());
827816
streamMock.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
828817
.Returns(tcs.Task);
829818
streamMock.Setup(s => s.Close()).Callback(() => tcs.TrySetException(new ObjectDisposedException("stream")));

tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,18 @@ public async Task ReadBytes_with_byte_array_should_have_expected_effect_for_part
9090
var bytes = new byte[] { 1, 2, 3 };
9191
var n = 0;
9292
var position = 0;
93-
Task<int> ReadPartial (byte[] buffer, int offset, int count)
93+
int ReadPartial (byte[] buffer, int offset, int count)
9494
{
9595
var length = partition[n++];
9696
Buffer.BlockCopy(bytes, position, buffer, offset, length);
9797
position += length;
98-
return Task.FromResult(length);
98+
return length;
9999
}
100100

101101
mockStream.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
102-
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count));
103-
mockStream.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
104-
.Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count));
105-
mockStream.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
106-
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
102+
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count)));
103+
mockStream.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
104+
.Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count));
107105
var destination = new byte[3];
108106

109107
if (async)
@@ -267,20 +265,18 @@ public async Task ReadBytes_with_byte_buffer_should_have_expected_effect_for_par
267265
var destination = new ByteArrayBuffer(new byte[3], 3);
268266
var n = 0;
269267
var position = 0;
270-
Task<int> ReadPartial (byte[] buffer, int offset, int count)
268+
int ReadPartial (byte[] buffer, int offset, int count)
271269
{
272270
var length = partition[n++];
273271
Buffer.BlockCopy(bytes, position, buffer, offset, length);
274272
position += length;
275-
return Task.FromResult(length);
273+
return length;
276274
}
277275

278276
mockStream.Setup(s => s.ReadAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
279-
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count));
280-
mockStream.Setup(s => s.BeginRead(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
281-
.Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count));
282-
mockStream.Setup(s => s.EndRead(It.IsAny<IAsyncResult>()))
283-
.Returns<IAsyncResult>(x => ((Task<int>)x).GetAwaiter().GetResult());
277+
.Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count)));
278+
mockStream.Setup(s => s.Read(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
279+
.Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count));
284280

285281
if (async)
286282
{

tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public void Heartbeat_should_be_emitted_before_connection_open()
9696

9797
var mockStream = new Mock<Stream>();
9898
mockStream
99-
.Setup(s => s.BeginWrite(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), It.IsAny<AsyncCallback>(), It.IsAny<object>()))
99+
.Setup(s => s.Write(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>()))
100100
.Callback(() => EnqueueEvent(HelloReceivedEvent))
101101
.Throws(new Exception("Stream is closed."));
102102

0 commit comments

Comments
 (0)