Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder
configuration.Resolver.Register(typeof(IServiceConnectionFactory), () => scf);
}

var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, loggerFactory);
var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, null, loggerFactory);

if (hubs?.Count > 0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@ internal interface ICallerClientResultsManager : IClientResultsManager
bool TryCompleteResult(string connectionId, ErrorCompletionMessage message);

void RemoveInvocation(string invocationId);

void SetAckNumber(string invocationId, int ackNumber);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
Expand All @@ -9,6 +9,7 @@
using System.Threading;
using System.Threading.Tasks;

using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
Expand All @@ -21,11 +22,13 @@ namespace Microsoft.Azure.SignalR;
internal class MultiEndpointMessageWriter : IServiceMessageWriter, IPresenceManager
{
private readonly ILogger _logger;
private readonly IClientInvocationManager _clientInvocationManager;

internal HubServiceEndpoint[] TargetEndpoints { get; }

public MultiEndpointMessageWriter(IReadOnlyCollection<ServiceEndpoint> targetEndpoints, ILoggerFactory loggerFactory)
public MultiEndpointMessageWriter(IReadOnlyCollection<ServiceEndpoint> targetEndpoints, IClientInvocationManager invocationManager, ILoggerFactory loggerFactory)
{
_clientInvocationManager = invocationManager;
_logger = loggerFactory.CreateLogger<MultiEndpointMessageWriter>();
var normalized = new List<HubServiceEndpoint>();
if (targetEndpoints != null)
Expand All @@ -52,6 +55,19 @@ public MultiEndpointMessageWriter(IReadOnlyCollection<ServiceEndpoint> targetEnd

public Task WriteAsync(ServiceMessage serviceMessage)
{
if (serviceMessage is ClientInvocationMessage invocationMessage)
{
// Accroding to target endpoints in method `WriteMultiEndpointMessageAsync`
_clientInvocationManager.Caller.SetAckNumber(invocationMessage.InvocationId, TargetEndpoints.Length);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when TargetEndpoints.Length is 0, the result is OK?

Copy link
Contributor Author

@xingsy97 xingsy97 Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class AckHandler could handle such condition correctly. Refer to its method SetExpectedCount

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need to SetAck if Length == 0?

if (TargetEndpoints.Length == 0)
{
_clientInvocationManager.Caller.TryCompleteResult(
invocationMessage.ConnectionId,
CompletionMessage.WithError(invocationMessage.InvocationId, "No available endpoint to send invocation message.")
);
}
}

return WriteMultiEndpointMessageAsync(serviceMessage, connection => connection.WriteAsync(serviceMessage));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ internal class MultiEndpointServiceConnectionContainer : IServiceConnectionConta

private readonly object _lock = new object();

private readonly IClientInvocationManager _clientInvocationManager;

private (bool needRouter, IReadOnlyList<HubServiceEndpoint> endpoints) _routerEndpoints;

private int _started;
Expand All @@ -56,13 +58,15 @@ public MultiEndpointServiceConnectionContainer(
int? maxCount,
IServiceEndpointManager endpointManager,
IMessageRouter router,
IClientInvocationManager clientInvocationManager,
ILoggerFactory loggerFactory,
TimeSpan? scaleTimeout = null
) : this(
hub,
endpoint => CreateContainer(serviceConnectionFactory, endpoint, count, maxCount, loggerFactory),
endpointManager,
router,
clientInvocationManager,
loggerFactory,
scaleTimeout)
{
Expand All @@ -73,6 +77,7 @@ internal MultiEndpointServiceConnectionContainer(
Func<HubServiceEndpoint, IServiceConnectionContainer> generator,
IServiceEndpointManager endpointManager,
IMessageRouter router,
IClientInvocationManager clientInvocationManager,
ILoggerFactory loggerFactory,
TimeSpan? scaleTimeout = null)
{
Expand All @@ -90,6 +95,7 @@ internal MultiEndpointServiceConnectionContainer(
_loggerFactory = loggerFactory;
_logger = loggerFactory?.CreateLogger<MultiEndpointServiceConnectionContainer>() ?? throw new ArgumentNullException(nameof(loggerFactory));
_serviceEndpointManager = endpointManager;
_clientInvocationManager = clientInvocationManager;
_scaleTimeout = scaleTimeout ?? Constants.Periods.DefaultScaleTimeout;

// Reserve generator for potential scale use.
Expand Down Expand Up @@ -158,7 +164,7 @@ public Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, Cancel
public IAsyncEnumerable<GroupMember> ListConnectionsInGroupAsync(string groupName, int? top = null, ulong? tracingId = null, CancellationToken token = default)
{
var targetEndpoints = _routerEndpoints.needRouter ? _router.GetEndpointsForGroup(groupName, _routerEndpoints.endpoints) : _routerEndpoints.endpoints;
var messageWriter = new MultiEndpointMessageWriter(targetEndpoints?.ToList(), _loggerFactory);
var messageWriter = new MultiEndpointMessageWriter(targetEndpoints?.ToList(), _clientInvocationManager, _loggerFactory);
return messageWriter.ListConnectionsInGroupAsync(groupName, top, tracingId, token);
}

Expand Down Expand Up @@ -271,7 +277,7 @@ private static IServiceConnectionContainer CreateContainer(IServiceConnectionFac
private MultiEndpointMessageWriter CreateMessageWriter(ServiceMessage serviceMessage)
{
var targetEndpoints = GetRoutedEndpoints(serviceMessage)?.ToList();
return new MultiEndpointMessageWriter(targetEndpoints, _loggerFactory);
return new MultiEndpointMessageWriter(targetEndpoints, _clientInvocationManager, _loggerFactory);
}

private void OnAdd(HubServiceEndpoint endpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ internal class ServiceConnectionContainerFactory : IServiceConnectionContainerFa

private readonly IServiceConnectionFactory _serviceConnectionFactory;

private readonly IClientInvocationManager _clientInvocationManager;

public ServiceConnectionContainerFactory(IServiceConnectionFactory serviceConnectionFactory,
IServiceEndpointManager serviceEndpointManager,
IMessageRouter router,
IServiceEndpointOptions options,
IClientInvocationManager clientInvocationManager,
ILoggerFactory loggerFactory)
{
_serviceConnectionFactory = serviceConnectionFactory;
_serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager));
_router = router ?? throw new ArgumentNullException(nameof(router));
_options = options;
_clientInvocationManager = clientInvocationManager;
_loggerFactory = loggerFactory;
}

Expand All @@ -39,6 +43,7 @@ public IServiceConnectionContainer Create(string hub, TimeSpan? serviceScaleTime
_options.MaxHubServerConnectionCount,
_serviceEndpointManager,
_router,
_clientInvocationManager,
_loggerFactory,
serviceScaleTimeout);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ internal class MultiEndpointConnectionContainerFactory
private readonly IServiceEndpointManager _endpointManager;
private readonly int _connectionCount;
private readonly IEndpointRouter _router;
private readonly IClientInvocationManager _clientInvocationManager;

public MultiEndpointConnectionContainerFactory(IServiceConnectionFactory connectionFactory, ILoggerFactory loggerFactory, IServiceEndpointManager serviceEndpointManager, IOptions<ServiceManagerOptions> options, IEndpointRouter router = null)
public MultiEndpointConnectionContainerFactory(IServiceConnectionFactory connectionFactory, ILoggerFactory loggerFactory, IServiceEndpointManager serviceEndpointManager, IOptions<ServiceManagerOptions> options, IEndpointRouter router = null, IClientInvocationManager clientInvocationManager = null)
{
_connectionFactory = connectionFactory;
_loggerFactory = loggerFactory;
_endpointManager = serviceEndpointManager;
_connectionCount = options.Value.ConnectionCount;
_router = router;
_clientInvocationManager = clientInvocationManager;
}

public MultiEndpointServiceConnectionContainer Create(string hubName)
Expand All @@ -31,8 +33,9 @@ public MultiEndpointServiceConnectionContainer Create(string hubName)
endpoint => new WeakServiceConnectionContainer(_connectionFactory, _connectionCount, endpoint, _loggerFactory.CreateLogger<WeakServiceConnectionContainer>()),
_endpointManager,
_router,
_clientInvocationManager,
_loggerFactory);
return container;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public override ServiceHubContext WithEndpoints(IEnumerable<ServiceEndpoint> end
private sealed class MessageWriterServiceContainerWrapper : MultiEndpointMessageWriter, IServiceConnectionContainer
{
public MessageWriterServiceContainerWrapper(IReadOnlyCollection<ServiceEndpoint> targetEndpoints, ILoggerFactory loggerFactory)
: base(targetEndpoints, loggerFactory) { }
: base(targetEndpoints, null, loggerFactory) { }

public Task StartAsync() => Task.CompletedTask;

Expand Down Expand Up @@ -125,4 +125,4 @@ public void Dispose()
#endregion Not supported method or properties
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
#if NET7_0_OR_GREATER
using System;
Expand Down Expand Up @@ -43,13 +43,8 @@ public Task<T> AddInvocation<T>(string hub, string connectionId, string invocati
cancellationToken,
() => TryCompleteResult(connectionId, CompletionMessage.WithError(invocationId, "Canceled")));

var serviceEndpoints = _serviceEndpointManager.GetEndpoints(hub);
var ackNumber = _endpointRouter.GetEndpointsForConnection(connectionId, serviceEndpoints).Count();

var multiAck = _ackHandler.CreateMultiAck(out var ackId);

_ackHandler.SetExpectedCount(ackId, ackNumber);

// When the caller server is also the client router, Azure SignalR service won't send a ServiceMappingMessage to server.
// To handle this condition, CallerClientResultsManager itself should record this mapping information rather than waiting for a ServiceMappingMessage sent by service. Only in this condition, this method is called with instanceId != null.
var result = _pendingInvocations.TryAdd(invocationId,
Expand Down Expand Up @@ -206,6 +201,14 @@ public void RemoveInvocation(string invocationId)
_pendingInvocations.TryRemove(invocationId, out _);
}

public void SetAckNumber(string invocationId, int ackNumber)
{
if (_pendingInvocations.TryGetValue(invocationId, out var item))
{
_ackHandler.SetExpectedCount(item.AckId, ackNumber);
}
}

// Unused, here to honor the IInvocationBinder interface but should never be called
public IReadOnlyList<Type> GetParameterTypes(string methodName) => throw new NotImplementedException();

Expand All @@ -218,4 +221,4 @@ private record PendingInvocation(Type Type, string ConnectionId, object Tcs, int
}
}
}
#endif
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ private IServiceConnectionContainer GetServiceConnectionContainer(ConnectionDele
_serviceEndpointManager,
_router,
_options,
_clientInvocationManager,
_loggerFactory
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,14 @@ public override async Task<T> InvokeConnectionAsync<T>(string connectionId, stri

var invocationId = _clientInvocationManager.Caller.GenerateInvocationId(connectionId);
var message = AppendMessageTracingId(new ClientInvocationMessage(invocationId, connectionId, _callerId, SerializeAllProtocols(methodName, args, invocationId)));
await WriteAsync(message);
// The ack number of invocation will be set inside `WriteAsync`. So adding invocation should be first.
var task = _clientInvocationManager.Caller.AddInvocation<T>(_hub, connectionId, invocationId, cancellationToken);
await WriteAsync(message);
if (ServiceConnectionContainer is not MultiEndpointServiceConnectionContainer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds tricky, a ServiceLifttimeManager should not have knowledge of what the container type is, it looks like a strong assumption and is fragile. What if we later change to use other container?

{
// `WriteAsync` in test class `TestServiceConnectionHandler` does not set ack number. Set the number manually.
_clientInvocationManager.Caller.SetAckNumber(invocationId, 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want such logic spread into multiple layers, choose one place for such logic

}

// Exception handling follows https://source.dot.net/#Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs,349
try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ public TestMultiEndpointServiceConnectionContainer(string hub,
Func<HubServiceEndpoint, IServiceConnectionContainer> generator,
IServiceEndpointManager endpoint,
IEndpointRouter router,
ILoggerFactory loggerFactory) : base(hub, generator, endpoint, router, loggerFactory)
ILoggerFactory loggerFactory) : base(hub, generator, endpoint, router, null, loggerFactory)
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public async Task ListConnectionsInGroup(int? top, int resultCount, params int?[
endpoint.ConnectionContainer = containerMock.Object;
targetEndpoints.Add(endpoint);
}
var multiEndpointWriter = new MultiEndpointMessageWriter(targetEndpoints, Mock.Of<ILoggerFactory>());
var multiEndpointWriter = new MultiEndpointMessageWriter(targetEndpoints, null, Mock.Of<ILoggerFactory>());
var resultMembers = new List<GroupMember>();
await foreach (var member in multiEndpointWriter.ListConnectionsInGroupAsync("group", top))
{
Expand Down
Loading
Loading