From 7fe26437c77680cf073e573fbb04d6deaf9b3355 Mon Sep 17 00:00:00 2001 From: Robbie Ginsburg Date: Thu, 2 Oct 2025 16:13:48 -0400 Subject: [PATCH] Added caching for the IMDS endpoint env variable + improved unit tests --- .../ImdsManagedIdentitySource.cs | 32 ++++++------ .../Core/Helpers/ManagedIdentityTestUtil.cs | 5 ++ .../ManagedIdentityTests/ImdsV2Tests.cs | 49 +++++++++++++++++++ 3 files changed, 71 insertions(+), 15 deletions(-) diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs index ecc7efab1f..b26fef740f 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs @@ -34,6 +34,8 @@ internal class ImdsManagedIdentitySource : AbstractManagedIdentity private readonly Uri _imdsEndpoint; + private static string s_cachedBaseEndpoint = null; + internal ImdsManagedIdentitySource(RequestContext requestContext) : base(requestContext, ManagedIdentitySource.Imds) { @@ -181,25 +183,25 @@ public static Uri GetValidatedEndpoint( string queryParams = null ) { - UriBuilder builder; - - if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint)) + if (s_cachedBaseEndpoint == null) { - logger.Verbose(() => "[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: " + EnvironmentVariables.PodIdentityEndpoint); - builder = new UriBuilder(EnvironmentVariables.PodIdentityEndpoint) + if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint)) { - Path = subPath - }; - } - else - { - logger.Verbose(() => "[Managed Identity] Unable to find AZURE_POD_IDENTITY_AUTHORITY_HOST environment variable for IMDS, using the default endpoint."); - builder = new UriBuilder(DefaultImdsBaseEndpoint) + logger.Verbose(() => "[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: " + EnvironmentVariables.PodIdentityEndpoint); + s_cachedBaseEndpoint = EnvironmentVariables.PodIdentityEndpoint; + } + else { - Path = subPath - }; + logger.Verbose(() => "[Managed Identity] Unable to find AZURE_POD_IDENTITY_AUTHORITY_HOST environment variable for IMDS, using the default endpoint."); + s_cachedBaseEndpoint = DefaultImdsBaseEndpoint; + } } - + + UriBuilder builder = new UriBuilder(s_cachedBaseEndpoint) + { + Path = subPath + }; + if (!string.IsNullOrEmpty(queryParams)) { builder.Query = queryParams; diff --git a/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs b/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs index fe8651f0e7..eb6fd842ef 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs @@ -42,6 +42,7 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity break; case ManagedIdentitySource.Imds: + case ManagedIdentitySource.ImdsV2: Environment.SetEnvironmentVariable("AZURE_POD_IDENTITY_AUTHORITY_HOST", endpoint); break; @@ -59,11 +60,15 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity Environment.SetEnvironmentVariable("IDENTITY_HEADER", secret); Environment.SetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT", thumbprint); break; + case ManagedIdentitySource.MachineLearning: Environment.SetEnvironmentVariable("MSI_ENDPOINT", endpoint); Environment.SetEnvironmentVariable("MSI_SECRET", secret); Environment.SetEnvironmentVariable("DEFAULT_IDENTITY_CLIENT_ID", "fake_DEFAULT_IDENTITY_CLIENT_ID"); break; + + default: + throw new NotImplementedException($"Setting environment variables for {managedIdentitySource} is not implemented."); } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs index 8c281a1d39..01c07540d1 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs @@ -16,6 +16,7 @@ using Microsoft.Identity.Client.MtlsPop; using Microsoft.Identity.Client.PlatformsCommon.Interfaces; using Microsoft.Identity.Client.PlatformsCommon.Shared; +using Microsoft.Identity.Test.Common.Core.Helpers; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.Identity.Test.Unit.Helpers; using Microsoft.Identity.Test.Unit.PublicApiTests; @@ -152,8 +153,11 @@ public async Task BearerTokenHappyPath( UserAssignedIdentityId userAssignedIdentityId, string userAssignedId) { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.InMemory).ConfigureAwait(false); AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId); @@ -185,8 +189,11 @@ public async Task BearerTokenTokenIsPerIdentity( UserAssignedIdentityId userAssignedIdentityId, string userAssignedId) { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + #region Identity 1 var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); @@ -246,8 +253,11 @@ public async Task BearerTokenIsReAcquiredWhenCertificatIsExpired( UserAssignedIdentityId userAssignedIdentityId, string userAssignedId) { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false); AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredRawCertificate); // cert will be expired on second request @@ -288,8 +298,11 @@ public async Task mTLSPopTokenHappyPath( UserAssignedIdentityId userAssignedIdentityId, string userAssignedId) { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true); @@ -328,8 +341,11 @@ public async Task mTLSPopTokenTokenIsPerIdentity( UserAssignedIdentityId userAssignedIdentityId, string userAssignedId) { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + #region Identity 1 var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); @@ -408,8 +424,11 @@ public async Task mTLSPopTokenIsReAcquiredWhenCertificatIsExpired( UserAssignedIdentityId userAssignedIdentityId, string userAssignedId) { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredRawCertificate, mTLSPop: true); @@ -450,8 +469,11 @@ public async Task mTLSPopTokenIsReAcquiredWhenCertificatIsExpired( [TestMethod] public async Task GetCsrMetadataAsyncSucceeds() { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + var handler = httpManager.AddMockHandler(MockHelpers.MockCsrResponse()); await CreateManagedIdentityAsync(httpManager, addProbeMock: false).ConfigureAwait(false); @@ -461,8 +483,11 @@ public async Task GetCsrMetadataAsyncSucceeds() [TestMethod] public async Task GetCsrMetadataAsyncSucceedsAfterRetry() { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + // First attempt fails with INTERNAL_SERVER_ERROR (500) httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.InternalServerError)); @@ -474,8 +499,11 @@ public async Task GetCsrMetadataAsyncSucceedsAfterRetry() [TestMethod] public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: null)); var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); @@ -488,8 +516,11 @@ public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader() [TestMethod] public async Task GetCsrMetadataAsyncFailsWithInvalidFormat() { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: "I_MDS/150.870.65.1854")); var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); @@ -502,8 +533,11 @@ public async Task GetCsrMetadataAsyncFailsWithInvalidFormat() [TestMethod] public async Task GetCsrMetadataAsyncFailsAfterMaxRetries() { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + const int Num500Errors = 1 + TestCsrMetadataProbeRetryPolicy.ExponentialStrategyNumRetries; for (int i = 0; i < Num500Errors; i++) { @@ -520,8 +554,11 @@ public async Task GetCsrMetadataAsyncFailsAfterMaxRetries() [TestMethod] public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsync() { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.NotFound)); var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); @@ -606,8 +643,11 @@ public void AttachPrivateKeyToCert_NullPrivateKey_ThrowsArgumentNullException() [TestMethod] public async Task MtlsPop_AttestationProviderMissing_ThrowsClientException() { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); // CreateManagedIdentityAsync does a probe; Add one more CSR response for the actual acquire. @@ -627,8 +667,11 @@ await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) [TestMethod] public async Task MtlsPop_AttestationProviderReturnsNull_ThrowsClientException() { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); // CreateManagedIdentityAsync does a probe; Add one more CSR response for the actual acquire. @@ -651,8 +694,11 @@ await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) [TestMethod] public async Task MtlsPop_AttestationProviderReturnsEmptyToken_ThrowsClientException() { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false); // CreateManagedIdentityAsync does a probe; Add one more CSR response for the actual acquire. @@ -675,8 +721,11 @@ await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) [TestMethod] public async Task mTLSPop_RequestedWithoutKeyGuard_ThrowsClientException() { + using (new EnvVariableContext()) using (var httpManager = new MockHttpManager()) { + SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint); + // Force in-memory keys (i.e., not KeyGuard) var managedIdentityApp = await CreateManagedIdentityAsync( httpManager,