Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi
private readonly ILogger _logger;
private readonly Subchannel _subchannel;
private readonly TimeSpan _socketPingInterval;
private readonly TimeSpan _connectionIdleTimeout;
private readonly Func<Socket, DnsEndPoint, CancellationToken, ValueTask> _socketConnect;
private readonly List<ActiveStream> _activeStreams;
private readonly Timer _socketConnectedTimer;
Expand All @@ -64,20 +65,23 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi
internal Socket? _initialSocket;
private BalancerAddress? _initialSocketAddress;
private List<ReadOnlyMemory<byte>>? _initialSocketData;
private DateTime? _initialSocketCreatedTime;
private bool _disposed;
private BalancerAddress? _currentAddress;

public SocketConnectivitySubchannelTransport(
Subchannel subchannel,
TimeSpan socketPingInterval,
TimeSpan? connectTimeout,
TimeSpan connectionIdleTimeout,
ILoggerFactory loggerFactory,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
{
_logger = loggerFactory.CreateLogger<SocketConnectivitySubchannelTransport>();
_subchannel = subchannel;
_socketPingInterval = socketPingInterval;
ConnectTimeout = connectTimeout;
_connectionIdleTimeout = connectionIdleTimeout;
_socketConnect = socketConnect ?? OnConnect;
_activeStreams = new List<ActiveStream>();
_socketConnectedTimer = NonCapturingTimer.Create(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
Expand Down Expand Up @@ -125,6 +129,7 @@ private void DisconnectUnsynchronized()
_initialSocket = null;
_initialSocketAddress = null;
_initialSocketData = null;
_initialSocketCreatedTime = null;
_lastEndPointIndex = 0;
_currentAddress = null;
}
Expand Down Expand Up @@ -162,6 +167,7 @@ public async ValueTask<ConnectResult> TryConnectAsync(ConnectContext context)
_initialSocket = socket;
_initialSocketAddress = currentAddress;
_initialSocketData = null;
_initialSocketCreatedTime = DateTime.UtcNow;

// Schedule ping. Don't set a periodic interval to avoid any chance of timer causing the target method to run multiple times in paralle.
// This could happen because of execution delays (e.g. hitting a debugger breakpoint).
Expand Down Expand Up @@ -338,6 +344,7 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
Socket? socket = null;
BalancerAddress? socketAddress = null;
List<ReadOnlyMemory<byte>>? socketData = null;
DateTime? socketCreatedTime = null;
lock (Lock)
{
if (_initialSocket != null)
Expand All @@ -347,9 +354,11 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
socket = _initialSocket;
socketAddress = _initialSocketAddress;
socketData = _initialSocketData;
socketCreatedTime = _initialSocketCreatedTime;
_initialSocket = null;
_initialSocketAddress = null;
_initialSocketData = null;
_initialSocketCreatedTime = null;

// Double check the address matches the socket address and only use socket on match.
// Not sure if this is possible in practice, but better safe than sorry.
Expand All @@ -365,10 +374,23 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat

if (socket != null)
{
if (IsSocketInBadState(socket, address))
Debug.Assert(socketCreatedTime != null);

var closeSocket = false;

if (DateTime.UtcNow > socketCreatedTime.Value.Add(_connectionIdleTimeout))
{
SocketConnectivitySubchannelTransportLog.ClosingSocketFromIdleTimeoutOnCreateStream(_logger, _subchannel.Id, address, _connectionIdleTimeout);
closeSocket = true;
}
else if (IsSocketInBadState(socket, address))
{
SocketConnectivitySubchannelTransportLog.ClosingUnusableSocketOnCreateStream(_logger, _subchannel.Id, address);
closeSocket = true;
}

if (closeSocket)
{
socket.Dispose();
socket = null;
socketData = null;
Expand Down Expand Up @@ -530,6 +552,9 @@ internal static class SocketConnectivitySubchannelTransportLog
private static readonly Action<ILogger, int, BalancerAddress, Exception?> _closingUnusableSocketOnCreateStream =
LoggerMessage.Define<int, BalancerAddress>(LogLevel.Debug, new EventId(16, "ClosingUnusableSocketOnCreateStream"), "Subchannel id '{SubchannelId}' socket {Address} is being closed because it can't be used. The socket either can't receive data or it has received unexpected data.");

private static readonly Action<ILogger, int, BalancerAddress, TimeSpan, Exception?> _closingSocketFromIdleTimeoutOnCreateStream =
LoggerMessage.Define<int, BalancerAddress, TimeSpan>(LogLevel.Debug, new EventId(16, "ClosingSocketFromIdleTimeoutOnCreateStream"), "Subchannel id '{SubchannelId}' socket {Address} is being closed because it exceeds the idle timeout of {IdleTimeout}.");

public static void ConnectingSocket(ILogger logger, int subchannelId, BalancerAddress address)
{
_connectingSocket(logger, subchannelId, address, null);
Expand Down Expand Up @@ -609,5 +634,10 @@ public static void ClosingUnusableSocketOnCreateStream(ILogger logger, int subch
{
_closingUnusableSocketOnCreateStream(logger, subchannelId, address, null);
}

public static void ClosingSocketFromIdleTimeoutOnCreateStream(ILogger logger, int subchannelId, BalancerAddress address, TimeSpan idleTimeout)
{
_closingSocketFromIdleTimeoutOnCreateStream(logger, subchannelId, address, idleTimeout, null);
}
}
#endif
14 changes: 10 additions & 4 deletions src/Grpc.Net.Client/GrpcChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public sealed class GrpcChannel : ChannelBase, IDisposable
internal Uri Address { get; }
internal HttpMessageInvoker HttpInvoker { get; }
internal TimeSpan? ConnectTimeout { get; }
internal TimeSpan? ConnectionIdleTimeout { get; }
internal HttpHandlerType HttpHandlerType { get; }
internal TimeSpan InitialReconnectBackoff { get; }
internal TimeSpan? MaxReconnectBackoff { get; }
Expand Down Expand Up @@ -125,7 +126,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr

var resolverFactory = GetResolverFactory(channelOptions);
ResolveCredentials(channelOptions, out _isSecure, out _callCredentials);
(HttpHandlerType, ConnectTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);
(HttpHandlerType, ConnectTimeout, ConnectionIdleTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);

SubchannelTransportFactory = channelOptions.ResolveService<ISubchannelTransportFactory>(new SubChannelTransportFactory(this));

Expand Down Expand Up @@ -154,7 +155,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
throw new ArgumentException($"Address '{address.OriginalString}' doesn't have a host. Address should include a scheme, host, and optional port. For example, 'https://localhost:5001'.");
}
ResolveCredentials(channelOptions, out _isSecure, out _callCredentials);
(HttpHandlerType, ConnectTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);
(HttpHandlerType, ConnectTimeout, ConnectionIdleTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);
#endif

HttpInvoker = channelOptions.HttpClient ?? CreateInternalHttpInvoker(channelOptions.HttpHandler);
Expand Down Expand Up @@ -243,12 +244,14 @@ private static HttpHandlerContext CalculateHandlerContext(ILogger logger, Uri ad
{
HttpHandlerType type;
TimeSpan? connectTimeout;
TimeSpan? connectionIdleTimeout;

#if NET5_0_OR_GREATER
var socketsHttpHandler = HttpRequestHelpers.GetHttpHandlerType<SocketsHttpHandler>(channelOptions.HttpHandler)!;

type = HttpHandlerType.SocketsHttpHandler;
connectTimeout = socketsHttpHandler.ConnectTimeout;
connectionIdleTimeout = socketsHttpHandler.PooledConnectionIdleTimeout;

// Check if the SocketsHttpHandler is being shared by channels.
// It has already been setup by another channel (i.e. ConnectCallback is set) then
Expand All @@ -261,6 +264,7 @@ private static HttpHandlerContext CalculateHandlerContext(ILogger logger, Uri ad
{
type = HttpHandlerType.Custom;
connectTimeout = null;
connectionIdleTimeout = null;
}
}

Expand All @@ -282,8 +286,9 @@ private static HttpHandlerContext CalculateHandlerContext(ILogger logger, Uri ad
#else
type = HttpHandlerType.SocketsHttpHandler;
connectTimeout = null;
connectionIdleTimeout = null;
#endif
return new HttpHandlerContext(type, connectTimeout);
return new HttpHandlerContext(type, connectTimeout, connectionIdleTimeout);
}
if (HttpRequestHelpers.GetHttpHandlerType<HttpClientHandler>(channelOptions.HttpHandler) != null)
{
Expand Down Expand Up @@ -837,6 +842,7 @@ public ISubchannelTransport Create(Subchannel subchannel)
subchannel,
SocketConnectivitySubchannelTransport.SocketPingInterval,
_channel.ConnectTimeout,
_channel.ConnectionIdleTimeout ?? TimeSpan.FromMinutes(1),
_channel.LoggerFactory,
socketConnect: null);
}
Expand Down Expand Up @@ -895,7 +901,7 @@ public static void AddressPathUnused(ILogger logger, string address)
}
}

private readonly record struct HttpHandlerContext(HttpHandlerType HttpHandlerType, TimeSpan? ConnectTimeout = null);
private readonly record struct HttpHandlerContext(HttpHandlerType HttpHandlerType, TimeSpan? ConnectTimeout = null, TimeSpan? ConnectionIdleTimeout = null);
}

internal enum HttpHandlerType
Expand Down
15 changes: 10 additions & 5 deletions test/FunctionalTests/Balancer/BalancerHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,14 @@ public static Task<GrpcChannel> CreateChannel(
bool? connect = null,
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null)
TimeSpan? connectTimeout = null,
TimeSpan? connectionIdleTimeout = null)
{
var resolver = new TestResolver();
var e = endpoints.Select(i => new BalancerAddress(i.Host, i.Port)).ToList();
resolver.UpdateAddresses(e);

return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout);
return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout);
}

public static async Task<GrpcChannel> CreateChannel(
Expand All @@ -152,12 +153,13 @@ public static async Task<GrpcChannel> CreateChannel(
bool? connect = null,
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null)
TimeSpan? connectTimeout = null,
TimeSpan? connectionIdleTimeout = null)
{
var services = new ServiceCollection();
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<IRandomGenerator>(new TestRandomGenerator());
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, socketConnect));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
services.AddSingleton<LoadBalancerFactory>(new LeastUsedBalancerFactory());

var serviceConfig = new ServiceConfig();
Expand Down Expand Up @@ -214,12 +216,14 @@ internal class TestSubchannelTransportFactory : ISubchannelTransportFactory
{
private readonly TimeSpan _socketPingInterval;
private readonly TimeSpan? _connectTimeout;
private readonly TimeSpan _connectionIdleTimeout;
private readonly Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? _socketConnect;

public TestSubchannelTransportFactory(TimeSpan socketPingInterval, TimeSpan? connectTimeout, Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
public TestSubchannelTransportFactory(TimeSpan socketPingInterval, TimeSpan? connectTimeout, TimeSpan connectionIdleTimeout, Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
{
_socketPingInterval = socketPingInterval;
_connectTimeout = connectTimeout;
_connectionIdleTimeout = connectionIdleTimeout;
_socketConnect = socketConnect;
}

Expand All @@ -230,6 +234,7 @@ public ISubchannelTransport Create(Subchannel subchannel)
subchannel,
_socketPingInterval,
_connectTimeout,
_connectionIdleTimeout,
subchannel._manager.LoggerFactory,
_socketConnect);
#else
Expand Down
39 changes: 39 additions & 0 deletions test/FunctionalTests/Balancer/ConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,45 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
await ExceptionAssert.ThrowsAsync<OperationCanceledException>(() => connectTask).DefaultTimeout();
}

[Test]
public async Task Active_UnaryCall_ConnectionIdleTimeout_SocketRecreated()
{
// Ignore errors
SetExpectedErrorsFilter(writeContext =>
{
return true;
});

Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
{
return Task.FromResult(new HelloReply { Message = request.Name });
}

// Arrange
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));

var connectionIdleTimeout = TimeSpan.FromSeconds(1);
var channel = await BalancerHelpers.CreateChannel(
LoggerFactory,
new PickFirstConfig(),
new[] { endpoint.Address },
connectionIdleTimeout: connectionIdleTimeout).DefaultTimeout();

Logger.LogInformation("Connecting channel.");
await channel.ConnectAsync();

await Task.Delay(connectionIdleTimeout);

var client = TestClientFactory.Create(channel, endpoint.Method);
var response = await client.UnaryCall(new HelloRequest { Name = "Test!" }).ResponseAsync.DefaultTimeout();

// Assert
Assert.AreEqual("Test!", response.Message);

AssertHasLog(LogLevel.Debug, "ClosingSocketFromIdleTimeoutOnCreateStream", "Subchannel id '1' socket 127.0.0.1:50051 is being closed because it exceeds the idle timeout of 00:00:01.");
AssertHasLog(LogLevel.Trace, "ConnectingOnCreateStream", "Subchannel id '1' doesn't have a connected socket available. Connecting new stream socket for 127.0.0.1:50051.");
}

[Test]
public async Task Active_UnaryCall_MultipleStreams_UnavailableAddress_FallbackToWorkingAddress()
{
Expand Down
48 changes: 48 additions & 0 deletions test/Grpc.Net.Client.Tests/Balancer/StreamWrapperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,54 @@ namespace Grpc.Net.Client.Tests.Balancer;
[TestFixture]
public class StreamWrapperTests
{
[Test]
public async Task ReadAsync_ExactSize_Read()
{
// Arrange
var ms = new MemoryStream(new byte[] { 4 });
var data = new List<ReadOnlyMemory<byte>>
{
new byte[] { 1, 2, 3 }
};
var streamWrapper = new StreamWrapper(ms, s => { }, data);
var buffer = new byte[3];

// Act & Assert
Assert.AreEqual(3, await streamWrapper.ReadAsync(buffer));
Assert.AreEqual(1, buffer[0]);
Assert.AreEqual(2, buffer[1]);
Assert.AreEqual(3, buffer[2]);

Assert.AreEqual(1, await streamWrapper.ReadAsync(buffer));
Assert.AreEqual(4, buffer[0]);

Assert.AreEqual(0, await streamWrapper.ReadAsync(buffer));
}

[Test]
public async Task ReadAsync_BiggerThanNeeded_Read()
{
// Arrange
var ms = new MemoryStream(new byte[] { 4 });
var data = new List<ReadOnlyMemory<byte>>
{
new byte[] { 1, 2, 3 }
};
var streamWrapper = new StreamWrapper(ms, s => { }, data);
var buffer = new byte[4];

// Act & Assert
Assert.AreEqual(3, await streamWrapper.ReadAsync(buffer));
Assert.AreEqual(1, buffer[0]);
Assert.AreEqual(2, buffer[1]);
Assert.AreEqual(3, buffer[2]);

Assert.AreEqual(1, await streamWrapper.ReadAsync(buffer));
Assert.AreEqual(4, buffer[0]);

Assert.AreEqual(0, await streamWrapper.ReadAsync(buffer));
}

[Test]
public async Task ReadAsync_MultipleInitialData_ReadInOrder()
{
Expand Down
28 changes: 28 additions & 0 deletions test/Grpc.Net.Client.Tests/GrpcChannelTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,34 @@ public void Build_InsecureCredentialsWithHttps_ThrowsError()
Assert.AreEqual("Channel is configured with insecure channel credentials and can't use a HttpClient with a 'https' scheme.", ex.Message);
}

#if SUPPORT_LOAD_BALANCING
[Test]
public void Build_ConnectTimeout_ReadFromSocketsHttpHandler()
{
// Arrange & Act
var channel = GrpcChannel.ForAddress("https://localhost", CreateGrpcChannelOptions(o => o.HttpHandler = new SocketsHttpHandler
{
ConnectTimeout = TimeSpan.FromSeconds(1)
}));

// Assert
Assert.AreEqual(TimeSpan.FromSeconds(1), channel.ConnectTimeout);
}

[Test]
public void Build_ConnectionIdleTimeout_ReadFromSocketsHttpHandler()
{
// Arrange & Act
var channel = GrpcChannel.ForAddress("https://localhost", CreateGrpcChannelOptions(o => o.HttpHandler = new SocketsHttpHandler
{
PooledConnectionIdleTimeout = TimeSpan.FromSeconds(1)
}));

// Assert
Assert.AreEqual(TimeSpan.FromSeconds(1), channel.ConnectionIdleTimeout);
}
#endif

[Test]
public void Build_HttpClientAndHttpHandler_ThrowsError()
{
Expand Down