|  | 
|  | 1 | +// Copyright (c) Microsoft. All rights reserved. | 
|  | 2 | +// Licensed under the MIT license. See LICENSE file in the project root for full license information. | 
|  | 3 | + | 
|  | 4 | +using System; | 
|  | 5 | +using System.Collections; | 
|  | 6 | +using System.Collections.Concurrent; | 
|  | 7 | +using System.Collections.Generic; | 
|  | 8 | +using System.Diagnostics.CodeAnalysis; | 
|  | 9 | +using System.Linq; | 
|  | 10 | +using System.Threading; | 
|  | 11 | +using System.Threading.Tasks; | 
|  | 12 | + | 
|  | 13 | +using Microsoft.AspNetCore.SignalR.Client; | 
|  | 14 | + | 
|  | 15 | +using Xunit; | 
|  | 16 | + | 
|  | 17 | +namespace Microsoft.Azure.SignalR.E2ETests.Common; | 
|  | 18 | + | 
|  | 19 | +public class TestHubConnectionFactory | 
|  | 20 | +{ | 
|  | 21 | +    public bool EnableStatefulReconnect { get; set; } | 
|  | 22 | + | 
|  | 23 | +    public ITestHubConnection NewConnection(string host, string? hub = null, string userId = "foo") | 
|  | 24 | +    { | 
|  | 25 | +        return new TestHubConnection(host, hub) | 
|  | 26 | +        { | 
|  | 27 | +            User = userId, | 
|  | 28 | +            EnableStatefulReconnect = EnableStatefulReconnect | 
|  | 29 | +        }; | 
|  | 30 | +    } | 
|  | 31 | + | 
|  | 32 | +    public ITestHubConnectionGroup NewConnectionGroup(string host, int count, string? hub = null, string userId = "foo") | 
|  | 33 | +    { | 
|  | 34 | +        return new TestHubConnectionGroup(host, count, hub) | 
|  | 35 | +        { | 
|  | 36 | +            User = userId, | 
|  | 37 | +            EnableStatefulReconnect = EnableStatefulReconnect | 
|  | 38 | +        }; | 
|  | 39 | +    } | 
|  | 40 | + | 
|  | 41 | +    private sealed class TestHubConnection(string host, string? hub = null) : ITestHubConnection | 
|  | 42 | +    { | 
|  | 43 | +        private readonly ConcurrentDictionary<string, TaskCompletionSource<string[]>> _expectedInvokes = new(); | 
|  | 44 | + | 
|  | 45 | +        private HubConnection? _hubConnection; | 
|  | 46 | + | 
|  | 47 | +        private volatile int _messageCount; | 
|  | 48 | + | 
|  | 49 | +        public string User { get; init; } = "foo"; | 
|  | 50 | + | 
|  | 51 | +        public int MessageCount => _messageCount; | 
|  | 52 | + | 
|  | 53 | +        public bool EnableStatefulReconnect { get; init; } | 
|  | 54 | + | 
|  | 55 | +        public string ConnectionId => _hubConnection?.ConnectionId ?? throw NotReady; | 
|  | 56 | + | 
|  | 57 | +        private static Exception NotReady { get; } = new InvalidOperationException("HubConnection is not in connected state."); | 
|  | 58 | + | 
|  | 59 | +        public Task StartAsync() | 
|  | 60 | +        { | 
|  | 61 | +            BuildConnectionIfNull(); | 
|  | 62 | +            return _hubConnection.StartAsync(); | 
|  | 63 | +        } | 
|  | 64 | + | 
|  | 65 | +        public Task StopAsync() => _hubConnection?.StopAsync() ?? Task.CompletedTask; | 
|  | 66 | + | 
|  | 67 | +        public Task SendAsync(string method, params string[] messages) | 
|  | 68 | +        { | 
|  | 69 | +            if (_hubConnection == null || _hubConnection.State != HubConnectionState.Connected) | 
|  | 70 | +            { | 
|  | 71 | +                throw NotReady; | 
|  | 72 | +            } | 
|  | 73 | +            return _hubConnection.SendCoreAsync(method, messages); | 
|  | 74 | +        } | 
|  | 75 | + | 
|  | 76 | +        public void Listen(params string[] methods) | 
|  | 77 | +        { | 
|  | 78 | +            BuildConnectionIfNull(); | 
|  | 79 | +            foreach (var method in methods) | 
|  | 80 | +            { | 
|  | 81 | +                _expectedInvokes.TryAdd(method, new TaskCompletionSource<string[]>()); | 
|  | 82 | +                _hubConnection.On(method, (Action<string>)(message => Invoke(method, message))); | 
|  | 83 | +            } | 
|  | 84 | +        } | 
|  | 85 | + | 
|  | 86 | +        public void ResetInvoke(string method) | 
|  | 87 | +        { | 
|  | 88 | +            _expectedInvokes.AddOrUpdate(method, | 
|  | 89 | +                                         new TaskCompletionSource<string[]>(), | 
|  | 90 | +                                         (method, ov) => ov.Task.IsCompleted ? new TaskCompletionSource<string[]>() : ov); | 
|  | 91 | +        } | 
|  | 92 | + | 
|  | 93 | +        public async Task ExpectInvokeAsync(string method, params string[] messages) | 
|  | 94 | +        { | 
|  | 95 | +            Assert.Equal(messages, await _expectedInvokes[method].Task); | 
|  | 96 | +        } | 
|  | 97 | + | 
|  | 98 | +        public void ResetMessageCount() | 
|  | 99 | +        { | 
|  | 100 | +            _messageCount = 0; | 
|  | 101 | +        } | 
|  | 102 | + | 
|  | 103 | +        private void Invoke(string method, params string[] messages) | 
|  | 104 | +        { | 
|  | 105 | +            Interlocked.Increment(ref _messageCount); | 
|  | 106 | +            if (_expectedInvokes.TryGetValue(method, out var source)) | 
|  | 107 | +            { | 
|  | 108 | +                source.TrySetResult(messages); | 
|  | 109 | +            } | 
|  | 110 | +        } | 
|  | 111 | + | 
|  | 112 | +        [MemberNotNull(nameof(_hubConnection))] | 
|  | 113 | +        private void BuildConnectionIfNull() | 
|  | 114 | +        { | 
|  | 115 | +            if (_hubConnection == null) | 
|  | 116 | +            { | 
|  | 117 | +                hub ??= nameof(TestHub); | 
|  | 118 | +                var url = $"{host}/{hub}?user={User}"; | 
|  | 119 | +                var builder = new HubConnectionBuilder().WithUrl(url); | 
|  | 120 | + | 
|  | 121 | +                if (EnableStatefulReconnect) | 
|  | 122 | +                { | 
|  | 123 | +                    builder = builder.WithStatefulReconnect(); | 
|  | 124 | +                } | 
|  | 125 | +                _hubConnection = builder.Build(); | 
|  | 126 | +            } | 
|  | 127 | +        } | 
|  | 128 | +    } | 
|  | 129 | + | 
|  | 130 | +    private sealed class TestHubConnectionGroup(string host, int count, string? hub = null) : ITestHubConnectionGroup | 
|  | 131 | +    { | 
|  | 132 | +        private List<ITestHubConnection>? _connections; | 
|  | 133 | + | 
|  | 134 | +        public bool EnableStatefulReconnect { get; init; } | 
|  | 135 | + | 
|  | 136 | +        public IEnumerable<ITestHubConnection> Connections | 
|  | 137 | +        { | 
|  | 138 | +            get | 
|  | 139 | +            { | 
|  | 140 | +                _connections ??= (from i in Enumerable.Range(0, count) | 
|  | 141 | +                                  select new TestHubConnection(host, hub) | 
|  | 142 | +                                  { | 
|  | 143 | +                                      User = User, | 
|  | 144 | +                                      EnableStatefulReconnect = EnableStatefulReconnect, | 
|  | 145 | +                                  } as ITestHubConnection).ToList(); | 
|  | 146 | +                return _connections; | 
|  | 147 | +            } | 
|  | 148 | +        } | 
|  | 149 | + | 
|  | 150 | +        public string User { get; init; } = string.Empty; | 
|  | 151 | + | 
|  | 152 | +        public int MessageCount => Connections.Select(x => x.MessageCount).Sum(); | 
|  | 153 | + | 
|  | 154 | +        public string ConnectionId => throw new NotImplementedException("Connection group does not have ConnectionId."); | 
|  | 155 | + | 
|  | 156 | +        public void Listen(params string[] methods) | 
|  | 157 | +        { | 
|  | 158 | +            foreach (var connection in Connections) | 
|  | 159 | +            { | 
|  | 160 | +                connection.Listen(methods); | 
|  | 161 | +            } | 
|  | 162 | +        } | 
|  | 163 | + | 
|  | 164 | +        public async Task StartAsync() => await Task.WhenAll(Connections.Select(x => x.StartAsync())); | 
|  | 165 | + | 
|  | 166 | +        public IEnumerator<ITestHubConnection> GetEnumerator() => Connections.GetEnumerator(); | 
|  | 167 | + | 
|  | 168 | +        IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); | 
|  | 169 | + | 
|  | 170 | +        public async Task StopAsync() => await Task.WhenAll(Connections.Select(x => x.StopAsync())); | 
|  | 171 | + | 
|  | 172 | +        public void ResetInvoke(string method) | 
|  | 173 | +        { | 
|  | 174 | +            foreach (var connection in Connections) | 
|  | 175 | +            { | 
|  | 176 | +                connection.ResetInvoke(method); | 
|  | 177 | +            } | 
|  | 178 | +        } | 
|  | 179 | + | 
|  | 180 | +        public Task ExpectInvokeAsync(string method, params string[] messages) | 
|  | 181 | +        { | 
|  | 182 | +            return Task.WhenAll(Connections.Select(x => x.ExpectInvokeAsync(method, messages))); | 
|  | 183 | +        } | 
|  | 184 | + | 
|  | 185 | +        public Task SendAsync(string method, params string[] messages) => Task.WhenAll(Connections.Select(x => x.SendAsync(method, messages))); | 
|  | 186 | + | 
|  | 187 | +        public void ResetMessageCount() | 
|  | 188 | +        { | 
|  | 189 | +            foreach (var connection in Connections) | 
|  | 190 | +            { | 
|  | 191 | +                connection.ResetMessageCount(); | 
|  | 192 | +            } | 
|  | 193 | +        } | 
|  | 194 | +    } | 
|  | 195 | +} | 
0 commit comments