diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs index aed4821dae..65c84a96f1 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs @@ -14,5 +14,6 @@ internal class EnvironmentVariables public static string MsiEndpoint => Environment.GetEnvironmentVariable("MSI_ENDPOINT"); public static string MsiSecret => Environment.GetEnvironmentVariable("MSI_SECRET"); public static string IdentityServerThumbprint => Environment.GetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT"); + public static string MachineLearningDefaultClientId => Environment.GetEnvironmentVariable("DEFAULT_IDENTITY_CLIENT_ID"); } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs index f69f34de7a..02b11bc930 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs @@ -10,12 +10,16 @@ namespace Microsoft.Identity.Client.ManagedIdentity { internal class MachineLearningManagedIdentitySource : AbstractManagedIdentity { + private const string MachineLearning = "Machine Learning"; + private const string MachineLearningMsiApiVersion = "2017-09-01"; private const string SecretHeaderName = "secret"; private readonly Uri _endpoint; private readonly string _secret; + public const string UnsupportedIdTypeError = "Only client id is supported for user-assigned managed identity in Machine Learning."; // referenced in unit test + public static AbstractManagedIdentity Create(RequestContext requestContext) { requestContext.Logger.Info(() => "[Managed Identity] Machine learning managed identity is available."); @@ -47,15 +51,12 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger MsalErrorMessage.ManagedIdentityEndpointInvalidUriError, "MSI_ENDPOINT", msiEndpoint, "Machine learning"); - // Use the factory to create and throw the exception - var exception = MsalServiceExceptionFactory.CreateManagedIdentityException( + throw MsalServiceExceptionFactory.CreateManagedIdentityException( MsalError.InvalidManagedIdentityEndpoint, errorMessage, ex, ManagedIdentitySource.MachineLearning, null); // statusCode is null in this case - - throw exception; } logger.Info($"[Managed Identity] Environment variables validation passed for machine learning managed identity. Endpoint URI: {endpointUri}. Creating machine learning managed identity."); @@ -73,21 +74,37 @@ protected override ManagedIdentityRequest CreateRequest(string resource) switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType) { + case AppConfig.ManagedIdentityIdType.SystemAssigned: + _requestContext.Logger.Info("[Managed Identity] Adding system assigned client id to the request."); + + // this environment variable is always set in an Azure Machine Learning source, but check if null just in case + if (EnvironmentVariables.MachineLearningDefaultClientId == null) + { + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.InvalidManagedIdentityIdType, + "The DEFAULT_IDENTITY_CLIENT_ID environment variable is null.", + null, // configuration error + ManagedIdentitySource.MachineLearning, + null); // statusCode is null in this case + } + + // Use the new 2017 constant for older ML-based environment + request.QueryParameters[Constants.ManagedIdentityClientId2017] = EnvironmentVariables.MachineLearningDefaultClientId; + break; + case AppConfig.ManagedIdentityIdType.ClientId: _requestContext.Logger.Info("[Managed Identity] Adding user assigned client id to the request."); // Use the new 2017 constant for older ML-based environment request.QueryParameters[Constants.ManagedIdentityClientId2017] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId; break; - case AppConfig.ManagedIdentityIdType.ResourceId: - _requestContext.Logger.Info("[Managed Identity] Adding user assigned resource id to the request."); - request.QueryParameters[Constants.ManagedIdentityResourceId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId; - break; - - case AppConfig.ManagedIdentityIdType.ObjectId: - _requestContext.Logger.Info("[Managed Identity] Adding user assigned object id to the request."); - request.QueryParameters[Constants.ManagedIdentityObjectId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId; - break; + default: + throw MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.InvalidManagedIdentityIdType, + UnsupportedIdTypeError, + null, // configuration error + ManagedIdentitySource.MachineLearning, + null); // statusCode is null in this case } return request; diff --git a/src/client/Microsoft.Identity.Client/MsalError.cs b/src/client/Microsoft.Identity.Client/MsalError.cs index e712c02c77..a13974c616 100644 --- a/src/client/Microsoft.Identity.Client/MsalError.cs +++ b/src/client/Microsoft.Identity.Client/MsalError.cs @@ -1105,6 +1105,16 @@ public static class MsalError /// public const string InvalidManagedIdentityResponse = "invalid_managed_identity_response"; + /// + /// The managed identity's source does not select a specific id type. + /// + public const string InvalidManagedIdentityIdType = "invalid_managed_identity_id_type"; + + /// + /// The managed identity is missing a required environment variable. + /// + public const string MissingManagedIdentityEnvVar = "missing_managed_identity_env_var"; + /// /// Managed Identity error response was received. /// diff --git a/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs b/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs index f7cac1b493..182b96fa29 100644 --- a/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs +++ b/src/client/Microsoft.Identity.Client/MsalErrorMessage.cs @@ -415,6 +415,7 @@ public static string InvalidTokenProviderResponseValue(string invalidValueName) public const string ManagedIdentityNoResponseReceived = "[Managed Identity] Authentication unavailable. No response received from the managed identity endpoint."; public const string ManagedIdentityInvalidResponse = "[Managed Identity] Invalid response, the authentication response received did not contain the expected fields."; + public const string ManagedIdentityInvalidIdType = "Only {0} supported for user-assigned managed identity in {1}"; public const string ManagedIdentityJsonParseFailure = "[Managed Identity] MSI returned 200 OK, but the response could not be parsed."; public const string ManagedIdentityUnexpectedResponse = "[Managed Identity] Unexpected exception occurred when parsing the response. See the inner exception for details."; public const string ManagedIdentityExactlyOneScopeExpected = "[Managed Identity] To acquire token for managed identity, exactly one scope must be passed."; diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt index 8b13789179..50d9f12956 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ - +const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string +const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt index 8b13789179..50d9f12956 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ - +const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string +const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt index e69de29bb2..50d9f12956 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string +const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt index e69de29bb2..50d9f12956 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string +const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt index 8b13789179..50d9f12956 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ - +const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string +const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string diff --git a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt index e69de29bb2..50d9f12956 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string +const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string diff --git a/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs b/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs index c6e0627d0c..1283c66ad5 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs @@ -62,6 +62,7 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity case ManagedIdentitySource.MachineLearning: Environment.SetEnvironmentVariable("MSI_ENDPOINT", endpoint); Environment.SetEnvironmentVariable("MSI_SECRET", secret); + Environment.SetEnvironmentVariable("DEFAULT_IDENTITY_CLIENT_ID", "fake_DEFAULT_IDENTITY_CLIENT_ID"); break; } } diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs index ae196aa310..38bd9239e3 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs @@ -359,7 +359,7 @@ public static void AddRegionDiscoveryMockHandler( }); } - public static void AddManagedIdentityMockHandler( + public static MockHttpMessageHandler AddManagedIdentityMockHandler( this MockHttpManager httpManager, string expectedUrl, string resource, @@ -383,37 +383,42 @@ public static void AddManagedIdentityMockHandler( MockHttpMessageHandler httpMessageHandler = BuildMockHandlerForManagedIdentitySource(managedIdentitySourceType, resource); - if (userAssignedIdentityId == UserAssignedIdentityId.ClientId) + if (managedIdentitySourceType == ManagedIdentitySource.MachineLearning) { - if (managedIdentitySourceType == ManagedIdentitySource.MachineLearning) - { - // For Machine Learning (App Service 2017), the param is "clientid" - httpMessageHandler.ExpectedQueryParams.Add(Constants.ManagedIdentityClientId2017, userAssignedId); - } - else - { - // For App Service 2019, Azure Arc, IMDS, etc., the param is "client_id" - httpMessageHandler.ExpectedQueryParams.Add(Constants.ManagedIdentityClientId, userAssignedId); - } + // For Machine Learning (App Service 2017), the client id param is "clientid" + // it will always be a query parameter, no matter the source type + // use env var for SAMI, passed-in userAssignedId for UAMI + httpMessageHandler.ExpectedQueryParams.Add( + Constants.ManagedIdentityClientId2017, + userAssignedId ?? EnvironmentVariables.MachineLearningDefaultClientId); } - - if (userAssignedIdentityId == UserAssignedIdentityId.ResourceId) + else if (userAssignedIdentityId == UserAssignedIdentityId.ClientId) + { + // For App Service 2019, Azure Arc, IMDS, etc., the param is "client_id" + httpMessageHandler.ExpectedQueryParams.Add( + Constants.ManagedIdentityClientId, + userAssignedId); + } + else if (userAssignedIdentityId == UserAssignedIdentityId.ResourceId) { httpMessageHandler.ExpectedQueryParams.Add( managedIdentitySourceType == ManagedIdentitySource.Imds ? Constants.ManagedIdentityResourceIdImds : Constants.ManagedIdentityResourceId, userAssignedId); } - - if (userAssignedIdentityId == UserAssignedIdentityId.ObjectId) + else if (userAssignedIdentityId == UserAssignedIdentityId.ObjectId) { - httpMessageHandler.ExpectedQueryParams.Add(Constants.ManagedIdentityObjectId, userAssignedId); + httpMessageHandler.ExpectedQueryParams.Add( + Constants.ManagedIdentityObjectId, + userAssignedId); } httpMessageHandler.ResponseMessage = responseMessage; httpMessageHandler.ExpectedUrl = expectedUrl; httpManager.AddMockHandler(httpMessageHandler); + + return httpMessageHandler; } private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(ManagedIdentitySource managedIdentitySourceType, string resource) diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs index a63c469b17..137be702a3 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs @@ -3,12 +3,11 @@ using System; using System.Globalization; -using System.Net; using System.Threading.Tasks; using Microsoft.Identity.Client; using Microsoft.Identity.Client.AppConfig; +using Microsoft.Identity.Client.Internal; using Microsoft.Identity.Client.ManagedIdentity; -using Microsoft.Identity.Test.Common; using Microsoft.Identity.Test.Common.Core.Helpers; using Microsoft.Identity.Test.Common.Core.Mocks; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -20,6 +19,84 @@ namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests public class MachineLearningTests : TestBase { private const string MachineLearning = "Machine learning"; + private const string MachineLearningEndpoint = "http://localhost:7071/msi/token"; + internal const string Resource = "https://management.azure.com"; + + [DataTestMethod] + [DataRow(null, null)] // SAMI + [DataRow(TestConstants.ClientId, UserAssignedIdentityId.ClientId)] // UAMI + public async Task MachineLearningUserAssignedHappyPathAndHasCorrectClientIdQueryParameterAsync( + string userAssignedId, + UserAssignedIdentityId userAssignedIdentityId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint); + + ManagedIdentityId managedIdentityId = userAssignedId == null + ? ManagedIdentityId.SystemAssigned + : ManagedIdentityId.WithUserAssignedClientId(userAssignedId); + var miBuilder = ManagedIdentityApplicationBuilder.Create(managedIdentityId) + .WithHttpManager(httpManager); + + // Disabling shared cache options to avoid cross test pollution. + miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); + + MockHttpMessageHandler mockHandler = httpManager.AddManagedIdentityMockHandler( + MachineLearningEndpoint, + Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.MachineLearning, + userAssignedId: userAssignedId, + userAssignedIdentityId); + + AuthenticationResult result = await mi.AcquireTokenForManagedIdentity(Resource).ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + + // Verify query parameter is "clientid" and not "client_id" + Assert.IsTrue(mockHandler.ExpectedQueryParams.ContainsKey(Constants.ManagedIdentityClientId2017), "Query parameter should use 'clientid' and not 'client_id'"); + + // Verify the clientid value based on identity type + string expectedClientId = userAssignedId ?? EnvironmentVariables.MachineLearningDefaultClientId; + Assert.AreEqual(expectedClientId, mockHandler.ExpectedQueryParams[Constants.ManagedIdentityClientId2017], + "Clientid value should match the provided user assigned ID for UAMI or environment variable for SAMI"); + } + } + + [DataTestMethod] + [DataRow(TestConstants.MiResourceId, UserAssignedIdentityId.ResourceId)] + [DataRow(TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)] + public async Task MachineLearningUserAssignedNonClientIdThrowsAsync( + string userAssignedId, + UserAssignedIdentityId userAssignedIdentityId) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + SetEnvironmentVariables(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint); + + var miBuilder = CreateMIABuilder(userAssignedId, userAssignedIdentityId) + .WithHttpManager(httpManager); + + // Disabling shared cache options to avoid cross test pollution. + miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); + + MsalServiceException ex = await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity(Resource) + .ExecuteAsync().ConfigureAwait(false)).ConfigureAwait(false); + + Assert.IsNotNull(ex); + Assert.AreEqual(ManagedIdentitySource.MachineLearning.ToString(), ex.AdditionalExceptionData[MsalException.ManagedIdentitySource]); + Assert.AreEqual(MsalError.InvalidManagedIdentityIdType, ex.ErrorCode); + } + } [TestMethod] public async Task MachineLearningTestsInvalidEndpointAsync() diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs index d8d385e848..d4dba3501e 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs @@ -127,8 +127,6 @@ public async Task ManagedIdentityHappyPathAsync( [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId .ResourceId)] [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)] [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.ClientId, UserAssignedIdentityId.ClientId)] - [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.MiResourceId, UserAssignedIdentityId.ResourceId)] - [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)] public async Task ManagedIdentityUserAssignedHappyPathAsync( string endpoint, ManagedIdentitySource managedIdentitySource,