Skip to content

Commit 0897cce

Browse files
ImdsV2: Integrated .WithMtlsProofOfPossession (#5490)
1 parent 774e01e commit 0897cce

File tree

7 files changed

+71
-50
lines changed

7 files changed

+71
-50
lines changed

src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ internal class AcquireTokenForManagedIdentityParameters : IAcquireTokenParameter
2020

2121
public string RevokedTokenHash { get; set; }
2222

23+
public bool IsMtlsPopRequested { get; set; }
24+
2325
public void LogParameters(ILoggerAdapter logger)
2426
{
2527
if (logger.IsLoggingEnabled(LogLevel.Info))

src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ protected override async Task<AuthenticationResult> ExecuteAsync(CancellationTok
9595
logger.Info("[ManagedIdentityRequest] Access token retrieved from cache.");
9696

9797
try
98-
{
98+
{
9999
var proactivelyRefresh = SilentRequestHelper.NeedsRefresh(cachedAccessTokenItem);
100100

101101
// If needed, refreshes token in the background
@@ -141,7 +141,7 @@ protected override async Task<AuthenticationResult> ExecuteAsync(CancellationTok
141141
}
142142

143143
private async Task<AuthenticationResult> GetAccessTokenAsync(
144-
CancellationToken cancellationToken,
144+
CancellationToken cancellationToken,
145145
ILoggerAdapter logger)
146146
{
147147
AuthenticationResult authResult;
@@ -161,7 +161,7 @@ private async Task<AuthenticationResult> GetAccessTokenAsync(
161161
// 1) ForceRefresh is requested
162162
// 2) Proactive refresh is in effect
163163
// 3) Claims are present (revocation flow)
164-
if (_managedIdentityParameters.ForceRefresh ||
164+
if (_managedIdentityParameters.ForceRefresh ||
165165
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo == CacheRefreshReason.ProactivelyRefreshed ||
166166
!string.IsNullOrEmpty(_managedIdentityParameters.Claims))
167167
{
@@ -198,6 +198,8 @@ private async Task<AuthenticationResult> SendTokenRequestForManagedIdentityAsync
198198

199199
await ResolveAuthorityAsync().ConfigureAwait(false);
200200

201+
_managedIdentityParameters.IsMtlsPopRequested = AuthenticationRequestParameters.IsMtlsPopRequested;
202+
201203
ManagedIdentityResponse managedIdentityResponse =
202204
await _managedIdentityClient
203205
.SendTokenRequestForManagedIdentityAsync(AuthenticationRequestParameters.RequestContext, _managedIdentityParameters, cancellationToken)

src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ internal abstract class AbstractManagedIdentity
3131

3232
protected readonly RequestContext _requestContext;
3333

34+
protected bool _isMtlsPopRequested;
35+
3436
internal const string TimeoutError = "[Managed Identity] Authentication unavailable. The request to the managed identity endpoint timed out.";
3537
internal readonly ManagedIdentitySource _sourceType;
36-
38+
3739
protected AbstractManagedIdentity(RequestContext requestContext, ManagedIdentitySource sourceType)
3840
{
3941
_requestContext = requestContext;
@@ -55,6 +57,8 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
5557
// Convert the scopes to a resource string.
5658
string resource = parameters.Resource;
5759

60+
_isMtlsPopRequested = parameters.IsMtlsPopRequested;
61+
5862
ManagedIdentityRequest request = await CreateRequestAsync(resource).ConfigureAwait(false);
5963

6064
// Automatically add claims / capabilities if this MI source supports them
@@ -83,7 +87,7 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
8387
logger: _requestContext.Logger,
8488
doNotThrow: true,
8589
mtlsCertificate: request.MtlsCertificate,
86-
validateServerCertificate: GetValidationCallback(),
90+
validateServerCertificate: GetValidationCallback(),
8791
cancellationToken: cancellationToken,
8892
retryPolicy: retryPolicy).ConfigureAwait(false);
8993
}
@@ -98,7 +102,7 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
98102
logger: _requestContext.Logger,
99103
doNotThrow: true,
100104
mtlsCertificate: request.MtlsCertificate,
101-
validateServerCertificate: GetValidationCallback(),
105+
validateServerCertificate: GetValidationCallback(),
102106
cancellationToken: cancellationToken,
103107
retryPolicy: retryPolicy)
104108
.ConfigureAwait(false);
@@ -172,8 +176,8 @@ protected ManagedIdentityResponse GetSuccessfulResponse(HttpResponse response)
172176
throw exception;
173177
}
174178

175-
if (managedIdentityResponse == null ||
176-
managedIdentityResponse.AccessToken.IsNullOrEmpty() ||
179+
if (managedIdentityResponse == null ||
180+
managedIdentityResponse.AccessToken.IsNullOrEmpty() ||
177181
managedIdentityResponse.ExpiresOn.IsNullOrEmpty())
178182
{
179183
_requestContext.Logger.Error("[Managed Identity] Response is either null or insufficient for authentication.");

src/client/Microsoft.Identity.Client/ManagedIdentity/V2/ImdsV2ManagedIdentitySource.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ protected override async Task<ManagedIdentityRequest> CreateRequestAsync(string
259259

260260
var keyInfo = await _requestContext.ServiceBundle.PlatformProxy.ManagedIdentityKeyProvider
261261
.GetOrCreateKeyAsync(_requestContext.Logger, _requestContext.UserCancellationToken).ConfigureAwait(false);
262-
262+
263263
var (csr, privateKey) = _requestContext.ServiceBundle.Config.CsrFactory.Generate(keyInfo.Key, csrMetadata.ClientId, csrMetadata.TenantId, csrMetadata.CuId);
264264

265265
var certificateRequestResponse = await ExecuteCertificateRequestAsync(csr).ConfigureAwait(false);
@@ -280,10 +280,12 @@ protected override async Task<ManagedIdentityRequest> CreateRequestAsync(string
280280
request.Headers.Add(ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue);
281281
request.Headers.Add(OAuth2Header.RequestCorrelationIdInResponse, "true");
282282

283+
var tokenType = _isMtlsPopRequested ? "mtls_pop" : "bearer";
284+
283285
request.BodyParameters.Add("client_id", certificateRequestResponse.ClientId);
284286
request.BodyParameters.Add("grant_type", OAuth2GrantType.ClientCredentials);
285287
request.BodyParameters.Add("scope", "https://management.azure.com/.default");
286-
request.BodyParameters.Add("token_type", "bearer");
288+
request.BodyParameters.Add("token_type", tokenType);
287289

288290
request.RequestType = RequestType.STS;
289291

tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ public static string GetBridgedHybridSpaTokenResponse(string spaAccountId)
122122
",\"id_token_expires_in\":\"3600\"}";
123123
}
124124

125-
public static string GetMsiSuccessfulResponse(int expiresInHours = 1, bool useIsoFormat = false)
125+
public static string GetMsiSuccessfulResponse(
126+
int expiresInHours = 1,
127+
bool useIsoFormat = false,
128+
bool mTLSPop = false)
126129
{
127130
string expiresOn;
128131

@@ -137,9 +140,11 @@ public static string GetMsiSuccessfulResponse(int expiresInHours = 1, bool useIs
137140
expiresOn = DateTimeHelpers.DateTimeToUnixTimestamp(DateTime.UtcNow.AddHours(expiresInHours));
138141
}
139142

143+
var tokenType = mTLSPop ? "mtls_pop" : "Bearer";
144+
140145
return
141-
"{\"access_token\":\"" + TestConstants.ATSecret + "\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"https://management.azure.com/\",\"token_type\":" +
142-
"\"Bearer\",\"client_id\":\"client_id\"}";
146+
"{\"access_token\":\"" + TestConstants.ATSecret + "\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"https://management.azure.com/\"," +
147+
"\"token_type\":\"" + tokenType + "\",\"client_id\":\"client_id\"}";
143148
}
144149

145150
public static string GetMsiErrorBadJson()
@@ -725,7 +730,7 @@ public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse(
725730
PresentRequestHeaders = presentRequestHeaders,
726731
ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK)
727732
{
728-
Content = new StringContent(GetMsiSuccessfulResponse()),
733+
Content = new StringContent(GetMsiSuccessfulResponse(mTLSPop: mTLSPop)),
729734
}
730735
};
731736

tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsV2Tests.cs

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.Identity.Client.ManagedIdentity;
1313
using Microsoft.Identity.Client.ManagedIdentity.KeyProviders;
1414
using Microsoft.Identity.Client.ManagedIdentity.V2;
15+
using Microsoft.Identity.Client.MtlsPop;
1516
using Microsoft.Identity.Client.PlatformsCommon.Shared;
1617
using Microsoft.Identity.Test.Common.Core.Mocks;
1718
using Microsoft.Identity.Test.Unit.Helpers;
@@ -34,7 +35,7 @@ public class ImdsV2Tests : TestBase
3435
enablePiiLogging: false
3536
);
3637
public const string Bearer = "Bearer";
37-
public const string MTLSPoP = "MTLSPoP";
38+
public const string MTLSPoP = "mtls_pop";
3839

3940
private void AddMocksToGetEntraToken(
4041
MockHttpManager httpManager,
@@ -256,26 +257,28 @@ public async Task mTLSPopTokenHappyPath(
256257
{
257258
var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false);
258259

259-
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId/*, mTLSPop: true*/); // TODO: implement mTLS Pop
260+
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true);
260261

261262
var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
262-
// .WithMtlsProofOfPossession() // TODO: implement mTLS Pop
263+
.WithMtlsProofOfPossession()
263264
.ExecuteAsync().ConfigureAwait(false);
264265

265266
Assert.IsNotNull(result);
266267
Assert.IsNotNull(result.AccessToken);
267-
// Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop
268-
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop
268+
Assert.AreEqual(result.TokenType, MTLSPoP);
269+
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
269270
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
270271

271-
result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
272+
// TODO: broken until Gladwin's PR is merged in
273+
/*result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
274+
.WithMtlsProofOfPossession()
272275
.ExecuteAsync().ConfigureAwait(false);
273276
274277
Assert.IsNotNull(result);
275278
Assert.IsNotNull(result.AccessToken);
276-
// Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop
277-
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop
278-
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);
279+
Assert.AreEqual(result.TokenType, MTLSPoP);
280+
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
281+
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);*/
279282
}
280283
}
281284

@@ -293,53 +296,55 @@ public async Task mTLSPopTokenTokenIsPerIdentity(
293296
#region Identity 1
294297
var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false);
295298

296-
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId/*, mTLSPop: true*/); // TODO: implement mTLS Pop
299+
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true);
297300

298301
var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
299-
// .WithMtlsProofOfPossession() // TODO: implement mTLS Pop
302+
.WithMtlsProofOfPossession()
300303
.ExecuteAsync().ConfigureAwait(false);
301304

302305
Assert.IsNotNull(result);
303306
Assert.IsNotNull(result.AccessToken);
304-
// Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop
305-
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop
307+
Assert.AreEqual(result.TokenType, MTLSPoP);
308+
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
306309
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
307310

308-
result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
309-
// .WithMtlsProofOfPossession() // TODO: implement mTLS Pop
311+
// TODO: broken until Gladwin's PR is merged in
312+
/*result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
313+
.WithMtlsProofOfPossession()
310314
.ExecuteAsync().ConfigureAwait(false);
311315
312316
Assert.IsNotNull(result);
313317
Assert.IsNotNull(result.AccessToken);
314-
// Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop
315-
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop
316-
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);
318+
Assert.AreEqual(result.TokenType, MTLSPoP);
319+
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
320+
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);*/
317321
#endregion Identity 1
318322

319323
#region Identity 2
320324
var managedIdentityApp2 = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId, addProbeMock: false, addSourceCheck: false).ConfigureAwait(false); // source is already cached
321325

322-
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId/*, mTLSPop: true*/); // TODO: implement mTLS Pop
326+
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true);
323327

324328
var result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
325-
// .WithMtlsProofOfPossession() // TODO: implement mTLS Pop
329+
.WithMtlsProofOfPossession()
326330
.ExecuteAsync().ConfigureAwait(false);
327331

328332
Assert.IsNotNull(result2);
329333
Assert.IsNotNull(result2.AccessToken);
330-
// Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop
331-
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop
334+
Assert.AreEqual(result.TokenType, MTLSPoP);
335+
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
332336
Assert.AreEqual(TokenSource.IdentityProvider, result2.AuthenticationResultMetadata.TokenSource);
333337

334-
result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
335-
// .WithMtlsProofOfPossession() // TODO: implement mTLS Pop
338+
// TODO: broken until Gladwin's PR is merged in
339+
/*result2 = await managedIdentityApp2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
340+
.WithMtlsProofOfPossession()
336341
.ExecuteAsync().ConfigureAwait(false);
337342
338343
Assert.IsNotNull(result2);
339344
Assert.IsNotNull(result2.AccessToken);
340-
// Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop
341-
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop
342-
Assert.AreEqual(TokenSource.Cache, result2.AuthenticationResultMetadata.TokenSource);
345+
Assert.AreEqual(result.TokenType, MTLSPoP);
346+
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
347+
Assert.AreEqual(TokenSource.Cache, result2.AuthenticationResultMetadata.TokenSource);*/
343348
#endregion Identity 2
344349

345350
// TODO: Assert.AreEqual(CertificateCache.Count, 2);
@@ -359,30 +364,30 @@ public async Task mTLSPopTokenIsReAcquiredWhenCertificatIsExpired(
359364
{
360365
var managedIdentityApp = await CreateManagedIdentityAsync(httpManager, userAssignedIdentityId, userAssignedId).ConfigureAwait(false);
361366

362-
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredRawCertificate/*, mTLSPop: true*/); // TODO: implement mTLS Pop
367+
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, TestConstants.ExpiredRawCertificate, mTLSPop: true);
363368

364369
var result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
365-
// .WithMtlsProofOfPossession() // TODO: implement mTLS Pop
370+
.WithMtlsProofOfPossession()
366371
.ExecuteAsync().ConfigureAwait(false);
367372

368373
Assert.IsNotNull(result);
369374
Assert.IsNotNull(result.AccessToken);
370-
// Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop
371-
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop
375+
Assert.AreEqual(result.TokenType, MTLSPoP);
376+
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
372377
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
373378

374379
// TODO: Add functionality to check cert expiration in the cache
375380
/**
376-
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, // mTLSPop: true); // TODO: implement mTLS Pop
381+
AddMocksToGetEntraToken(httpManager, userAssignedIdentityId, userAssignedId, mTLSPop: true);
377382
378383
result = await managedIdentityApp.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
379-
// .WithMtlsProofOfPossession() // TODO: implement mTLS Pop
384+
.WithMtlsProofOfPossession()
380385
.ExecuteAsync().ConfigureAwait(false);
381386
382387
Assert.IsNotNull(result);
383388
Assert.IsNotNull(result.AccessToken);
384-
// Assert.AreEqual(result.TokenType, MTLSPoP); // TODO: implement mTLS Pop
385-
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop
389+
Assert.AreEqual(result.TokenType, MTLSPoP);
390+
// Assert.IsNotNull(result.BindingCertificate); // TODO: implement mTLS Pop BindingCertificate
386391
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
387392
388393
Assert.AreEqual(CertificateCache.Count, 1); // expired cert was removed from the cache
@@ -484,7 +489,7 @@ public void TestCsrGeneration_OnlyVmId()
484489
{
485490
VmId = TestConstants.VmId
486491
};
487-
492+
488493
var rsa = InMemoryManagedIdentityKeyProvider.CreateRsaKeyPair();
489494
var (csr, _) = Csr.Generate(rsa, TestConstants.ClientId, TestConstants.TenantId, cuid);
490495
CsrValidator.ValidateCsrContent(csr, TestConstants.ClientId, TestConstants.TenantId, cuid);

0 commit comments

Comments
 (0)