diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs
index aed4821dae..65c84a96f1 100644
--- a/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs
+++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs
@@ -14,5 +14,6 @@ internal class EnvironmentVariables
public static string MsiEndpoint => Environment.GetEnvironmentVariable("MSI_ENDPOINT");
public static string MsiSecret => Environment.GetEnvironmentVariable("MSI_SECRET");
public static string IdentityServerThumbprint => Environment.GetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT");
+ public static string MachineLearningDefaultClientId => Environment.GetEnvironmentVariable("DEFAULT_IDENTITY_CLIENT_ID");
}
}
diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs
index f69f34de7a..02b11bc930 100644
--- a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs
+++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs
@@ -10,12 +10,16 @@ namespace Microsoft.Identity.Client.ManagedIdentity
{
internal class MachineLearningManagedIdentitySource : AbstractManagedIdentity
{
+ private const string MachineLearning = "Machine Learning";
+
private const string MachineLearningMsiApiVersion = "2017-09-01";
private const string SecretHeaderName = "secret";
private readonly Uri _endpoint;
private readonly string _secret;
+ public const string UnsupportedIdTypeError = "Only client id is supported for user-assigned managed identity in Machine Learning."; // referenced in unit test
+
public static AbstractManagedIdentity Create(RequestContext requestContext)
{
requestContext.Logger.Info(() => "[Managed Identity] Machine learning managed identity is available.");
@@ -47,15 +51,12 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger
MsalErrorMessage.ManagedIdentityEndpointInvalidUriError,
"MSI_ENDPOINT", msiEndpoint, "Machine learning");
- // Use the factory to create and throw the exception
- var exception = MsalServiceExceptionFactory.CreateManagedIdentityException(
+ throw MsalServiceExceptionFactory.CreateManagedIdentityException(
MsalError.InvalidManagedIdentityEndpoint,
errorMessage,
ex,
ManagedIdentitySource.MachineLearning,
null); // statusCode is null in this case
-
- throw exception;
}
logger.Info($"[Managed Identity] Environment variables validation passed for machine learning managed identity. Endpoint URI: {endpointUri}. Creating machine learning managed identity.");
@@ -73,21 +74,37 @@ protected override ManagedIdentityRequest CreateRequest(string resource)
switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
{
+ case AppConfig.ManagedIdentityIdType.SystemAssigned:
+ _requestContext.Logger.Info("[Managed Identity] Adding system assigned client id to the request.");
+
+ // this environment variable is always set in an Azure Machine Learning source, but check if null just in case
+ if (EnvironmentVariables.MachineLearningDefaultClientId == null)
+ {
+ throw MsalServiceExceptionFactory.CreateManagedIdentityException(
+ MsalError.InvalidManagedIdentityIdType,
+ "The DEFAULT_IDENTITY_CLIENT_ID environment variable is null.",
+ null, // configuration error
+ ManagedIdentitySource.MachineLearning,
+ null); // statusCode is null in this case
+ }
+
+ // Use the new 2017 constant for older ML-based environment
+ request.QueryParameters[Constants.ManagedIdentityClientId2017] = EnvironmentVariables.MachineLearningDefaultClientId;
+ break;
+
case AppConfig.ManagedIdentityIdType.ClientId:
_requestContext.Logger.Info("[Managed Identity] Adding user assigned client id to the request.");
// Use the new 2017 constant for older ML-based environment
request.QueryParameters[Constants.ManagedIdentityClientId2017] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId;
break;
- case AppConfig.ManagedIdentityIdType.ResourceId:
- _requestContext.Logger.Info("[Managed Identity] Adding user assigned resource id to the request.");
- request.QueryParameters[Constants.ManagedIdentityResourceId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId;
- break;
-
- case AppConfig.ManagedIdentityIdType.ObjectId:
- _requestContext.Logger.Info("[Managed Identity] Adding user assigned object id to the request.");
- request.QueryParameters[Constants.ManagedIdentityObjectId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId;
- break;
+ default:
+ throw MsalServiceExceptionFactory.CreateManagedIdentityException(
+ MsalError.InvalidManagedIdentityIdType,
+ UnsupportedIdTypeError,
+ null, // configuration error
+ ManagedIdentitySource.MachineLearning,
+ null); // statusCode is null in this case
}
return request;
diff --git a/src/client/Microsoft.Identity.Client/MsalError.cs b/src/client/Microsoft.Identity.Client/MsalError.cs
index e712c02c77..a13974c616 100644
--- a/src/client/Microsoft.Identity.Client/MsalError.cs
+++ b/src/client/Microsoft.Identity.Client/MsalError.cs
@@ -1105,6 +1105,16 @@ public static class MsalError
///
public const string InvalidManagedIdentityResponse = "invalid_managed_identity_response";
+ ///
+ /// The managed identity's source does not select a specific id type.
+ ///
+ public const string InvalidManagedIdentityIdType = "invalid_managed_identity_id_type";
+
+ ///
+ /// The managed identity is missing a required environment variable.
+ ///
+ public const string MissingManagedIdentityEnvVar = "missing_managed_identity_env_var";
+
///
/// Managed Identity error response was received.
///
diff --git a/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs b/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs
index f7cac1b493..182b96fa29 100644
--- a/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs
+++ b/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs
@@ -415,6 +415,7 @@ public static string InvalidTokenProviderResponseValue(string invalidValueName)
public const string ManagedIdentityNoResponseReceived = "[Managed Identity] Authentication unavailable. No response received from the managed identity endpoint.";
public const string ManagedIdentityInvalidResponse = "[Managed Identity] Invalid response, the authentication response received did not contain the expected fields.";
+ public const string ManagedIdentityInvalidIdType = "Only {0} supported for user-assigned managed identity in {1}";
public const string ManagedIdentityJsonParseFailure = "[Managed Identity] MSI returned 200 OK, but the response could not be parsed.";
public const string ManagedIdentityUnexpectedResponse = "[Managed Identity] Unexpected exception occurred when parsing the response. See the inner exception for details.";
public const string ManagedIdentityExactlyOneScopeExpected = "[Managed Identity] To acquire token for managed identity, exactly one scope must be passed.";
diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt
index 8b13789179..50d9f12956 100644
--- a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt
+++ b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt
@@ -1 +1,2 @@
-
+const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
+const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt
index 8b13789179..50d9f12956 100644
--- a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt
+++ b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt
@@ -1 +1,2 @@
-
+const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
+const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt
index e69de29bb2..50d9f12956 100644
--- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt
+++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt
@@ -0,0 +1,2 @@
+const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
+const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt
index e69de29bb2..50d9f12956 100644
--- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt
+++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt
@@ -0,0 +1,2 @@
+const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
+const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt
index 8b13789179..50d9f12956 100644
--- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt
+++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt
@@ -1 +1,2 @@
-
+const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
+const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
diff --git a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt
index e69de29bb2..50d9f12956 100644
--- a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt
+++ b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt
@@ -0,0 +1,2 @@
+const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
+const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
diff --git a/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs b/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs
index c6e0627d0c..1283c66ad5 100644
--- a/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs
+++ b/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs
@@ -62,6 +62,7 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity
case ManagedIdentitySource.MachineLearning:
Environment.SetEnvironmentVariable("MSI_ENDPOINT", endpoint);
Environment.SetEnvironmentVariable("MSI_SECRET", secret);
+ Environment.SetEnvironmentVariable("DEFAULT_IDENTITY_CLIENT_ID", "fake_DEFAULT_IDENTITY_CLIENT_ID");
break;
}
}
diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs
index ae196aa310..38bd9239e3 100644
--- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs
+++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs
@@ -359,7 +359,7 @@ public static void AddRegionDiscoveryMockHandler(
});
}
- public static void AddManagedIdentityMockHandler(
+ public static MockHttpMessageHandler AddManagedIdentityMockHandler(
this MockHttpManager httpManager,
string expectedUrl,
string resource,
@@ -383,37 +383,42 @@ public static void AddManagedIdentityMockHandler(
MockHttpMessageHandler httpMessageHandler = BuildMockHandlerForManagedIdentitySource(managedIdentitySourceType, resource);
- if (userAssignedIdentityId == UserAssignedIdentityId.ClientId)
+ if (managedIdentitySourceType == ManagedIdentitySource.MachineLearning)
{
- if (managedIdentitySourceType == ManagedIdentitySource.MachineLearning)
- {
- // For Machine Learning (App Service 2017), the param is "clientid"
- httpMessageHandler.ExpectedQueryParams.Add(Constants.ManagedIdentityClientId2017, userAssignedId);
- }
- else
- {
- // For App Service 2019, Azure Arc, IMDS, etc., the param is "client_id"
- httpMessageHandler.ExpectedQueryParams.Add(Constants.ManagedIdentityClientId, userAssignedId);
- }
+ // For Machine Learning (App Service 2017), the client id param is "clientid"
+ // it will always be a query parameter, no matter the source type
+ // use env var for SAMI, passed-in userAssignedId for UAMI
+ httpMessageHandler.ExpectedQueryParams.Add(
+ Constants.ManagedIdentityClientId2017,
+ userAssignedId ?? EnvironmentVariables.MachineLearningDefaultClientId);
}
-
- if (userAssignedIdentityId == UserAssignedIdentityId.ResourceId)
+ else if (userAssignedIdentityId == UserAssignedIdentityId.ClientId)
+ {
+ // For App Service 2019, Azure Arc, IMDS, etc., the param is "client_id"
+ httpMessageHandler.ExpectedQueryParams.Add(
+ Constants.ManagedIdentityClientId,
+ userAssignedId);
+ }
+ else if (userAssignedIdentityId == UserAssignedIdentityId.ResourceId)
{
httpMessageHandler.ExpectedQueryParams.Add(
managedIdentitySourceType == ManagedIdentitySource.Imds ?
Constants.ManagedIdentityResourceIdImds : Constants.ManagedIdentityResourceId,
userAssignedId);
}
-
- if (userAssignedIdentityId == UserAssignedIdentityId.ObjectId)
+ else if (userAssignedIdentityId == UserAssignedIdentityId.ObjectId)
{
- httpMessageHandler.ExpectedQueryParams.Add(Constants.ManagedIdentityObjectId, userAssignedId);
+ httpMessageHandler.ExpectedQueryParams.Add(
+ Constants.ManagedIdentityObjectId,
+ userAssignedId);
}
httpMessageHandler.ResponseMessage = responseMessage;
httpMessageHandler.ExpectedUrl = expectedUrl;
httpManager.AddMockHandler(httpMessageHandler);
+
+ return httpMessageHandler;
}
private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(ManagedIdentitySource managedIdentitySourceType, string resource)
diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs
index a63c469b17..137be702a3 100644
--- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs
+++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs
@@ -3,12 +3,11 @@
using System;
using System.Globalization;
-using System.Net;
using System.Threading.Tasks;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.AppConfig;
+using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.ManagedIdentity;
-using Microsoft.Identity.Test.Common;
using Microsoft.Identity.Test.Common.Core.Helpers;
using Microsoft.Identity.Test.Common.Core.Mocks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
@@ -20,6 +19,84 @@ namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests
public class MachineLearningTests : TestBase
{
private const string MachineLearning = "Machine learning";
+ private const string MachineLearningEndpoint = "http://localhost:7071/msi/token";
+ internal const string Resource = "https://management.azure.com";
+
+ [DataTestMethod]
+ [DataRow(null, null)] // SAMI
+ [DataRow(TestConstants.ClientId, UserAssignedIdentityId.ClientId)] // UAMI
+ public async Task MachineLearningUserAssignedHappyPathAndHasCorrectClientIdQueryParameterAsync(
+ string userAssignedId,
+ UserAssignedIdentityId userAssignedIdentityId)
+ {
+ using (new EnvVariableContext())
+ using (var httpManager = new MockHttpManager())
+ {
+ SetEnvironmentVariables(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint);
+
+ ManagedIdentityId managedIdentityId = userAssignedId == null
+ ? ManagedIdentityId.SystemAssigned
+ : ManagedIdentityId.WithUserAssignedClientId(userAssignedId);
+ var miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId)
+ .WithHttpManager(httpManager);
+
+ // Disabling shared cache options to avoid cross test pollution.
+ miBuilder.Config.AccessorOptions = null;
+
+ var mi = miBuilder.Build();
+
+ MockHttpMessageHandler mockHandler = httpManager.AddManagedIdentityMockHandler(
+ MachineLearningEndpoint,
+ Resource,
+ MockHelpers.GetMsiSuccessfulResponse(),
+ ManagedIdentitySource.MachineLearning,
+ userAssignedId: userAssignedId,
+ userAssignedIdentityId);
+
+ AuthenticationResult result = await mi.AcquireTokenForManagedIdentity(Resource).ExecuteAsync().ConfigureAwait(false);
+
+ Assert.IsNotNull(result);
+ Assert.IsNotNull(result.AccessToken);
+
+ // Verify query parameter is "clientid" and not "client_id"
+ Assert.IsTrue(mockHandler.ExpectedQueryParams.ContainsKey(Constants.ManagedIdentityClientId2017), "Query parameter should use 'clientid' and not 'client_id'");
+
+ // Verify the clientid value based on identity type
+ string expectedClientId = userAssignedId ?? EnvironmentVariables.MachineLearningDefaultClientId;
+ Assert.AreEqual(expectedClientId, mockHandler.ExpectedQueryParams[Constants.ManagedIdentityClientId2017],
+ "Clientid value should match the provided user assigned ID for UAMI or environment variable for SAMI");
+ }
+ }
+
+ [DataTestMethod]
+ [DataRow(TestConstants.MiResourceId, UserAssignedIdentityId.ResourceId)]
+ [DataRow(TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)]
+ public async Task MachineLearningUserAssignedNonClientIdThrowsAsync(
+ string userAssignedId,
+ UserAssignedIdentityId userAssignedIdentityId)
+ {
+ using (new EnvVariableContext())
+ using (var httpManager = new MockHttpManager())
+ {
+ SetEnvironmentVariables(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint);
+
+ var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId)
+ .WithHttpManager(httpManager);
+
+ // Disabling shared cache options to avoid cross test pollution.
+ miBuilder.Config.AccessorOptions = null;
+
+ var mi = miBuilder.Build();
+
+ MsalServiceException ex = await Assert.ThrowsExceptionAsync(async () =>
+ await mi.AcquireTokenForManagedIdentity(Resource)
+ .ExecuteAsync().ConfigureAwait(false)).ConfigureAwait(false);
+
+ Assert.IsNotNull(ex);
+ Assert.AreEqual(ManagedIdentitySource.MachineLearning.ToString(), ex.AdditionalExceptionData[MsalException.ManagedIdentitySource]);
+ Assert.AreEqual(MsalError.InvalidManagedIdentityIdType, ex.ErrorCode);
+ }
+ }
[TestMethod]
public async Task MachineLearningTestsInvalidEndpointAsync()
diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs
index d8d385e848..d4dba3501e 100644
--- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs
+++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs
@@ -127,8 +127,6 @@ public async Task ManagedIdentityHappyPathAsync(
[DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId .ResourceId)]
[DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)]
[DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.ClientId, UserAssignedIdentityId.ClientId)]
- [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.MiResourceId, UserAssignedIdentityId.ResourceId)]
- [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)]
public async Task ManagedIdentityUserAssignedHappyPathAsync(
string endpoint,
ManagedIdentitySource managedIdentitySource,