Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ internal class ImdsManagedIdentitySource : AbstractManagedIdentity

private readonly Uri _imdsEndpoint;

private static string s_cachedBaseEndpoint = null;

internal ImdsManagedIdentitySource(RequestContext requestContext) :
base(requestContext, ManagedIdentitySource.Imds)
{
Expand Down Expand Up @@ -181,25 +183,25 @@ public static Uri GetValidatedEndpoint(
string queryParams = null
)
{
UriBuilder builder;

if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint))
if (s_cachedBaseEndpoint == null)
{
logger.Verbose(() => "[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: " + EnvironmentVariables.PodIdentityEndpoint);
builder = new UriBuilder(EnvironmentVariables.PodIdentityEndpoint)
if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint))
{
Path = subPath
};
}
else
{
logger.Verbose(() => "[Managed Identity] Unable to find AZURE_POD_IDENTITY_AUTHORITY_HOST environment variable for IMDS, using the default endpoint.");
builder = new UriBuilder(DefaultImdsBaseEndpoint)
logger.Verbose(() => "[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: " + EnvironmentVariables.PodIdentityEndpoint);
s_cachedBaseEndpoint = EnvironmentVariables.PodIdentityEndpoint;
}
else
{
Path = subPath
};
logger.Verbose(() => "[Managed Identity] Unable to find AZURE_POD_IDENTITY_AUTHORITY_HOST environment variable for IMDS, using the default endpoint.");
s_cachedBaseEndpoint = DefaultImdsBaseEndpoint;
}
}


UriBuilder builder = new UriBuilder(s_cachedBaseEndpoint)
{
Path = subPath
};

if (!string.IsNullOrEmpty(queryParams))
{
builder.Query = queryParams;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity
break;

case ManagedIdentitySource.Imds:
case ManagedIdentitySource.ImdsV2:
Environment.SetEnvironmentVariable("AZURE_POD_IDENTITY_AUTHORITY_HOST", endpoint);
break;

Expand All @@ -59,11 +60,15 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity
Environment.SetEnvironmentVariable("IDENTITY_HEADER", secret);
Environment.SetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT", thumbprint);
break;

case ManagedIdentitySource.MachineLearning:
Environment.SetEnvironmentVariable("MSI_ENDPOINT", endpoint);
Environment.SetEnvironmentVariable("MSI_SECRET", secret);
Environment.SetEnvironmentVariable("DEFAULT_IDENTITY_CLIENT_ID", "fake_DEFAULT_IDENTITY_CLIENT_ID");
break;

default:
throw new NotImplementedException($"Setting environment variables for {managedIdentitySource} is not implemented.");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using Microsoft.Identity.Client.MtlsPop;
using Microsoft.Identity.Client.PlatformsCommon.Interfaces;
using Microsoft.Identity.Client.PlatformsCommon.Shared;
using Microsoft.Identity.Test.Common.Core.Helpers;
using Microsoft.Identity.Test.Common.Core.Mocks;
using Microsoft.Identity.Test.Unit.Helpers;
using Microsoft.Identity.Test.Unit.PublicApiTests;
Expand Down Expand Up @@ -152,8 +153,11 @@ public async Task BearerTokenHappyPath(
UserAssignedIdentityId userAssignedIdentityId,
string userAssignedId)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.InMemory).ConfigureAwait(false);

AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId);
Expand Down Expand Up @@ -185,8 +189,11 @@ public async Task BearerTokenTokenIsPerIdentity(
UserAssignedIdentityId userAssignedIdentityId,
string userAssignedId)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

#region Identity 1
var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false);

Expand Down Expand Up @@ -246,8 +253,11 @@ public async Task BearerTokenIsReAcquiredWhenCertificatIsExpired(
UserAssignedIdentityId userAssignedIdentityId,
string userAssignedId)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false);

AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredRawCertificate); // cert will be expired on second request
Expand Down Expand Up @@ -288,8 +298,11 @@ public async Task mTLSPopTokenHappyPath(
UserAssignedIdentityId userAssignedIdentityId,
string userAssignedId)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false);

AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true);
Expand Down Expand Up @@ -328,8 +341,11 @@ public async Task mTLSPopTokenTokenIsPerIdentity(
UserAssignedIdentityId userAssignedIdentityId,
string userAssignedId)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

#region Identity 1
var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false);

Expand Down Expand Up @@ -408,8 +424,11 @@ public async Task mTLSPopTokenIsReAcquiredWhenCertificatIsExpired(
UserAssignedIdentityId userAssignedIdentityId,
string userAssignedId)
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false);

AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredRawCertificate, mTLSPop: true);
Expand Down Expand Up @@ -450,8 +469,11 @@ public async Task mTLSPopTokenIsReAcquiredWhenCertificatIsExpired(
[TestMethod]
public async Task GetCsrMetadataAsyncSucceeds()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

var handler = httpManager.AddMockHandler(MockHelpers.MockCsrResponse());

await CreateManagedIdentityAsync(httpManager, addProbeMock: false).ConfigureAwait(false);
Expand All @@ -461,8 +483,11 @@ public async Task GetCsrMetadataAsyncSucceeds()
[TestMethod]
public async Task GetCsrMetadataAsyncSucceedsAfterRetry()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

// First attempt fails with INTERNAL_SERVER_ERROR (500)
httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.InternalServerError));

Expand All @@ -474,8 +499,11 @@ public async Task GetCsrMetadataAsyncSucceedsAfterRetry()
[TestMethod]
public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: null));

var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false);
Expand All @@ -488,8 +516,11 @@ public async Task GetCsrMetadataAsyncFailsWithMissingServerHeader()
[TestMethod]
public async Task GetCsrMetadataAsyncFailsWithInvalidFormat()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

httpManager.AddMockHandler(MockHelpers.MockCsrResponse(responseServerHeader: "I_MDS/150.870.65.1854"));

var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false);
Expand All @@ -502,8 +533,11 @@ public async Task GetCsrMetadataAsyncFailsWithInvalidFormat()
[TestMethod]
public async Task GetCsrMetadataAsyncFailsAfterMaxRetries()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

const int Num500Errors = 1 + TestCsrMetadataProbeRetryPolicy.ExponentialStrategyNumRetries;
for (int i = 0; i < Num500Errors; i++)
{
Expand All @@ -520,8 +554,11 @@ public async Task GetCsrMetadataAsyncFailsAfterMaxRetries()
[TestMethod]
public async Task GetCsrMetadataAsyncFails404WhichIsNonRetriableAndRetryPolicyIsNotTriggeredAsync()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

httpManager.AddMockHandler(MockHelpers.MockCsrResponse(HttpStatusCode.NotFound));

var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false);
Expand Down Expand Up @@ -606,8 +643,11 @@ public void AttachPrivateKeyToCert_NullPrivateKey_ThrowsArgumentNullException()
[TestMethod]
public async Task MtlsPop_AttestationProviderMissing_ThrowsClientException()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false);

// CreateManagedIdentityAsync does a probe; Add one more CSR response for the actual acquire.
Expand All @@ -627,8 +667,11 @@ await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
[TestMethod]
public async Task MtlsPop_AttestationProviderReturnsNull_ThrowsClientException()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false);

// CreateManagedIdentityAsync does a probe; Add one more CSR response for the actual acquire.
Expand All @@ -651,8 +694,11 @@ await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
[TestMethod]
public async Task MtlsPop_AttestationProviderReturnsEmptyToken_ThrowsClientException()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false);

// CreateManagedIdentityAsync does a probe; Add one more CSR response for the actual acquire.
Expand All @@ -675,8 +721,11 @@ await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
[TestMethod]
public async Task mTLSPop_RequestedWithoutKeyGuard_ThrowsClientException()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

// Force in-memory keys (i.e., not KeyGuard)
var managedIdentityApp = await CreateManagedIdentityAsync(
httpManager,
Expand Down