Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
/// The purpose of this cache is to allow re-use of Access Tokens fetched for a user interactively or with any other mode
/// to avoid interactive authentication request every-time, within application scope making use of MSAL's userTokenCache.
/// </summary>
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
private static readonly ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap = new();
private static readonly ConcurrentDictionary<TokenCredentialKey, TokenCredentialData> s_tokenCredentialMap = new();
private static SemaphoreSlim s_pcaMapModifierSemaphore = new(1, 1);
private static SemaphoreSlim s_tokenCredentialMapModifierSemaphore = new(1, 1);
private static readonly int s_accountPwCacheTtlInHours = 2;
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
private readonly SqlClientLogger _logger = new SqlClientLogger();
private readonly SqlClientLogger _logger = new();
private Func<DeviceCodeResult, Task> _deviceCodeFlowCallback;
private ICustomWebUi _customWebUI = null;
private readonly string _applicationClientId = ActiveDirectoryAuthentication.AdoClientId;
Expand Down Expand Up @@ -66,6 +68,11 @@ public static void ClearUserTokenCache()
{
s_pcaMap.Clear();
}

if (!s_tokenCredentialMap.IsEmpty)
{
s_tokenCredentialMap.Clear();
}
}

/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/SetDeviceCodeFlowCallback/*'/>
Expand Down Expand Up @@ -145,50 +152,40 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
* More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
**/

int seperatorIndex = parameters.Authority.LastIndexOf('/');
string authority = parameters.Authority.Remove(seperatorIndex + 1);
string audience = parameters.Authority.Substring(seperatorIndex + 1);
int separatorIndex = parameters.Authority.LastIndexOf('/');
string authority = parameters.Authority.Remove(separatorIndex + 1);
string audience = parameters.Authority.Substring(separatorIndex + 1);
string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId;

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault)
{
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new()
{
AuthorityHost = new Uri(authority),
SharedTokenCacheTenantId = audience,
VisualStudioCodeTenantId = audience,
VisualStudioTenantId = audience,
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
};

// Optionally set clientId when available
if (clientId is not null)
{
defaultAzureCredentialOptions.ManagedIdentityClientId = clientId;
defaultAzureCredentialOptions.SharedTokenCacheUsername = clientId;
}
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
// Cache DefaultAzureCredenial based on scope, authority, audience, and clientId
TokenCredentialKey tokenCredentialKey = new(typeof(DefaultAzureCredential), authority, scope, audience, clientId);
AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}

TokenCredentialOptions tokenCredentialOptions = new TokenCredentialOptions() { AuthorityHost = new Uri(authority) };
TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(authority) };

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI)
{
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
// Cache ManagedIdentityCredential based on scope, authority, and clientId
TokenCredentialKey tokenCredentialKey = new(typeof(ManagedIdentityCredential), authority, scope, string.Empty, clientId);
AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, string.Empty, tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}

AuthenticationResult result = null;
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
{
AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
// Cache ClientSecretCredential based on scope, authority, audience, and clientId
TokenCredentialKey tokenCredentialKey = new(typeof(ClientSecretCredential), authority, scope, audience, clientId);
AccessToken accessToken = await GetTokenAsync(tokenCredentialKey, parameters.Password, tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}

/*
* Today, MSAL.NET uses another redirect URI by default in desktop applications that run on Windows
* (urn:ietf:wg:oauth:2.0:oob). In the future, we'll want to change this default, so we recommend
Expand All @@ -204,7 +201,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
redirectUri = "http://localhost";
}
#endif
PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId
PublicClientAppKey pcaKey = new(parameters.Authority, redirectUri, _applicationClientId
#if NETFRAMEWORK
, _iWin32WindowFunc
#endif
Expand All @@ -213,7 +210,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
#endif
);

IPublicClientApplication app = GetPublicClientAppInstance(pcaKey);
IPublicClientApplication app = await GetPublicClientAppInstanceAsync(pcaKey, cts.Token).ConfigureAwait(false);

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
Expand Down Expand Up @@ -248,7 +245,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
if (null != previousPw &&
previousPw is byte[] previousPwBytes &&
// Only get the cached token if the current password hash matches the previously used password hash
currPwHash.SequenceEqual(previousPwBytes))
AreEqual(currPwHash, previousPwBytes))
{
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
}
Expand Down Expand Up @@ -353,7 +350,7 @@ private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlo
{
if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive)
{
CancellationTokenSource ctsInteractive = new CancellationTokenSource();
CancellationTokenSource ctsInteractive = new();
#if NETCOREAPP
/*
* On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser,
Expand Down Expand Up @@ -447,16 +444,69 @@ public Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirec
=> _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken);
}

private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey)
private async Task<IPublicClientApplication> GetPublicClientAppInstanceAsync(PublicClientAppKey publicClientAppKey, CancellationToken cancellationToken)
{
if (!s_pcaMap.TryGetValue(publicClientAppKey, out IPublicClientApplication clientApplicationInstance))
{
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
await s_pcaMapModifierSemaphore.WaitAsync(cancellationToken);
try
{
// Double-check in case another thread added it while we waited for the semaphore
if (!s_pcaMap.TryGetValue(publicClientAppKey, out clientApplicationInstance))
{
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
}
}
finally
{
s_pcaMapModifierSemaphore.Release();
}
}

return clientApplicationInstance;
}

private static async Task<AccessToken> GetTokenAsync(TokenCredentialKey tokenCredentialKey, string secret,
TokenRequestContext tokenRequestContext, CancellationToken cancellationToken)
{
if (!s_tokenCredentialMap.TryGetValue(tokenCredentialKey, out TokenCredentialData tokenCredentialInstance))
{
await s_tokenCredentialMapModifierSemaphore.WaitAsync(cancellationToken);
try
{
// Double-check in case another thread added it while we waited for the semaphore
if (!s_tokenCredentialMap.TryGetValue(tokenCredentialKey, out tokenCredentialInstance))
{
tokenCredentialInstance = CreateTokenCredentialInstance(tokenCredentialKey, secret);
s_tokenCredentialMap.TryAdd(tokenCredentialKey, tokenCredentialInstance);
}
}
finally
{
s_tokenCredentialMapModifierSemaphore.Release();
}
}

if (!AreEqual(tokenCredentialInstance._secretHash, GetHash(secret)))
{
// If the secret hash has changed, we need to remove the old token credential instance and create a new one.
await s_tokenCredentialMapModifierSemaphore.WaitAsync(cancellationToken);
try
{
s_tokenCredentialMap.TryRemove(tokenCredentialKey, out _);
tokenCredentialInstance = CreateTokenCredentialInstance(tokenCredentialKey, secret);
s_tokenCredentialMap.TryAdd(tokenCredentialKey, tokenCredentialInstance);
}
finally
{
s_tokenCredentialMapModifierSemaphore.Release();
}
}

return await tokenCredentialInstance._tokenCredential.GetTokenAsync(tokenRequestContext, cancellationToken);
}

private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters)
{
return parameters.Authority + "+" + parameters.UserId;
Expand All @@ -470,6 +520,24 @@ private static byte[] GetHash(string input)
return hashedBytes;
}

private static bool AreEqual(byte[] a1, byte[] a2)
{
if (ReferenceEquals(a1, a2))
{
return true;
}
else if (a1 is null || a2 is null)
{
return false;
}
else if (a1.Length != a2.Length)
{
return false;
}

return a1.AsSpan().SequenceEqual(a2.AsSpan());
}

private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
{
IPublicClientApplication publicClientApplication;
Expand Down Expand Up @@ -513,6 +581,59 @@ private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publ
return publicClientApplication;
}

private static TokenCredentialData CreateTokenCredentialInstance(TokenCredentialKey tokenCredentialKey, string secret)
{
if (tokenCredentialKey._tokenCredentialType == typeof(DefaultAzureCredential))
{
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new()
{
AuthorityHost = new Uri(tokenCredentialKey._authority),
SharedTokenCacheTenantId = tokenCredentialKey._audience,
VisualStudioCodeTenantId = tokenCredentialKey._audience,
VisualStudioTenantId = tokenCredentialKey._audience,
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
};

// Optionally set clientId when available
if (tokenCredentialKey._clientId is not null)
{
defaultAzureCredentialOptions.ManagedIdentityClientId = tokenCredentialKey._clientId;
defaultAzureCredentialOptions.SharedTokenCacheUsername = tokenCredentialKey._clientId;
defaultAzureCredentialOptions.WorkloadIdentityClientId = tokenCredentialKey._clientId;
}

return new TokenCredentialData(new DefaultAzureCredential(defaultAzureCredentialOptions), GetHash(secret));
}

TokenCredentialOptions tokenCredentialOptions = new() { AuthorityHost = new Uri(tokenCredentialKey._authority) };

if (tokenCredentialKey._tokenCredentialType == typeof(ManagedIdentityCredential))
{
return new TokenCredentialData(new ManagedIdentityCredential(tokenCredentialKey._clientId, tokenCredentialOptions), GetHash(secret));
}
else if (tokenCredentialKey._tokenCredentialType == typeof(ClientSecretCredential))
{
return new TokenCredentialData(new ClientSecretCredential(tokenCredentialKey._audience, tokenCredentialKey._clientId, secret, tokenCredentialOptions), GetHash(secret));
}
else if (tokenCredentialKey._tokenCredentialType == typeof(WorkloadIdentityCredential))
{
// The WorkloadIdentityCredentialOptions object initialization populates its instance members
// from the environment variables AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE,
// and AZURE_ADDITIONALLY_ALLOWED_TENANTS. AZURE_CLIENT_ID may be overridden by the User Id.
WorkloadIdentityCredentialOptions options = new() { AuthorityHost = new Uri(tokenCredentialKey._authority) };

if (tokenCredentialKey._clientId is not null)
{
options.ClientId = tokenCredentialKey._clientId;
}

return new TokenCredentialData(new WorkloadIdentityCredential(options), GetHash(secret));
}

// This should never be reached, but if it is, throw an exception that will be noticed during development
throw new ArgumentException(nameof(ActiveDirectoryAuthenticationProvider));
}

internal class PublicClientAppKey
{
public readonly string _authority;
Expand Down Expand Up @@ -572,5 +693,52 @@ public override int GetHashCode() => Tuple.Create(_authority, _redirectUri, _app
#endif
).GetHashCode();
}

internal class TokenCredentialData
{
public TokenCredential _tokenCredential;
public byte[] _secretHash;

public TokenCredentialData(TokenCredential tokenCredential, byte[] secretHash)
{
_tokenCredential = tokenCredential;
_secretHash = secretHash;
}
}

internal class TokenCredentialKey
{
public readonly Type _tokenCredentialType;
public readonly string _authority;
public readonly string _scope;
public readonly string _audience;
public readonly string _clientId;

public TokenCredentialKey(Type tokenCredentialType, string authority, string scope, string audience, string clientId)
{
_tokenCredentialType = tokenCredentialType;
_authority = authority;
_scope = scope;
_audience = audience;
_clientId = clientId;
}

public override bool Equals(object obj)
{
if (obj != null && obj is TokenCredentialKey tcKey)
{
return string.CompareOrdinal(nameof(_tokenCredentialType), nameof(tcKey._tokenCredentialType)) == 0
&& string.CompareOrdinal(_authority, tcKey._authority) == 0
&& string.CompareOrdinal(_scope, tcKey._scope) == 0
&& string.CompareOrdinal(_audience, tcKey._audience) == 0
&& string.CompareOrdinal(_clientId, tcKey._clientId) == 0
;
}
return false;
}

public override int GetHashCode() => Tuple.Create(_tokenCredentialType, _authority, _scope, _audience, _clientId).GetHashCode();
}

}
}