diff --git a/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs b/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs index 2c9fe7e587..ab19425b9e 100644 --- a/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs +++ b/src/client/Microsoft.Identity.Client/AppConfig/ApplicationConfiguration.cs @@ -11,6 +11,7 @@ using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Extensibility; using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Instance; using Microsoft.Identity.Client.Instance.Discovery; using Microsoft.Identity.Client.Internal.Broker; @@ -124,6 +125,8 @@ public string ClientVersion public Func> AppTokenProvider; + internal IRetryPolicyFactory RetryPolicyFactory { get; set; } + #region ClientCredentials // Indicates if claims or assertions are used within the configuration @@ -207,6 +210,5 @@ public X509Certificate2 ClientCredentialCertificate public IDeviceAuthManager DeviceAuthManagerForTest { get; set; } public bool IsInstanceDiscoveryEnabled { get; internal set; } = true; #endregion - } } diff --git a/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs b/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs index 2fc2fa6f95..b60ae2dbe0 100644 --- a/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs +++ b/src/client/Microsoft.Identity.Client/AppConfig/BaseAbstractApplicationBuilder.cs @@ -14,6 +14,8 @@ using Microsoft.Identity.Client.Utils; using Microsoft.IdentityModel.Abstractions; using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.Http.Retry; + #if SUPPORTS_SYSTEM_TEXT_JSON using System.Text.Json; #else @@ -31,6 +33,12 @@ public abstract class BaseAbstractApplicationBuilder internal BaseAbstractApplicationBuilder(ApplicationConfiguration configuration) { Config = configuration; + + // Ensure the default retry policy factory is set if the test factory was not provided + if (Config.RetryPolicyFactory == null) + { + Config.RetryPolicyFactory = new RetryPolicyFactory(); + } } internal ApplicationConfiguration Config { get; } @@ -227,6 +235,17 @@ public T WithClientVersion(string clientVersion) return this as T; } + /// + /// Internal only: Allows tests to inject a custom retry policy factory. + /// + /// The retry policy factory to use. + /// The builder for chaining. + internal T WithRetryPolicyFactory(IRetryPolicyFactory factory) + { + Config.RetryPolicyFactory = factory; + return (T)this; + } + internal virtual ApplicationConfiguration BuildConfiguration() { ResolveAuthority(); diff --git a/src/client/Microsoft.Identity.Client/Http/HttpManager.cs b/src/client/Microsoft.Identity.Client/Http/HttpManager.cs index 4e54e426f5..f50d325498 100644 --- a/src/client/Microsoft.Identity.Client/Http/HttpManager.cs +++ b/src/client/Microsoft.Identity.Client/Http/HttpManager.cs @@ -13,6 +13,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.Http.Retry; namespace Microsoft.Identity.Client.Http { @@ -110,10 +111,11 @@ public async Task SendRequestAsync( logger.Error("The HTTP request failed. " + exception.Message); timeoutException = exception; } - - while (!_disableInternalRetries && retryPolicy.PauseForRetry(response, timeoutException, retryCount)) + + while (!_disableInternalRetries && await retryPolicy.PauseForRetryAsync(response, timeoutException, retryCount, logger).ConfigureAwait(false)) { - logger.Warning($"Retry condition met. Retry count: {retryCount++} after waiting {retryPolicy.DelayInMilliseconds}ms."); + retryCount++; + return await SendRequestAsync( endpoint, headers, diff --git a/src/client/Microsoft.Identity.Client/Http/IHttpManager.cs b/src/client/Microsoft.Identity.Client/Http/IHttpManager.cs index 87cd955a2a..04e60f0619 100644 --- a/src/client/Microsoft.Identity.Client/Http/IHttpManager.cs +++ b/src/client/Microsoft.Identity.Client/Http/IHttpManager.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.Http.Retry; namespace Microsoft.Identity.Client.Http { diff --git a/src/client/Microsoft.Identity.Client/Http/IRetryPolicy.cs b/src/client/Microsoft.Identity.Client/Http/IRetryPolicy.cs deleted file mode 100644 index db3b466759..0000000000 --- a/src/client/Microsoft.Identity.Client/Http/IRetryPolicy.cs +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace Microsoft.Identity.Client.Http -{ - internal interface IRetryPolicy - { - int DelayInMilliseconds { get; } - bool PauseForRetry(HttpResponse response, Exception exception, int retryCount); - } -} diff --git a/src/client/Microsoft.Identity.Client/Http/LinearRetryPolicy.cs b/src/client/Microsoft.Identity.Client/Http/LinearRetryPolicy.cs deleted file mode 100644 index fc9ce94dec..0000000000 --- a/src/client/Microsoft.Identity.Client/Http/LinearRetryPolicy.cs +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Reflection; -using System.Threading.Tasks; - -namespace Microsoft.Identity.Client.Http -{ - internal class LinearRetryPolicy : IRetryPolicy - { - // referenced in unit tests, cannot be private - public static int numRetries { get; private set; } = 0; - public const int DefaultStsMaxRetries = 1; - // this will be overridden in the unit tests so that they run faster - public static int DefaultStsRetryDelayMs { get; set; } = 1000; - - private int _maxRetries; - private readonly Func _retryCondition; - public int DelayInMilliseconds { private set; get; } - - public LinearRetryPolicy(int delayMilliseconds, int maxRetries, Func retryCondition) - { - DelayInMilliseconds = delayMilliseconds; - _maxRetries = maxRetries; - _retryCondition = retryCondition; - } - - public bool PauseForRetry(HttpResponse response, Exception exception, int retryCount) - { - // referenced in the unit tests - numRetries = retryCount + 1; - - return retryCount < _maxRetries && _retryCondition(response, exception); - } - } -} diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/DefaultRetryPolicy.cs b/src/client/Microsoft.Identity.Client/Http/Retry/DefaultRetryPolicy.cs new file mode 100644 index 0000000000..63b06c0e77 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/Http/Retry/DefaultRetryPolicy.cs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Core; + +namespace Microsoft.Identity.Client.Http.Retry +{ + class DefaultRetryPolicy : IRetryPolicy + { + // referenced in unit tests + public const int DefaultStsMaxRetries = 1; + public const int DefaultManagedIdentityMaxRetries = 3; + + private const int DefaultStsRetryDelayMs = 1000; + private const int DefaultManagedIdentityRetryDelayMs = 1000; + + public readonly int _defaultRetryDelayMs; + private readonly int _maxRetries; + private readonly Func _retryCondition; + private readonly LinearRetryStrategy _linearRetryStrategy = new LinearRetryStrategy(); + + public DefaultRetryPolicy(RequestType requestType) + { + switch (requestType) + { + case RequestType.ManagedIdentityDefault: + _defaultRetryDelayMs = DefaultManagedIdentityRetryDelayMs; + _maxRetries = DefaultManagedIdentityMaxRetries; + _retryCondition = HttpRetryConditions.DefaultManagedIdentity; + break; + case RequestType.STS: + _defaultRetryDelayMs = DefaultStsRetryDelayMs; + _maxRetries = DefaultStsMaxRetries; + _retryCondition = HttpRetryConditions.Sts; + break; + default: + throw new ArgumentOutOfRangeException(nameof(requestType), requestType, "Unknown request type"); + } + } + + internal virtual Task DelayAsync(int milliseconds) + { + return Task.Delay(milliseconds); + } + + public async Task PauseForRetryAsync(HttpResponse response, Exception exception, int retryCount, ILoggerAdapter logger) + { + // Check if the status code is retriable and if the current retry count is less than max retries + if (_retryCondition(response, exception) && + retryCount < _maxRetries) + { + // Use HeadersAsDictionary to check for "Retry-After" header + string retryAfter = string.Empty; + if (response?.HeadersAsDictionary != null) + { + response.HeadersAsDictionary.TryGetValue("Retry-After", out retryAfter); + } + + int retryAfterDelay = _linearRetryStrategy.CalculateDelay(retryAfter, _defaultRetryDelayMs); + + logger.Warning($"Retrying request in {retryAfterDelay}ms (retry attempt: {retryCount + 1})"); + + // Pause execution for the calculated delay + await DelayAsync(retryAfterDelay).ConfigureAwait(false); + + return true; + } + + // If the status code is not retriable or max retries have been reached, do not retry + return false; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/ExponentialRetryStrategy.cs b/src/client/Microsoft.Identity.Client/Http/Retry/ExponentialRetryStrategy.cs new file mode 100644 index 0000000000..bdc4abc008 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/Http/Retry/ExponentialRetryStrategy.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; + +namespace Microsoft.Identity.Client.Http.Retry +{ + internal class ExponentialRetryStrategy + { + // Minimum backoff time in milliseconds + private int _minExponentialBackoff; + // Maximum backoff time in milliseconds + private int _maxExponentialBackoff; + // Maximum backoff time in milliseconds + private int _exponentialDeltaBackoff; + + public ExponentialRetryStrategy(int minExponentialBackoff, int maxExponentialBackoff, int exponentialDeltaBackoff) + { + _minExponentialBackoff = minExponentialBackoff; + _maxExponentialBackoff = maxExponentialBackoff; + _exponentialDeltaBackoff = exponentialDeltaBackoff; + } + + /// + /// Calculates the exponential delay based on the current retry attempt. + /// + /// The current retry attempt number. + /// The calculated exponential delay in milliseconds. + /// + /// The delay is calculated using the formula: + /// - If is 0, it returns the minimum backoff time. + /// - Otherwise, it calculates the delay as the minimum of: + /// - (2^(currentRetry - 1)) * deltaBackoff + /// - maxBackoff + /// This ensures that the delay increases exponentially with each retry attempt, + /// but does not exceed the maximum backoff time. + /// + public int CalculateDelay(int currentRetry) + { + // Attempt 1 + if (currentRetry == 0) + { + return _minExponentialBackoff; + } + + // Attempt 2+ + int exponentialDelay = Math.Min( + (int)(Math.Pow(2, currentRetry - 1) * _exponentialDeltaBackoff), + _maxExponentialBackoff + ); + + return exponentialDelay; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/Http/HttpRetryCondition.cs b/src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs similarity index 61% rename from src/client/Microsoft.Identity.Client/Http/HttpRetryCondition.cs rename to src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs index cf83c25516..800a5280df 100644 --- a/src/client/Microsoft.Identity.Client/Http/HttpRetryCondition.cs +++ b/src/client/Microsoft.Identity.Client/Http/Retry/HttpRetryCondition.cs @@ -4,15 +4,15 @@ using System; using System.Threading.Tasks; -namespace Microsoft.Identity.Client.Http +namespace Microsoft.Identity.Client.Http.Retry { internal static class HttpRetryConditions { /// /// Retry policy specific to managed identity flow. - /// Avoid changing this, as it's breaking change. + /// Avoid changing this, as it's a breaking change. /// - public static bool ManagedIdentity(HttpResponse response, Exception exception) + public static bool DefaultManagedIdentity(HttpResponse response, Exception exception) { if (exception != null) { @@ -21,12 +21,32 @@ public static bool ManagedIdentity(HttpResponse response, Exception exception) return (int)response.StatusCode switch { - //Not Found + // Not Found, Request Timeout, Too Many Requests, Server Error, Service Unavailable, Gateway Timeout 404 or 408 or 429 or 500 or 503 or 504 => true, _ => false, }; } + /// + /// Retry policy specific to IMDS Managed Identity. + /// + public static bool Imds(HttpResponse response, Exception exception) + { + if (exception != null) + { + return exception is TaskCanceledException ? true : false; + } + + return (int)response.StatusCode switch + { + // Not Found, Request Timeout, Gone, Too Many Requests + 404 or 408 or 410 or 429 => true, + // Server Error range + >= 500 and <= 599 => true, + _ => false, + }; + } + /// /// Retry condition for /token and /authorize endpoints /// diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/IRetryPolicy.cs b/src/client/Microsoft.Identity.Client/Http/Retry/IRetryPolicy.cs new file mode 100644 index 0000000000..ea8c80b809 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/Http/Retry/IRetryPolicy.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Core; + +namespace Microsoft.Identity.Client.Http.Retry +{ + // Interface for implementing retry logic for HTTP requests. + // Determines if a retry should occur and handles pause logic between retries. + internal interface IRetryPolicy + { + /// + /// Determines whether a retry should be attempted for a given HTTP response or exception, + /// and performs any necessary pause or delay logic before the next retry attempt. + /// + /// The HTTP response received from the request. + /// The exception encountered during the request. + /// The current retry attempt count. + /// The logger used for diagnostic and informational messages. + /// A task that returns true if a retry should be performed; otherwise, false. + Task PauseForRetryAsync(HttpResponse response, Exception exception, int retryCount, ILoggerAdapter logger); + } +} diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/IRetryPolicyFactory.cs b/src/client/Microsoft.Identity.Client/Http/Retry/IRetryPolicyFactory.cs new file mode 100644 index 0000000000..c40d9c49b2 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/Http/Retry/IRetryPolicyFactory.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Identity.Client.Http.Retry +{ + internal interface IRetryPolicyFactory + { + IRetryPolicy GetRetryPolicy(RequestType requestType); + } +} diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/ImdsRetryPolicy.cs b/src/client/Microsoft.Identity.Client/Http/Retry/ImdsRetryPolicy.cs new file mode 100644 index 0000000000..6ed44e0745 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/Http/Retry/ImdsRetryPolicy.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Net; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Core; + +namespace Microsoft.Identity.Client.Http.Retry +{ + // https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/main/docs/imds_retry_based_on_errors.md + internal class ImdsRetryPolicy : IRetryPolicy + { + // referenced in unit tests + public const int ExponentialStrategyNumRetries = 3; + public const int LinearStrategyNumRetries = 7; + + private const int MinExponentialBackoffMs = 1000; + private const int MaxExponentialBackoffMs = 4000; + private const int ExponentialDeltaBackoffMs = 2000; + private const int HttpStatusGoneRetryAfterMs = 10000; + + private int _maxRetries; + + private readonly ExponentialRetryStrategy _exponentialRetryStrategy = new ExponentialRetryStrategy( + ImdsRetryPolicy.MinExponentialBackoffMs, + ImdsRetryPolicy.MaxExponentialBackoffMs, + ImdsRetryPolicy.ExponentialDeltaBackoffMs + ); + + internal virtual Task DelayAsync(int milliseconds) + { + return Task.Delay(milliseconds); + } + + public async Task PauseForRetryAsync(HttpResponse response, Exception exception, int retryCount, ILoggerAdapter logger) + { + int httpStatusCode = (int)response.StatusCode; + + if (retryCount == 0) + { + // Calculate the maxRetries based on the status code, once per request + _maxRetries = httpStatusCode == (int)HttpStatusCode.Gone + ? LinearStrategyNumRetries + : ExponentialStrategyNumRetries; + } + + // Check if the status code is retriable and if the current retry count is less than max retries + if (HttpRetryConditions.Imds(response, exception) && + retryCount < _maxRetries) + { + int retryAfterDelay = httpStatusCode == (int)HttpStatusCode.Gone + ? HttpStatusGoneRetryAfterMs + : _exponentialRetryStrategy.CalculateDelay(retryCount); + + logger.Warning($"Retrying request in {retryAfterDelay}ms (retry attempt: {retryCount + 1})"); + + // Pause execution for the calculated delay + await DelayAsync(retryAfterDelay).ConfigureAwait(false); + + return true; + } + + // If the status code is not retriable or max retries have been reached, do not retry + return false; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/LinearRetryStrategy.cs b/src/client/Microsoft.Identity.Client/Http/Retry/LinearRetryStrategy.cs new file mode 100644 index 0000000000..e6fc10c797 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/Http/Retry/LinearRetryStrategy.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; + +namespace Microsoft.Identity.Client.Http.Retry +{ + internal class LinearRetryStrategy + { + /// + /// Calculates the number of milliseconds to sleep based on the `Retry-After` HTTP header. + /// + /// The value of the `Retry-After` HTTP header. This can be either a number of seconds or an HTTP date string. + /// The minimum delay in milliseconds to return if the header is not present or invalid. + /// The number of milliseconds to sleep before retrying the request. + public int CalculateDelay(string retryHeader, int minimumDelay) + { + if (string.IsNullOrEmpty(retryHeader)) + { + return minimumDelay; + } + + // Try parsing the retry-after header as seconds + if (double.TryParse(retryHeader, out double seconds)) + { + int millisToSleep = (int)Math.Round(seconds * 1000); + return Math.Max(minimumDelay, millisToSleep); + } + + // If parsing as seconds fails, try parsing as an HTTP date + if (DateTime.TryParse(retryHeader, out DateTime retryDate)) + { + DateTime.TryParse(DateTime.UtcNow.ToString("R"), out DateTime nowDate); + + int millisToSleep = (int)(retryDate - nowDate).TotalMilliseconds; + return Math.Max(minimumDelay, millisToSleep); + } + + // If all parsing fails, return the minimum delay + return minimumDelay; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs b/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs new file mode 100644 index 0000000000..dd62d4c886 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/Http/Retry/RetryPolicyFactory.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; + +namespace Microsoft.Identity.Client.Http.Retry +{ + internal class RetryPolicyFactory : IRetryPolicyFactory + { + public virtual IRetryPolicy GetRetryPolicy(RequestType requestType) + { + switch (requestType) + { + case RequestType.STS: + case RequestType.ManagedIdentityDefault: + return new DefaultRetryPolicy(requestType); + case RequestType.Imds: + return new ImdsRetryPolicy(); + default: + throw new ArgumentOutOfRangeException(nameof(requestType), requestType, "Unknown request type."); + } + } + } +} diff --git a/src/client/Microsoft.Identity.Client/Instance/Region/RegionManager.cs b/src/client/Microsoft.Identity.Client/Instance/Region/RegionManager.cs index b2b72a4a45..bcc3ad76e8 100644 --- a/src/client/Microsoft.Identity.Client/Instance/Region/RegionManager.cs +++ b/src/client/Microsoft.Identity.Client/Instance/Region/RegionManager.cs @@ -10,6 +10,7 @@ using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.TelemetryCore.Internal.Events; using Microsoft.Identity.Client.Utils; @@ -46,11 +47,6 @@ public RegionInfo(string region, RegionAutodetectionSource regionSource, string private static bool s_failedAutoDiscovery = false; private static string s_regionDiscoveryDetails; - private readonly LinearRetryPolicy _linearRetryPolicy = new LinearRetryPolicy( - LinearRetryPolicy.DefaultStsRetryDelayMs, - LinearRetryPolicy.DefaultStsMaxRetries, - HttpRetryConditions.Sts); - public RegionManager( IHttpManager httpManager, int imdsCallTimeout = 2000, @@ -81,9 +77,12 @@ public async Task GetAzureRegionAsync(RequestContext requestContext) requestContext.ApiEvent != null, "Do not call GetAzureRegionAsync outside of a request. This can happen if you perform instance discovery outside a request, for example as part of validating input params."); + IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.STS); + // MSAL always performs region auto-discovery, even if the user configured an actual region // in order to detect inconsistencies and report via telemetry - var discoveredRegion = await DiscoverAndCacheAsync(logger, requestContext.UserCancellationToken).ConfigureAwait(false); + var discoveredRegion = await DiscoverAndCacheAsync(logger, requestContext.UserCancellationToken, retryPolicy).ConfigureAwait(false); RecordTelemetry(requestContext.ApiEvent, azureRegionConfig, discoveredRegion); @@ -158,7 +157,7 @@ private static bool IsTelemetryRecorded(ApiEvent apiEvent) apiEvent.RegionOutcome == default); } - private async Task DiscoverAndCacheAsync(ILoggerAdapter logger, CancellationToken requestCancellationToken) + private async Task DiscoverAndCacheAsync(ILoggerAdapter logger, CancellationToken requestCancellationToken, IRetryPolicy retryPolicy) { var regionInfo = GetCachedRegion(logger); if (regionInfo != null) @@ -166,12 +165,12 @@ private async Task DiscoverAndCacheAsync(ILoggerAdapter logger, Canc return regionInfo; } - var result = await DiscoverAsync(logger, requestCancellationToken).ConfigureAwait(false); + var result = await DiscoverAsync(logger, requestCancellationToken, retryPolicy).ConfigureAwait(false); return result; } - private async Task DiscoverAsync(ILoggerAdapter logger, CancellationToken requestCancellationToken) + private async Task DiscoverAsync(ILoggerAdapter logger, CancellationToken requestCancellationToken, IRetryPolicy retryPolicy) { RegionInfo result = null; @@ -213,14 +212,15 @@ private async Task DiscoverAsync(ILoggerAdapter logger, Cancellation mtlsCertificate: null, validateServerCertificate: null, cancellationToken: GetCancellationToken(requestCancellationToken), - retryPolicy: _linearRetryPolicy) + retryPolicy: retryPolicy) .ConfigureAwait(false); // A bad request occurs when the version in the IMDS call is no longer supported. if (response.StatusCode == HttpStatusCode.BadRequest) { - string apiVersion = await GetImdsUriApiVersionAsync(logger, headers, requestCancellationToken).ConfigureAwait(false); // Get the latest version + string apiVersion = await GetImdsUriApiVersionAsync(logger, headers, requestCancellationToken, retryPolicy).ConfigureAwait(false); // Get the latest version imdsUri = BuildImdsUri(apiVersion); + response = await _httpManager.SendRequestAsync( imdsUri, headers, @@ -231,7 +231,7 @@ private async Task DiscoverAsync(ILoggerAdapter logger, Cancellation mtlsCertificate: null, validateServerCertificate: null, cancellationToken: GetCancellationToken(requestCancellationToken), - retryPolicy: _linearRetryPolicy) + retryPolicy: retryPolicy) .ConfigureAwait(false); // Call again with updated version } @@ -320,7 +320,7 @@ private static bool ValidateRegion(string region, string source, ILoggerAdapter return true; } - private async Task GetImdsUriApiVersionAsync(ILoggerAdapter logger, Dictionary headers, CancellationToken userCancellationToken) + private async Task GetImdsUriApiVersionAsync(ILoggerAdapter logger, Dictionary headers, CancellationToken userCancellationToken, IRetryPolicy retryPolicy) { Uri imdsErrorUri = new(ImdsEndpoint); @@ -334,7 +334,7 @@ private async Task GetImdsUriApiVersionAsync(ILoggerAdapter logger, Dict mtlsCertificate: null, validateServerCertificate: null, cancellationToken: GetCancellationToken(userCancellationToken), - retryPolicy: _linearRetryPolicy) + retryPolicy: retryPolicy) .ConfigureAwait(false); // When IMDS endpoint is called without the api version query param, bad request response comes back with latest version. diff --git a/src/client/Microsoft.Identity.Client/Instance/Validation/AdfsAuthorityValidator.cs b/src/client/Microsoft.Identity.Client/Instance/Validation/AdfsAuthorityValidator.cs index 471f79ef38..585a172a38 100644 --- a/src/client/Microsoft.Identity.Client/Instance/Validation/AdfsAuthorityValidator.cs +++ b/src/client/Microsoft.Identity.Client/Instance/Validation/AdfsAuthorityValidator.cs @@ -6,7 +6,7 @@ using System.Net; using System.Threading.Tasks; using Microsoft.Identity.Client.Core; -using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.OAuth2; @@ -29,10 +29,8 @@ public async Task ValidateAuthorityAsync( var resource = $"https://{authorityInfo.Host}"; string webFingerUrl = Constants.FormatAdfsWebFingerUrl(authorityInfo.Host, resource); - LinearRetryPolicy _linearRetryPolicy = new LinearRetryPolicy( - LinearRetryPolicy.DefaultStsRetryDelayMs, - LinearRetryPolicy.DefaultStsMaxRetries, - HttpRetryConditions.Sts); + IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.STS); Http.HttpResponse httpResponse = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync( new Uri(webFingerUrl), @@ -44,7 +42,7 @@ public async Task ValidateAuthorityAsync( mtlsCertificate: null, validateServerCertificate: null, cancellationToken: _requestContext.UserCancellationToken, - retryPolicy: _linearRetryPolicy + retryPolicy: retryPolicy ) .ConfigureAwait(false); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs index 17b3c56f3a..3748745b4d 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs @@ -14,6 +14,8 @@ using System.Text; using System.Security.Cryptography.X509Certificates; using System.Net.Security; +using Microsoft.Identity.Client.Http.Retry; + #if SUPPORTS_SYSTEM_TEXT_JSON using System.Text.Json; @@ -57,6 +59,8 @@ public virtual async Task AuthenticateAsync( _requestContext.Logger.Info("[Managed Identity] Sending request to managed identity endpoints."); + IRetryPolicy retryPolicy = _requestContext.ServiceBundle.Config.RetryPolicyFactory.GetRetryPolicy(request.RequestType); + try { if (request.Method == HttpMethod.Get) @@ -72,7 +76,7 @@ public virtual async Task AuthenticateAsync( mtlsCertificate: null, validateServerCertificate: GetValidationCallback(), cancellationToken: cancellationToken, - retryPolicy: request.RetryPolicy).ConfigureAwait(false); + retryPolicy: retryPolicy).ConfigureAwait(false); } else { @@ -87,7 +91,7 @@ public virtual async Task AuthenticateAsync( mtlsCertificate: null, validateServerCertificate: GetValidationCallback(), cancellationToken: cancellationToken, - retryPolicy: request.RetryPolicy) + retryPolicy: retryPolicy) .ConfigureAwait(false); } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs index a643ce7880..8071a13944 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs @@ -9,6 +9,7 @@ using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.PlatformsCommon.Shared; @@ -123,10 +124,8 @@ protected override async Task HandleResponseAsync( _requestContext.Logger.Verbose(() => "[Managed Identity] Adding authorization header to the request."); request.Headers.Add("Authorization", authHeaderValue); - LinearRetryPolicy _linearRetryPolicy = new LinearRetryPolicy( - LinearRetryPolicy.DefaultStsRetryDelayMs, - LinearRetryPolicy.DefaultStsMaxRetries, - HttpRetryConditions.Sts); + IRetryPolicyFactory retryPolicyFactory = _requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.ManagedIdentityDefault); response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync( request.ComputeUri(), @@ -138,7 +137,7 @@ protected override async Task HandleResponseAsync( mtlsCertificate: null, validateServerCertificate: null, cancellationToken: cancellationToken, - retryPolicy: _linearRetryPolicy) + retryPolicy: retryPolicy) .ConfigureAwait(false); return await base.HandleResponseAsync(parameters, response, cancellationToken).ConfigureAwait(false); diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs index 6cfb8854e6..e4c6384103 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs @@ -9,7 +9,6 @@ using System.Threading.Tasks; using Microsoft.Identity.Client.ApiConfig.Parameters; using Microsoft.Identity.Client.Core; -using Microsoft.Identity.Client.Extensibility; using Microsoft.Identity.Client.Http; using Microsoft.Identity.Client.Internal; @@ -82,6 +81,8 @@ protected override ManagedIdentityRequest CreateRequest(string resource) break; } + request.RequestType = RequestType.Imds; + return request; } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs index 3541730c7b..fb08b37822 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs @@ -4,18 +4,12 @@ using System; using System.Collections.Generic; using System.Net.Http; -using Microsoft.Identity.Client.Http; using Microsoft.Identity.Client.Utils; namespace Microsoft.Identity.Client.ManagedIdentity { internal class ManagedIdentityRequest { - // referenced in unit tests, cannot be private - public const int DefaultManagedIdentityMaxRetries = 3; - // this will be overridden in the unit tests so that they run faster - public static int DefaultManagedIdentityRetryDelayMs { get; set; } = 1000; - private readonly Uri _baseEndpoint; public HttpMethod Method { get; } @@ -26,21 +20,16 @@ internal class ManagedIdentityRequest public IDictionary QueryParameters { get; } - public IRetryPolicy RetryPolicy { get; set; } + public RequestType RequestType { get; set; } - public ManagedIdentityRequest(HttpMethod method, Uri endpoint, IRetryPolicy retryPolicy = null) + public ManagedIdentityRequest(HttpMethod method, Uri endpoint, RequestType requestType = RequestType.ManagedIdentityDefault) { Method = method; _baseEndpoint = endpoint; Headers = new Dictionary(); BodyParameters = new Dictionary(); QueryParameters = new Dictionary(); - - IRetryPolicy defaultRetryPolicy = new LinearRetryPolicy( - DefaultManagedIdentityRetryDelayMs, - DefaultManagedIdentityMaxRetries, - HttpRetryConditions.ManagedIdentity); - RetryPolicy = retryPolicy ?? defaultRetryPolicy; + RequestType = requestType; } public Uri ComputeUri() diff --git a/src/client/Microsoft.Identity.Client/OAuth2/OAuth2Client.cs b/src/client/Microsoft.Identity.Client/OAuth2/OAuth2Client.cs index 462ee6c4a3..ae6384fe5f 100644 --- a/src/client/Microsoft.Identity.Client/OAuth2/OAuth2Client.cs +++ b/src/client/Microsoft.Identity.Client/OAuth2/OAuth2Client.cs @@ -15,8 +15,8 @@ using Microsoft.Identity.Client.Instance.Oidc; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.Utils; -using Microsoft.Identity.Client.Internal.Broker; using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client.Http.Retry; #if SUPPORTS_SYSTEM_TEXT_JSON using System.Text.Json; @@ -41,10 +41,6 @@ internal class OAuth2Client private readonly IDictionary _bodyParameters = new Dictionary(); private readonly IHttpManager _httpManager; private readonly X509Certificate2 _mtlsCertificate; - private readonly LinearRetryPolicy _linearRetryPolicy = new LinearRetryPolicy( - LinearRetryPolicy.DefaultStsRetryDelayMs, - LinearRetryPolicy.DefaultStsMaxRetries, - HttpRetryConditions.Sts); public OAuth2Client(ILoggerAdapter logger, IHttpManager httpManager, X509Certificate2 mtlsCertificate) { @@ -123,6 +119,9 @@ internal async Task ExecuteRequestAsync( using (requestContext.Logger.LogBlockDuration($"[Oauth2Client] Sending {method} request ")) { + IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.STS); + try { if (method == HttpMethod.Post) @@ -145,7 +144,7 @@ internal async Task ExecuteRequestAsync( mtlsCertificate: _mtlsCertificate, validateServerCertificate: null, cancellationToken: requestContext.UserCancellationToken, - retryPolicy: _linearRetryPolicy) + retryPolicy: retryPolicy) .ConfigureAwait(false); } else @@ -160,7 +159,7 @@ internal async Task ExecuteRequestAsync( mtlsCertificate: null, validateServerCertificate: null, cancellationToken: requestContext.UserCancellationToken, - retryPolicy: _linearRetryPolicy) + retryPolicy: retryPolicy) .ConfigureAwait(false); } } diff --git a/src/client/Microsoft.Identity.Client/RequestType.cs b/src/client/Microsoft.Identity.Client/RequestType.cs new file mode 100644 index 0000000000..fe92431689 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/RequestType.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.Identity.Client +{ + /// + /// Specifies the type of request being made to the identity provider. + /// + internal enum RequestType + { + /// + /// Security Token Service (STS) request, used for standard authentication flows. + /// + STS, + + /// + /// Managed Identity Default request, used when acquiring tokens for managed identities in Azure. + /// + ManagedIdentityDefault, + + /// + /// Instance Metadata Service (IMDS) request, used for obtaining tokens from the Azure VM metadata endpoint. + /// + Imds + } +} diff --git a/src/client/Microsoft.Identity.Client/WsTrust/WsTrustWebRequestManager.cs b/src/client/Microsoft.Identity.Client/WsTrust/WsTrustWebRequestManager.cs index 4b2fa4e8b8..c7302c4c77 100644 --- a/src/client/Microsoft.Identity.Client/WsTrust/WsTrustWebRequestManager.cs +++ b/src/client/Microsoft.Identity.Client/WsTrust/WsTrustWebRequestManager.cs @@ -4,16 +4,14 @@ using System; using System.Collections.Generic; using System.Globalization; -using System.IO; -using System.Linq; using System.Net.Http; using System.Text; using System.Threading.Tasks; using System.Xml.Linq; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.Internal; -using Microsoft.Identity.Client.TelemetryCore; using Microsoft.Identity.Client.Utils; namespace Microsoft.Identity.Client.WsTrust @@ -21,10 +19,6 @@ namespace Microsoft.Identity.Client.WsTrust internal class WsTrustWebRequestManager : IWsTrustWebRequestManager { private readonly IHttpManager _httpManager; - private readonly LinearRetryPolicy _linearRetryPolicy = new LinearRetryPolicy( - LinearRetryPolicy.DefaultStsRetryDelayMs, - LinearRetryPolicy.DefaultStsMaxRetries, - HttpRetryConditions.Sts); public WsTrustWebRequestManager(IHttpManager httpManager) { @@ -47,6 +41,9 @@ public async Task GetMexDocumentAsync(string federationMetadataUrl, var uri = new UriBuilder(federationMetadataUrl); + IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.STS); + HttpResponse httpResponse = await _httpManager.SendRequestAsync( uri.Uri, msalIdParams, @@ -57,7 +54,7 @@ public async Task GetMexDocumentAsync(string federationMetadataUrl, mtlsCertificate: null, validateServerCertificate: null, cancellationToken: requestContext.UserCancellationToken, - retryPolicy: _linearRetryPolicy) + retryPolicy: retryPolicy) .ConfigureAwait(false); if (httpResponse.StatusCode != System.Net.HttpStatusCode.OK) @@ -105,6 +102,9 @@ public async Task GetWsTrustResponseAsync( wsTrustRequest, Encoding.UTF8, "application/soap+xml"); + IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.STS); + HttpResponse resp = await _httpManager.SendRequestAsync( wsTrustEndpoint.Uri, headers, @@ -115,7 +115,7 @@ public async Task GetWsTrustResponseAsync( mtlsCertificate: null, validateServerCertificate: null, cancellationToken: requestContext.UserCancellationToken, - retryPolicy: _linearRetryPolicy) + retryPolicy: retryPolicy) .ConfigureAwait(false); if (resp.StatusCode != System.Net.HttpStatusCode.OK) @@ -182,6 +182,9 @@ public async Task GetUserRealmAsync( var uri = new UriBuilder(userRealmUriPrefix + userName + "?api-version=1.0").Uri; + IRetryPolicyFactory retryPolicyFactory = requestContext.ServiceBundle.Config.RetryPolicyFactory; + IRetryPolicy retryPolicy = retryPolicyFactory.GetRetryPolicy(RequestType.STS); + var httpResponse = await _httpManager.SendRequestAsync( uri, msalIdParams, @@ -192,7 +195,7 @@ public async Task GetUserRealmAsync( mtlsCertificate: null, validateServerCertificate: null, cancellationToken: requestContext.UserCancellationToken, - retryPolicy: _linearRetryPolicy) + retryPolicy: retryPolicy) .ConfigureAwait(false); if (httpResponse.StatusCode == System.Net.HttpStatusCode.OK) diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManager.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManager.cs index 232a152869..9f0bf34cae 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManager.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManager.cs @@ -16,7 +16,7 @@ using Microsoft.Identity.Client; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; -using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.Http.Retry; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace Microsoft.Identity.Test.Common.Core.Mocks diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs index 5a04bedf5b..bde5093fd3 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs @@ -367,12 +367,19 @@ public static void AddManagedIdentityMockHandler( ManagedIdentitySource managedIdentitySourceType, string userAssignedId = null, UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None, - HttpStatusCode statusCode = HttpStatusCode.OK + HttpStatusCode statusCode = HttpStatusCode.OK, + string retryAfterHeader = null // A number of seconds (e.g., "120"), or an HTTP-date in RFC1123 format (e.g., "Fri, 19 Apr 2025 15:00:00 GMT") ) { - HttpResponseMessage responseMessage = new HttpResponseMessage(statusCode); - HttpContent content = new StringContent(response); - responseMessage.Content = content; + HttpResponseMessage responseMessage = new HttpResponseMessage(statusCode) + { + Content = new StringContent(response) + }; + + if (retryAfterHeader != null) + { + responseMessage.Headers.TryAddWithoutValidation("Retry-After", retryAfterHeader); + } MockHttpMessageHandler httpMessageHandler = BuildMockHandlerForManagedIdentitySource(managedIdentitySourceType, resource); diff --git a/tests/Microsoft.Identity.Test.Common/TestCommon.cs b/tests/Microsoft.Identity.Test.Common/TestCommon.cs index 95772c3da7..4411a37492 100644 --- a/tests/Microsoft.Identity.Test.Common/TestCommon.cs +++ b/tests/Microsoft.Identity.Test.Common/TestCommon.cs @@ -31,6 +31,7 @@ using Microsoft.Identity.Test.Common.Core.Mocks; using NSubstitute; using static Microsoft.Identity.Client.TelemetryCore.Internal.Events.ApiEvent; +using Microsoft.Identity.Client.Http.Retry; namespace Microsoft.Identity.Test.Common { @@ -90,7 +91,8 @@ public static IServiceBundle CreateServiceBundleWithCustomHttpManager( LegacyCacheCompatibilityEnabled = isLegacyCacheEnabled, MultiCloudSupportEnabled = isMultiCloudSupportEnabled, IsInstanceDiscoveryEnabled = isInstanceDiscoveryEnabled, - PlatformProxy = platformProxy + PlatformProxy = platformProxy, + RetryPolicyFactory = new RetryPolicyFactory() }; return new ServiceBundle(appConfig, clearCaches); } diff --git a/tests/Microsoft.Identity.Test.Common/TestConstants.cs b/tests/Microsoft.Identity.Test.Common/TestConstants.cs index 6248ac1d61..eef1747403 100644 --- a/tests/Microsoft.Identity.Test.Common/TestConstants.cs +++ b/tests/Microsoft.Identity.Test.Common/TestConstants.cs @@ -202,7 +202,15 @@ public static HashSet s_scope public const string Region = "centralus"; public const string InvalidRegion = "invalidregion"; public const int TimeoutInMs = 2000; - public const string ImdsUrl = "http://169.254.169.254/metadata/instance/compute/location"; + public const string ImdsHost = "169.254.169.254"; + public const string ImdsUrl = $"http://{ImdsHost}/metadata/instance/compute/location"; + + public const string AppServiceEndpoint = "http://127.0.0.1:41564/msi/token"; + public const string AzureArcEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; + public const string CloudShellEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; + public const string ImdsEndpoint = $"http://{ImdsHost}/metadata/identity/oauth2/token"; + public const string MachineLearningEndpoint = "http://localhost:7071/msi/token"; + public const string ServiceFabricEndpoint = "https://localhost:2377/metadata/identity/oauth2/token"; public const string UserAssertion = "fake_access_token"; public const string CodeVerifier = "someCodeVerifier"; diff --git a/tests/Microsoft.Identity.Test.Integration.netcore/Infrastructure/MsiProxyHttpManager.cs b/tests/Microsoft.Identity.Test.Integration.netcore/Infrastructure/MsiProxyHttpManager.cs index 7520306f52..e7cfd54dcf 100644 --- a/tests/Microsoft.Identity.Test.Integration.netcore/Infrastructure/MsiProxyHttpManager.cs +++ b/tests/Microsoft.Identity.Test.Integration.netcore/Infrastructure/MsiProxyHttpManager.cs @@ -3,19 +3,16 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Net.Security; using System.Security.Cryptography.X509Certificates; -using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Web; -using Microsoft.Identity.Client; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Test.LabInfrastructure; namespace Microsoft.Identity.Test.Integration.NetFx.Infrastructure diff --git a/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpManagerTests.cs b/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpManagerTests.cs index b9c3428145..66cb40d66f 100644 --- a/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpManagerTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/CoreTests/HttpTests/HttpManagerTests.cs @@ -11,11 +11,10 @@ using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.Core; -using Microsoft.Identity.Client.Http; -using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Test.Common; using Microsoft.Identity.Test.Common.Core.Helpers; using Microsoft.Identity.Test.Common.Core.Mocks; +using Microsoft.Identity.Test.Unit.Helpers; using Microsoft.VisualStudio.TestTools.UnitTesting; using NSubstitute; @@ -24,10 +23,7 @@ namespace Microsoft.Identity.Test.Unit.CoreTests.HttpTests [TestClass] public class HttpManagerTests { - LinearRetryPolicy _stsLinearRetryPolicy = new LinearRetryPolicy( - LinearRetryPolicy.DefaultStsRetryDelayMs, - LinearRetryPolicy.DefaultStsMaxRetries, - HttpRetryConditions.Sts); + private readonly TestDefaultRetryPolicy _stsRetryPolicy = new TestDefaultRetryPolicy(RequestType.STS); [TestInitialize] public void TestInitialize() @@ -67,7 +63,7 @@ public async Task MtlsCertAsync() mtlsCertificate: cert, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy) + retryPolicy: _stsRetryPolicy) .ConfigureAwait(false); Assert.IsNotNull(response); @@ -107,7 +103,7 @@ await Assert.ThrowsExceptionAsync(() => mtlsCertificate: cert, validateServerCert: customCallback, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy)) + retryPolicy: _stsRetryPolicy)) .ConfigureAwait(false); } } @@ -144,7 +140,7 @@ public async Task TestHttpManagerWithValidationCallbackAsync() mtlsCertificate: null, validateServerCert: customCallback, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy) + retryPolicy: _stsRetryPolicy) .ConfigureAwait(false); Assert.IsNotNull(response); @@ -172,7 +168,7 @@ public async Task TestSendPostNullHeaderNullBodyAsync() mtlsCertificate: null, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy) + retryPolicy: _stsRetryPolicy) .ConfigureAwait(false); Assert.IsNotNull(response); @@ -214,7 +210,7 @@ public async Task TestSendPostNoFailureAsync() mtlsCertificate: null, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy) + retryPolicy: _stsRetryPolicy) .ConfigureAwait(false); Assert.IsNotNull(response); @@ -246,7 +242,7 @@ public async Task TestSendGetNoFailureAsync() mtlsCertificate: null, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy) + retryPolicy: _stsRetryPolicy) .ConfigureAwait(false); Assert.IsNotNull(response); @@ -282,13 +278,13 @@ await Assert.ThrowsExceptionAsync(() => mtlsCertificate: null, validateServerCert: null, cancellationToken: cts.Token, - retryPolicy: _stsLinearRetryPolicy)) + retryPolicy: _stsRetryPolicy)) .ConfigureAwait(false); } } [TestMethod] - public async Task TestSendGetWithRetryFalseHttp500TypeFailureAsync() + public async Task TestSendGetWithHttp500TypeFailureWithInternalRetriesDisabledAsync() { using (var httpManager = new MockHttpManager(disableInternalRetries: true)) { @@ -305,10 +301,13 @@ public async Task TestSendGetWithRetryFalseHttp500TypeFailureAsync() mtlsCertificate: null, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy)) + retryPolicy: _stsRetryPolicy)) .ConfigureAwait(false); - Assert.AreEqual(MsalError.ServiceNotAvailable, ex.ErrorCode); + + const int NumRequests = 1; // initial request + 0 retries + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); } } @@ -317,8 +316,12 @@ public async Task TestSendGetWithHttp500TypeFailureAsync() { using (var httpManager = new MockHttpManager()) { - httpManager.AddResiliencyMessageMockHandler(HttpMethod.Get, HttpStatusCode.GatewayTimeout); - httpManager.AddResiliencyMessageMockHandler(HttpMethod.Get, HttpStatusCode.InternalServerError); + // Simulate permanent errors (to trigger the maximum number of retries) + const int Num500Errors = 1 + TestDefaultRetryPolicy.DefaultStsMaxRetries; // initial request + maximum number of retries + for (int i = 0; i < Num500Errors; i++) + { + httpManager.AddResiliencyMessageMockHandler(HttpMethod.Get, HttpStatusCode.GatewayTimeout); + } var ex = await Assert.ThrowsExceptionAsync(() => httpManager.SendRequestAsync( @@ -331,10 +334,12 @@ public async Task TestSendGetWithHttp500TypeFailureAsync() mtlsCertificate: null, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy)) + retryPolicy: _stsRetryPolicy)) .ConfigureAwait(false); - Assert.AreEqual(MsalError.ServiceNotAvailable, ex.ErrorCode); + + int requestsMade = Num500Errors - httpManager.QueueSize; + Assert.AreEqual(Num500Errors, requestsMade); } } @@ -362,11 +367,40 @@ public async Task NoResiliencyIfRetryAfterHeaderPresentAsync(bool useTimeSpanRet mtlsCertificate: null, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy)) + retryPolicy: _stsRetryPolicy)) .ConfigureAwait(false); - - Assert.AreEqual(0, httpManager.QueueSize, "HttpManager must not retry because a RetryAfter header is present"); Assert.AreEqual(MsalError.ServiceNotAvailable, exc.ErrorCode); + + const int NumRequests = 1; // initial request + 0 retries + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); + } + } + + [TestMethod] + public async Task NoResiliencyIfHttpErrorNotRetriableAsync() + { + using (var httpManager = new MockHttpManager()) + { + httpManager.AddResiliencyMessageMockHandler(HttpMethod.Get, HttpStatusCode.BadRequest); + + var msalHttpResponse = await httpManager.SendRequestAsync( + new Uri(TestConstants.AuthorityHomeTenant + "oauth2/token"), + headers: null, + body: new StringContent("body"), + method: HttpMethod.Get, + logger: Substitute.For(), + doNotThrow: true, + mtlsCertificate: null, + validateServerCert: null, + cancellationToken: default, + retryPolicy: _stsRetryPolicy) + .ConfigureAwait(false); + Assert.AreEqual(HttpStatusCode.BadRequest, msalHttpResponse.StatusCode); + + const int NumRequests = 1; // initial request + 0 retries + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); } } @@ -375,23 +409,30 @@ public async Task TestSendGetWithHttp500TypeFailure2Async() { using (var httpManager = new MockHttpManager()) { - httpManager.AddResiliencyMessageMockHandler(HttpMethod.Post, HttpStatusCode.BadGateway); - httpManager.AddResiliencyMessageMockHandler(HttpMethod.Post, HttpStatusCode.BadGateway); + // Simulate permanent errors (to trigger the maximum number of retries) + const int Num500Errors = 1 + TestDefaultRetryPolicy.DefaultStsMaxRetries; // initial request + maximum number of retries + for (int i = 0; i < Num500Errors; i++) + { + httpManager.AddResiliencyMessageMockHandler(HttpMethod.Get, HttpStatusCode.BadGateway); + } var msalHttpResponse = await httpManager.SendRequestAsync( new Uri(TestConstants.AuthorityHomeTenant + "oauth2/token"), headers: null, body: new StringContent("body"), - method: HttpMethod.Post, + method: HttpMethod.Get, logger: Substitute.For(), doNotThrow: true, mtlsCertificate: null, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy) + retryPolicy: _stsRetryPolicy) .ConfigureAwait(false); Assert.AreEqual(HttpStatusCode.BadGateway, msalHttpResponse.StatusCode); + + int requestsMade = Num500Errors - httpManager.QueueSize; + Assert.AreEqual(Num500Errors, requestsMade); } } @@ -400,8 +441,12 @@ public async Task TestSendPostWithHttp500TypeFailureAsync() { using (var httpManager = new MockHttpManager()) { - httpManager.AddResiliencyMessageMockHandler(HttpMethod.Post, HttpStatusCode.GatewayTimeout); - httpManager.AddResiliencyMessageMockHandler(HttpMethod.Post, HttpStatusCode.ServiceUnavailable); + // Simulate permanent errors (to trigger the maximum number of retries) + const int Num500Errors = 1 + TestDefaultRetryPolicy.DefaultStsMaxRetries; // initial request + maximum number of retries + for (int i = 0; i < Num500Errors; i++) + { + httpManager.AddResiliencyMessageMockHandler(HttpMethod.Post, HttpStatusCode.ServiceUnavailable); + } var exc = await AssertException.TaskThrowsAsync(() => httpManager.SendRequestAsync( @@ -414,10 +459,12 @@ public async Task TestSendPostWithHttp500TypeFailureAsync() mtlsCertificate: null, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy)) + retryPolicy: _stsRetryPolicy)) .ConfigureAwait(false); - Assert.AreEqual(MsalError.ServiceNotAvailable, exc.ErrorCode); + + int requestsMade = Num500Errors - httpManager.QueueSize; + Assert.AreEqual(Num500Errors, requestsMade); } } @@ -426,8 +473,12 @@ public async Task TestSendGetWithRetryOnTimeoutFailureAsync() { using (var httpManager = new MockHttpManager()) { - httpManager.AddRequestTimeoutResponseMessageMockHandler(HttpMethod.Get); - httpManager.AddRequestTimeoutResponseMessageMockHandler(HttpMethod.Get); + // Simulate permanent errors (to trigger the maximum number of retries) + const int Num500Errors = 1 + TestDefaultRetryPolicy.DefaultStsMaxRetries; // initial request + maximum number of retries + for (int i = 0; i < Num500Errors; i++) + { + httpManager.AddRequestTimeoutResponseMessageMockHandler(HttpMethod.Get); + } var exc = await AssertException.TaskThrowsAsync(() => httpManager.SendRequestAsync( @@ -440,11 +491,13 @@ public async Task TestSendGetWithRetryOnTimeoutFailureAsync() mtlsCertificate: null, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy)) + retryPolicy: _stsRetryPolicy)) .ConfigureAwait(false); - Assert.AreEqual(MsalError.RequestTimeout, exc.ErrorCode); Assert.IsTrue(exc.InnerException is TaskCanceledException); + + int requestsMade = Num500Errors - httpManager.QueueSize; + Assert.AreEqual(Num500Errors, requestsMade); } } @@ -453,8 +506,12 @@ public async Task TestSendPostWithRetryOnTimeoutFailureAsync() { using (var httpManager = new MockHttpManager()) { - httpManager.AddRequestTimeoutResponseMessageMockHandler(HttpMethod.Post); - httpManager.AddRequestTimeoutResponseMessageMockHandler(HttpMethod.Post); + // Simulate permanent errors (to trigger the maximum number of retries) + const int Num500Errors = 1 + TestDefaultRetryPolicy.DefaultStsMaxRetries; // initial request + maximum number of retries + for (int i = 0; i < Num500Errors; i++) + { + httpManager.AddRequestTimeoutResponseMessageMockHandler(HttpMethod.Post); + } var exc = await AssertException.TaskThrowsAsync(() => httpManager.SendRequestAsync( @@ -467,60 +524,13 @@ public async Task TestSendPostWithRetryOnTimeoutFailureAsync() mtlsCertificate: null, validateServerCert: null, cancellationToken: default, - retryPolicy: _stsLinearRetryPolicy)) + retryPolicy: _stsRetryPolicy)) .ConfigureAwait(false); Assert.AreEqual(MsalError.RequestTimeout, exc.ErrorCode); Assert.IsTrue(exc.InnerException is TaskCanceledException); - } - } - - [TestMethod] - [DataRow(true, false)] - [DataRow(false, false)] - [DataRow(true, true)] - [DataRow(false, true)] - public async Task TestRetryConfigWithHttp500TypeFailureAsync(bool disableInternalRetries, bool isManagedIdentity) - { - using (var httpManager = new MockHttpManager(disableInternalRetries: disableInternalRetries)) - { - httpManager.AddResiliencyMessageMockHandler(HttpMethod.Post, HttpStatusCode.ServiceUnavailable); - - if (!disableInternalRetries) - { - //Adding second response for retry - httpManager.AddResiliencyMessageMockHandler(HttpMethod.Post, HttpStatusCode.ServiceUnavailable); - - // Add 2 more response for the managed identity flow since 3 retries happen in this scenario - if (isManagedIdentity) - { - httpManager.AddResiliencyMessageMockHandler(HttpMethod.Post, HttpStatusCode.ServiceUnavailable); - httpManager.AddResiliencyMessageMockHandler(HttpMethod.Post, HttpStatusCode.ServiceUnavailable); - } - } - - LinearRetryPolicy retryPolicy = isManagedIdentity ? new LinearRetryPolicy( - ManagedIdentityRequest.DefaultManagedIdentityRetryDelayMs, - ManagedIdentityRequest.DefaultManagedIdentityMaxRetries, - HttpRetryConditions.ManagedIdentity) : _stsLinearRetryPolicy; - - var msalHttpResponse = await httpManager.SendRequestAsync( - new Uri(TestConstants.AuthorityHomeTenant + "oauth2/token"), - headers: null, - body: new StringContent("body"), - method: HttpMethod.Post, - logger: Substitute.For(), - doNotThrow: true, - mtlsCertificate: null, - validateServerCert: null, - cancellationToken: default, - retryPolicy: retryPolicy) - .ConfigureAwait(false); - Assert.IsNotNull(msalHttpResponse); - Assert.AreEqual(HttpStatusCode.ServiceUnavailable, msalHttpResponse.StatusCode); - //If a second request is sent when retry is configured to false, the test will fail since - //the MockHttpManager will not be able to serve another response. - //The MockHttpManager will also check for unused responses which will check if the retry did not occur when it should have. + int requestsMade = Num500Errors - httpManager.QueueSize; + Assert.AreEqual(Num500Errors, requestsMade); } } } diff --git a/tests/Microsoft.Identity.Test.Unit/Helpers/ParallelRequestMockHandler.cs b/tests/Microsoft.Identity.Test.Unit/Helpers/ParallelRequestMockHandler.cs index 740b1f7a03..22431fcf9b 100644 --- a/tests/Microsoft.Identity.Test.Unit/Helpers/ParallelRequestMockHandler.cs +++ b/tests/Microsoft.Identity.Test.Unit/Helpers/ParallelRequestMockHandler.cs @@ -11,6 +11,7 @@ using System.Threading.Tasks; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Http; +using Microsoft.Identity.Client.Http.Retry; using Microsoft.Identity.Client.OAuth2; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.RequestsTests; diff --git a/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs new file mode 100644 index 0000000000..75e95f7e14 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicies.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Threading.Tasks; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.Http.Retry; + +namespace Microsoft.Identity.Test.Unit.Helpers +{ + internal class TestDefaultRetryPolicy : DefaultRetryPolicy + { + public TestDefaultRetryPolicy(RequestType requestType) : base(requestType) { } + + internal override Task DelayAsync(int milliseconds) + { + // No delay for tests + return Task.CompletedTask; + } + } + + internal class TestImdsRetryPolicy : ImdsRetryPolicy + { + public TestImdsRetryPolicy() : base() { } + + internal override Task DelayAsync(int milliseconds) + { + // No delay for tests + return Task.CompletedTask; + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs new file mode 100644 index 0000000000..9f5a656794 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/Helpers/TestRetryPolicyFactory.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.Http.Retry; + +namespace Microsoft.Identity.Test.Unit.Helpers +{ + internal class TestRetryPolicyFactory : IRetryPolicyFactory + { + public virtual IRetryPolicy GetRetryPolicy(RequestType requestType) + { + switch (requestType) + { + case RequestType.STS: + case RequestType.ManagedIdentityDefault: + return new TestDefaultRetryPolicy(requestType); + case RequestType.Imds: + return new TestImdsRetryPolicy(); + default: + throw new ArgumentOutOfRangeException(nameof(requestType), requestType, "Unknown request type."); + } + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/DefaultRetryPolicyTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/DefaultRetryPolicyTests.cs new file mode 100644 index 0000000000..f563a34fe9 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/DefaultRetryPolicyTests.cs @@ -0,0 +1,360 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Net; +using System.Threading.Tasks; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.AppConfig; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Test.Common.Core.Helpers; +using Microsoft.Identity.Test.Common.Core.Mocks; +using Microsoft.Identity.Test.Unit.Helpers; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; + +namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests +{ + /// + /// The Default Retry Policy applies to: + /// STS (Azure AD) (Tested in HttpManagerTests.cs) + /// Managed Identity Sources: App Service, Azure Arc, Cloud Shell, Machine Learning, Service Fabric + /// + [TestClass] + public class DefaultRetryPolicyTests : TestBase + { + private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); + + [DataTestMethod] // see test class header: all sources that allow UAMI + [DataRow(ManagedIdentitySource.AppService, TestConstants.AppServiceEndpoint)] + [DataRow(ManagedIdentitySource.MachineLearning, TestConstants.MachineLearningEndpoint)] + [DataRow(ManagedIdentitySource.ServiceFabric, TestConstants.ServiceFabricEndpoint)] + public async Task UAMIFails500OnceThenSucceeds200Async( + ManagedIdentitySource managedIdentitySource, + string endpoint) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(managedIdentitySource, endpoint); + + string userAssignedId = TestConstants.ClientId; + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.ClientId; + + ManagedIdentityId managedIdentityId = ManagedIdentityId.WithUserAssignedClientId(userAssignedId); + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + // Initial request fails with 500 + httpManager.AddManagedIdentityMockHandler( + endpoint, + ManagedIdentityTests.Resource, + "", + managedIdentitySource, + statusCode: HttpStatusCode.InternalServerError, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + + // Final success + httpManager.AddManagedIdentityMockHandler( + endpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiSuccessfulResponse(), + managedIdentitySource, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + + AuthenticationResult result = + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + Assert.AreEqual(result.AccessToken, TestConstants.ATSecret); + + const int NumRequests = 2; // initial request + 1 retry + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); + } + } + + [DataTestMethod] // see test class header: all sources that allow UAMI + [DataRow(ManagedIdentitySource.AppService, TestConstants.AppServiceEndpoint)] + [DataRow(ManagedIdentitySource.MachineLearning, TestConstants.MachineLearningEndpoint)] + [DataRow(ManagedIdentitySource.ServiceFabric, TestConstants.ServiceFabricEndpoint)] + public async Task UAMIFails500PermanentlyAsync( + ManagedIdentitySource managedIdentitySource, + string endpoint) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(managedIdentitySource, endpoint); + + string userAssignedId = TestConstants.ClientId; + UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.ClientId; + + ManagedIdentityId managedIdentityId = ManagedIdentityId.WithUserAssignedClientId(userAssignedId); + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + // Simulate permanent 500s (to trigger the maximum number of retries) + const int Num500Errors = 1 + TestDefaultRetryPolicy.DefaultManagedIdentityMaxRetries; // initial request + maximum number of retries + for (int i = 0; i < Num500Errors; i++) + { + httpManager.AddManagedIdentityMockHandler( + endpoint, + ManagedIdentityTests.Resource, + "", + managedIdentitySource, + statusCode: HttpStatusCode.InternalServerError, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + } + + MsalServiceException msalException = null; + try + { + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + } + catch (Exception ex) + { + msalException = ex as MsalServiceException; + } + Assert.IsNotNull(msalException); + + int requestsMade = Num500Errors - httpManager.QueueSize; + Assert.AreEqual(Num500Errors, requestsMade); + } + } + + [DataTestMethod] + [DataRow(ManagedIdentitySource.AppService, TestConstants.AppServiceEndpoint, null)] + [DataRow(ManagedIdentitySource.AzureArc, TestConstants.AzureArcEndpoint, null)] + [DataRow(ManagedIdentitySource.CloudShell, TestConstants.CloudShellEndpoint, null)] + [DataRow(ManagedIdentitySource.MachineLearning, TestConstants.MachineLearningEndpoint, null)] + [DataRow(ManagedIdentitySource.ServiceFabric, TestConstants.ServiceFabricEndpoint, null)] + [DataRow(ManagedIdentitySource.AppService, TestConstants.AppServiceEndpoint, "3")] + [DataRow(ManagedIdentitySource.AzureArc, TestConstants.AzureArcEndpoint, "3")] + [DataRow(ManagedIdentitySource.CloudShell, TestConstants.CloudShellEndpoint, "3")] + [DataRow(ManagedIdentitySource.MachineLearning, TestConstants.MachineLearningEndpoint, "3")] + [DataRow(ManagedIdentitySource.ServiceFabric, TestConstants.ServiceFabricEndpoint, "3")] + [DataRow(ManagedIdentitySource.AppService, TestConstants.AppServiceEndpoint, "date")] + [DataRow(ManagedIdentitySource.AzureArc, TestConstants.AzureArcEndpoint, "date")] + [DataRow(ManagedIdentitySource.CloudShell, TestConstants.CloudShellEndpoint, "date")] + [DataRow(ManagedIdentitySource.MachineLearning, TestConstants.MachineLearningEndpoint, "date")] + [DataRow(ManagedIdentitySource.ServiceFabric, TestConstants.ServiceFabricEndpoint, "date")] + public async Task SAMIFails500OnceWithVariousRetryAfterHeaderValuesThenSucceeds200Async( + ManagedIdentitySource managedIdentitySource, + string endpoint, + string retryAfterHeader) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(managedIdentitySource, endpoint); + + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + // Initial request fails with 500 + httpManager.AddManagedIdentityMockHandler( + endpoint, + ManagedIdentityTests.Resource, + "", + managedIdentitySource, + statusCode: HttpStatusCode.InternalServerError, + retryAfterHeader: retryAfterHeader == "date" ? DateTime.UtcNow.AddSeconds(3).ToString("R") : retryAfterHeader); + + // Final success + httpManager.AddManagedIdentityMockHandler( + endpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiSuccessfulResponse(), + managedIdentitySource); + + AuthenticationResult result = + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + Assert.AreEqual(result.AccessToken, TestConstants.ATSecret); + + const int NumRequests = 2; // initial request + 1 retry + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); + } + } + + [DataTestMethod] // see test class header: all sources allow SAMI + [DataRow(ManagedIdentitySource.AppService, TestConstants.AppServiceEndpoint)] + [DataRow(ManagedIdentitySource.AzureArc, TestConstants.AzureArcEndpoint)] + [DataRow(ManagedIdentitySource.CloudShell, TestConstants.CloudShellEndpoint)] + [DataRow(ManagedIdentitySource.MachineLearning, TestConstants.MachineLearningEndpoint)] + [DataRow(ManagedIdentitySource.ServiceFabric, TestConstants.ServiceFabricEndpoint)] + public async Task SAMIFails500Permanently( + ManagedIdentitySource managedIdentitySource, + string endpoint) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(managedIdentitySource, endpoint); + + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + // Simulate permanent 500s (to trigger the maximum number of retries) + int Num500Errors = 1 + TestDefaultRetryPolicy.DefaultManagedIdentityMaxRetries; // initial request + maximum number of retries + for (int i = 0; i < Num500Errors; i++) + { + httpManager.AddManagedIdentityMockHandler( + endpoint, + ManagedIdentityTests.Resource, + "", + managedIdentitySource, + statusCode: HttpStatusCode.InternalServerError); + } + + MsalServiceException msalException = null; + try + { + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + } + catch (Exception ex) + { + msalException = ex as MsalServiceException; + } + Assert.IsNotNull(msalException); + + int requestsMade = Num500Errors - httpManager.QueueSize; + Assert.AreEqual(Num500Errors, requestsMade); + } + } + + [DataTestMethod] // see test class header: all sources allow SAMI + [DataRow(ManagedIdentitySource.AppService, TestConstants.AppServiceEndpoint)] + [DataRow(ManagedIdentitySource.AzureArc, TestConstants.AzureArcEndpoint)] + [DataRow(ManagedIdentitySource.CloudShell, TestConstants.CloudShellEndpoint)] + [DataRow(ManagedIdentitySource.MachineLearning, TestConstants.MachineLearningEndpoint)] + [DataRow(ManagedIdentitySource.ServiceFabric, TestConstants.ServiceFabricEndpoint)] + public async Task SAMIFails400WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsync( + ManagedIdentitySource managedIdentitySource, + string endpoint) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(managedIdentitySource, endpoint); + + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + httpManager.AddManagedIdentityMockHandler( + endpoint, + ManagedIdentityTests.Resource, + "", + managedIdentitySource, + statusCode: HttpStatusCode.BadRequest); + + MsalServiceException msalException = null; + try + { + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + } + catch (Exception ex) + { + msalException = ex as MsalServiceException; + } + Assert.IsNotNull(msalException); + + const int NumRequests = 1; // initial request + 0 retries + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); + } + } + + [DataTestMethod] // see test class header: all sources allow SAMI + [DataRow(ManagedIdentitySource.AppService, TestConstants.AppServiceEndpoint)] + [DataRow(ManagedIdentitySource.AzureArc, TestConstants.AzureArcEndpoint)] + [DataRow(ManagedIdentitySource.CloudShell, TestConstants.CloudShellEndpoint)] + [DataRow(ManagedIdentitySource.MachineLearning, TestConstants.MachineLearningEndpoint)] + [DataRow(ManagedIdentitySource.ServiceFabric, TestConstants.ServiceFabricEndpoint)] + public async Task SAMIFails500AndRetryPolicyIsDisabledAndNotTriggeredAsync( + ManagedIdentitySource managedIdentitySource, + string endpoint) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager(disableInternalRetries: true)) + { + SetEnvironmentVariables(managedIdentitySource, endpoint); + + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + httpManager.AddManagedIdentityMockHandler( + endpoint, + ManagedIdentityTests.Resource, + "", + managedIdentitySource, + statusCode: HttpStatusCode.InternalServerError); + + MsalServiceException msalException = null; + try + { + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + } + catch (Exception ex) + { + msalException = ex as MsalServiceException; + } + Assert.IsNotNull(msalException); + + const int NumRequests = 1; // initial request + 0 retries + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); + } + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs index 6fe43611ee..13e22314af 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs @@ -9,6 +9,7 @@ using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Test.Common.Core.Helpers; using Microsoft.Identity.Test.Common.Core.Mocks; +using Microsoft.Identity.Test.Unit.Helpers; using Microsoft.VisualStudio.TestTools.UnitTesting; using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; @@ -17,41 +18,424 @@ namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests [TestClass] public class ImdsTests : TestBase { + private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); + + [DataTestMethod] + [DataRow(null, null)] // SAMI + [DataRow(TestConstants.ClientId, UserAssignedIdentityId.ClientId)] // UAMI + public async Task ImdsFails404TwiceThenSucceeds200Async( + string userAssignedId, + UserAssignedIdentityId userAssignedIdentityId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); + + ManagedIdentityId managedIdentityId = userAssignedId == null + ? ManagedIdentityId.SystemAssigned + : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + // Simulate two 404s (to trigger retries), then a successful response + const int Num404Errors = 2; + for (int i = 0; i < Num404Errors; i++) + { + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiImdsErrorResponse(), + ManagedIdentitySource.Imds, + statusCode: HttpStatusCode.NotFound, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + } + + // Final success + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.Imds, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + + AuthenticationResult result = + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + Assert.AreEqual(result.AccessToken, TestConstants.ATSecret); + + const int NumRequests = 1 + Num404Errors; // initial request + 2 retries + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); + } + } + + [DataTestMethod] + [DataRow(null, null)] // SAMI + [DataRow(TestConstants.ClientId, UserAssignedIdentityId.ClientId)] // UAMI + public async Task ImdsFails410FourTimesThenSucceeds200Async( + string userAssignedId, + UserAssignedIdentityId userAssignedIdentityId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); + + ManagedIdentityId managedIdentityId = userAssignedId == null + ? ManagedIdentityId.SystemAssigned + : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + // Simulate four 410s (to trigger retries), then a successful response + const int Num410Errors = 4; + for (int i = 0; i < Num410Errors; i++) + { + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiImdsErrorResponse(), + ManagedIdentitySource.Imds, + statusCode: HttpStatusCode.Gone, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + } + + // Final success + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.Imds, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + + AuthenticationResult result = + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + Assert.AreEqual(result.AccessToken, TestConstants.ATSecret); + + const int NumRequests = 1 + Num410Errors; // initial request + 4 retries + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); + } + } + [DataTestMethod] - [DataRow(HttpStatusCode.BadRequest, ImdsManagedIdentitySource.IdentityUnavailableError, 1, DisplayName = "BadRequest - Identity Unavailable")] - [DataRow(HttpStatusCode.BadGateway, ImdsManagedIdentitySource.GatewayError, 1, DisplayName = "BadGateway - Gateway Error")] - [DataRow(HttpStatusCode.GatewayTimeout, ImdsManagedIdentitySource.GatewayError, 4, DisplayName = "GatewayTimeout - Gateway Error Retries")] - public async Task ImdsErrorHandlingTestAsync(HttpStatusCode statusCode, string expectedErrorSubstring, int expectedAttempts) + [DataRow(null, null)] // SAMI + [DataRow(TestConstants.ClientId, UserAssignedIdentityId.ClientId)] // UAMI + public async Task ImdsFails410PermanentlyAsync( + string userAssignedId, + UserAssignedIdentityId userAssignedIdentityId) { using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { - SetEnvironmentVariables(ManagedIdentitySource.Imds, "http://169.254.169.254"); + SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); + + ManagedIdentityId managedIdentityId = userAssignedId == null + ? ManagedIdentityId.SystemAssigned + : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + // Simulate permanent 410s (to trigger the maximum number of retries) + const int Num410Errors = 1 + TestImdsRetryPolicy.LinearStrategyNumRetries; // initial request + maximum number of retries + for (int i = 0; i < Num410Errors; i++) + { + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiImdsErrorResponse(), + ManagedIdentitySource.Imds, + statusCode: HttpStatusCode.Gone, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + } + + MsalServiceException msalException = null; + try + { + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + } + catch (Exception ex) + { + msalException = ex as MsalServiceException; + } + Assert.IsNotNull(msalException); + + int requestsMade = Num410Errors - httpManager.QueueSize; + Assert.AreEqual(Num410Errors, requestsMade); + } + } - var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager); + [DataTestMethod] + [DataRow(null, null)] // SAMI + [DataRow(TestConstants.ClientId, UserAssignedIdentityId.ClientId)] // UAMI + public async Task ImdsFails504PermanentlyAsync( + string userAssignedId, + UserAssignedIdentityId userAssignedIdentityId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); + + ManagedIdentityId managedIdentityId = userAssignedId == null + ? ManagedIdentityId.SystemAssigned + : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); - // Disabling shared cache options to avoid cross test pollution. + // Disable cache to avoid pollution miBuilder.Config.AccessorOptions = null; - var mi = miBuilder.Build(); + IManagedIdentityApplication mi = miBuilder.Build(); - // Adding multiple mock handlers to simulate retries for GatewayTimeout - for (int i = 0; i < expectedAttempts; i++) + // Simulate permanent 504s (to trigger the maximum number of retries) + const int Num504Errors = 1 + TestImdsRetryPolicy.ExponentialStrategyNumRetries; // initial request + maximum number of retries + for (int i = 0; i < Num504Errors; i++) { - httpManager.AddManagedIdentityMockHandler(ManagedIdentityTests.ImdsEndpoint, ManagedIdentityTests.Resource, - MockHelpers.GetMsiImdsErrorResponse(), ManagedIdentitySource.Imds, statusCode: statusCode); + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiImdsErrorResponse(), + ManagedIdentitySource.Imds, + statusCode: HttpStatusCode.GatewayTimeout, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); } - // Expecting a MsalServiceException indicating an error - MsalServiceException ex = await Assert.ThrowsExceptionAsync(async () => + MsalServiceException msalException = null; + try + { await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) - .ExecuteAsync().ConfigureAwait(false)).ConfigureAwait(false); + .ExecuteAsync() + .ConfigureAwait(false); + } + catch (Exception ex) + { + msalException = ex as MsalServiceException; + } + Assert.IsNotNull(msalException); + int requestsMade = Num504Errors - httpManager.QueueSize; + Assert.AreEqual(Num504Errors, requestsMade); + } + } + + [DataTestMethod] + [DataRow(null, null)] // SAMI + [DataRow(TestConstants.ClientId, UserAssignedIdentityId.ClientId)] // UAMI + public async Task ImdsFails400WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsync( + string userAssignedId, + UserAssignedIdentityId userAssignedIdentityId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); + + ManagedIdentityId managedIdentityId = userAssignedId == null + ? ManagedIdentityId.SystemAssigned + : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiImdsErrorResponse(), + ManagedIdentitySource.Imds, + statusCode: HttpStatusCode.BadRequest, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + + MsalServiceException msalException = null; + try + { + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + } + catch (Exception ex) + { + msalException = ex as MsalServiceException; + } + Assert.IsNotNull(msalException); + + const int NumRequests = 1; // initial request + 0 retries + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); + } + } + + [DataTestMethod] + [DataRow(null, null)] // SAMI + [DataRow(TestConstants.ClientId, UserAssignedIdentityId.ClientId)] // UAMI + public async Task ImdsFails500AndRetryPolicyIsDisabledAndNotTriggeredAsync( + string userAssignedId, + UserAssignedIdentityId userAssignedIdentityId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager(disableInternalRetries: true)) + { + SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); + + ManagedIdentityId managedIdentityId = userAssignedId == null + ? ManagedIdentityId.SystemAssigned + : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiImdsErrorResponse(), + ManagedIdentitySource.Imds, + statusCode: HttpStatusCode.InternalServerError, + userAssignedId: userAssignedId, + userAssignedIdentityId: userAssignedIdentityId); + + MsalServiceException msalException = null; + try + { + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false); + } + catch (Exception ex) + { + msalException = ex as MsalServiceException; + } + Assert.IsNotNull(msalException); + + const int NumRequests = 1; // initial request + 0 retries + int requestsMade = NumRequests - httpManager.QueueSize; + Assert.AreEqual(NumRequests, requestsMade); + } + } + + [TestMethod] + + public async Task ImdsRetryPolicyLifeTimeIsPerRequestAsync() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.Imds, TestConstants.ImdsEndpoint); + + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); + + // Disable cache to avoid pollution + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + // Simulate permanent errors (to trigger the maximum number of retries) + const int Num504Errors = 1 + TestImdsRetryPolicy.ExponentialStrategyNumRetries; // initial request + maximum number of retries + for (int i = 0; i < Num504Errors; i++) + { + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiImdsErrorResponse(), + ManagedIdentitySource.Imds, + statusCode: HttpStatusCode.GatewayTimeout); + } + + MsalServiceException ex = + await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false)) + .ConfigureAwait(false); + Assert.IsNotNull(ex); + + int requestsMade = Num504Errors - httpManager.QueueSize; + Assert.AreEqual(Num504Errors, requestsMade); + + for (int i = 0; i < Num504Errors; i++) + { + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiImdsErrorResponse(), + ManagedIdentitySource.Imds, + statusCode: HttpStatusCode.GatewayTimeout); + } + + ex = await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false)) + .ConfigureAwait(false); Assert.IsNotNull(ex); - Assert.AreEqual(ManagedIdentitySource.Imds.ToString(), ex.AdditionalExceptionData[MsalException.ManagedIdentitySource]); - Assert.AreEqual(MsalError.ManagedIdentityRequestFailed, ex.ErrorCode); - Assert.IsTrue(ex.Message.Contains(expectedErrorSubstring), $"The error message is not as expected. Error message: {ex.Message}. Expected message should contain: {expectedErrorSubstring}"); + + // 3 retries (requestsMade would be 6 if retry policy was NOT per request) + requestsMade = Num504Errors - httpManager.QueueSize; + Assert.AreEqual(Num504Errors, requestsMade); + + for (int i = 0; i < Num504Errors; i++) + { + httpManager.AddManagedIdentityMockHandler( + ManagedIdentityTests.ImdsEndpoint, + ManagedIdentityTests.Resource, + MockHelpers.GetMsiImdsErrorResponse(), + ManagedIdentitySource.Imds, + statusCode: HttpStatusCode.GatewayTimeout); + } + + ex = await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync() + .ConfigureAwait(false)) + .ConfigureAwait(false); + Assert.IsNotNull(ex); + + // 3 retries (requestsMade would be 9 if retry policy was NOT per request) + requestsMade = Num504Errors - httpManager.QueueSize; + Assert.AreEqual(Num504Errors, requestsMade); } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs index 809f725682..d8d385e848 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs @@ -6,20 +6,18 @@ using System.Linq; using System.Net; using System.Net.Http; -using System.Net.Security; using System.Net.Sockets; -using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; -using Microsoft.Identity.Client.Http; using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.ManagedIdentity; using Microsoft.Identity.Client.TelemetryCore.Internal.Events; using Microsoft.Identity.Test.Common; using Microsoft.Identity.Test.Common.Core.Helpers; using Microsoft.Identity.Test.Common.Core.Mocks; +using Microsoft.Identity.Test.Unit.Helpers; using Microsoft.VisualStudio.TestTools.UnitTesting; using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; @@ -41,6 +39,8 @@ public class ManagedIdentityTests : TestBase internal const string ExpectedErrorCode = "ErrorCode"; internal const string ExpectedCorrelationId = "Some GUID"; + private readonly TestRetryPolicyFactory _testRetryPolicyFactory = new TestRetryPolicyFactory(); + [DataTestMethod] [DataRow("http://127.0.0.1:41564/msi/token/", ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)] [DataRow(AppServiceEndpoint, ManagedIdentitySource.AppService, ManagedIdentitySource.AppService)] @@ -1244,13 +1244,12 @@ public async Task MixedUserAndSystemAssignedManagedIdentityTestAsync() [DataTestMethod] [DataRow(ManagedIdentitySource.AppService, AppServiceEndpoint, HttpStatusCode.NotFound)] [DataRow(ManagedIdentitySource.AppService, AppServiceEndpoint, HttpStatusCode.RequestTimeout)] - [DataRow(ManagedIdentitySource.AppService, AppServiceEndpoint, 429)] + [DataRow(ManagedIdentitySource.AppService, AppServiceEndpoint, 429)] // not defined in HttpStatusCode enum [DataRow(ManagedIdentitySource.AppService, AppServiceEndpoint, HttpStatusCode.InternalServerError)] [DataRow(ManagedIdentitySource.AppService, AppServiceEndpoint, HttpStatusCode.ServiceUnavailable)] [DataRow(ManagedIdentitySource.AppService, AppServiceEndpoint, HttpStatusCode.GatewayTimeout)] [DataRow(ManagedIdentitySource.AzureArc, AzureArcEndpoint, HttpStatusCode.GatewayTimeout)] [DataRow(ManagedIdentitySource.CloudShell, CloudShellEndpoint, HttpStatusCode.GatewayTimeout)] - [DataRow(ManagedIdentitySource.Imds, ImdsEndpoint, HttpStatusCode.GatewayTimeout)] [DataRow(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint, HttpStatusCode.GatewayTimeout)] [DataRow(ManagedIdentitySource.ServiceFabric, ServiceFabricEndpoint, HttpStatusCode.GatewayTimeout)] public async Task ManagedIdentityRetryPolicyLifeTimeIsPerRequestAsync( @@ -1264,7 +1263,8 @@ public async Task ManagedIdentityRetryPolicyLifeTimeIsPerRequestAsync( SetEnvironmentVariables(managedIdentitySource, endpoint); var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithHttpManager(httpManager); + .WithHttpManager(httpManager) + .WithRetryPolicyFactory(_testRetryPolicyFactory); // Disable cache to avoid pollution miBuilder.Config.AccessorOptions = null; @@ -1272,8 +1272,8 @@ public async Task ManagedIdentityRetryPolicyLifeTimeIsPerRequestAsync( var mi = miBuilder.Build(); // Simulate permanent errors (to trigger the maximum number of retries) - const int NumErrors = ManagedIdentityRequest.DefaultManagedIdentityMaxRetries + 1; // initial request + maximum number of retries (3) - for (int i = 0; i < NumErrors; i++) + const int Num504Errors = 1 + TestDefaultRetryPolicy.DefaultManagedIdentityMaxRetries; // initial request + maximum number of retries + for (int i = 0; i < Num504Errors; i++) { httpManager.AddManagedIdentityMockHandler( endpoint, @@ -1283,54 +1283,58 @@ public async Task ManagedIdentityRetryPolicyLifeTimeIsPerRequestAsync( statusCode: statusCode); } - MsalServiceException ex = await Assert.ThrowsExceptionAsync(async () => - await mi.AcquireTokenForManagedIdentity(Resource) - .ExecuteAsync().ConfigureAwait(false)).ConfigureAwait(false); + MsalServiceException ex = + await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity(Resource) + .ExecuteAsync() + .ConfigureAwait(false)) + .ConfigureAwait(false); Assert.IsNotNull(ex); - // 4 total: request + 3 retries - Assert.AreEqual(LinearRetryPolicy.numRetries, 1 + ManagedIdentityRequest.DefaultManagedIdentityMaxRetries); - Assert.AreEqual(httpManager.QueueSize, 0); + int requestsMade = Num504Errors - httpManager.QueueSize; + Assert.AreEqual(Num504Errors, requestsMade); - for (int i = 0; i < NumErrors; i++) + for (int i = 0; i < Num504Errors; i++) { httpManager.AddManagedIdentityMockHandler( endpoint, Resource, "", managedIdentitySource, - statusCode: HttpStatusCode.InternalServerError); + statusCode: statusCode); } ex = await Assert.ThrowsExceptionAsync(async () => - await mi.AcquireTokenForManagedIdentity(Resource) - .ExecuteAsync().ConfigureAwait(false)).ConfigureAwait(false); + await mi.AcquireTokenForManagedIdentity(Resource) + .ExecuteAsync() + .ConfigureAwait(false)) + .ConfigureAwait(false); Assert.IsNotNull(ex); - // 4 total: request + 3 retries - // (numRetries would be x2 if retry policy was NOT per request) - Assert.AreEqual(LinearRetryPolicy.numRetries, 1 + ManagedIdentityRequest.DefaultManagedIdentityMaxRetries); - Assert.AreEqual(httpManager.QueueSize, 0); + // 3 retries (requestsMade would be 6 if retry policy was NOT per request) + requestsMade = Num504Errors - httpManager.QueueSize; + Assert.AreEqual(Num504Errors, requestsMade); - for (int i = 0; i < NumErrors; i++) + for (int i = 0; i < Num504Errors; i++) { httpManager.AddManagedIdentityMockHandler( endpoint, Resource, "", managedIdentitySource, - statusCode: HttpStatusCode.InternalServerError); + statusCode: statusCode); } ex = await Assert.ThrowsExceptionAsync(async () => - await mi.AcquireTokenForManagedIdentity(Resource) - .ExecuteAsync().ConfigureAwait(false)).ConfigureAwait(false); + await mi.AcquireTokenForManagedIdentity(Resource) + .ExecuteAsync() + .ConfigureAwait(false)) + .ConfigureAwait(false); Assert.IsNotNull(ex); - // 4 total: request + 3 retries - // (numRetries would be x3 if retry policy was NOT per request) - Assert.AreEqual(LinearRetryPolicy.numRetries, 1 + ManagedIdentityRequest.DefaultManagedIdentityMaxRetries); - Assert.AreEqual(httpManager.QueueSize, 0); + // 3 retries (requestsMade would be 9 if retry policy was NOT per request) + requestsMade = Num504Errors - httpManager.QueueSize; + Assert.AreEqual(Num504Errors, requestsMade); } }