66using System . Collections . Concurrent ;
77using System . Linq ;
88using System . Security ;
9+ using System . Runtime . Caching ;
10+ using System . Security . Cryptography ;
11+ using System . Text ;
912using System . Threading ;
1013using System . Threading . Tasks ;
1114using Microsoft . Identity . Client ;
@@ -23,6 +26,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
2326 /// </summary>
2427 private static ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > s_pcaMap
2528 = new ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > ( ) ;
29+ private static readonly MemoryCache s_accountPwCache = new ( nameof ( ActiveDirectoryAuthenticationProvider ) ) ;
30+ private static readonly int s_accountPwCacheTtlInHours = 2 ;
2631 private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient" ;
2732 private static readonly string s_defaultScopeSuffix = "/.default" ;
2833 private readonly string _type = typeof ( ActiveDirectoryAuthenticationProvider ) . Name ;
@@ -101,7 +106,9 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
101106 /// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/AcquireTokenAsync/*'/>
102107 public override Task < SqlAuthenticationToken > AcquireTokenAsync ( SqlAuthenticationParameters parameters ) => Task . Run ( async ( ) =>
103108 {
104- AuthenticationResult result ;
109+ CancellationTokenSource cts = new ( ) ;
110+
111+ AuthenticationResult result = null ;
105112 string scope = parameters . Resource . EndsWith ( s_defaultScopeSuffix ) ? parameters . Resource : parameters . Resource + s_defaultScopeSuffix ;
106113 string [ ] scopes = new string [ ] { scope } ;
107114
@@ -147,69 +154,84 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
147154
148155 if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryIntegrated )
149156 {
150- if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
151- {
152- result = app . AcquireTokenByIntegratedWindowsAuth ( scopes )
153- . WithCorrelationId ( parameters . ConnectionId )
154- . WithUsername ( parameters . UserId )
155- . ExecuteAsync ( ) . Result ;
156- }
157- else
157+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
158+
159+ if ( result == null )
158160 {
159- result = app . AcquireTokenByIntegratedWindowsAuth ( scopes )
160- . WithCorrelationId ( parameters . ConnectionId )
161- . ExecuteAsync ( ) . Result ;
161+ if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
162+ {
163+ result = app . AcquireTokenByIntegratedWindowsAuth ( scopes )
164+ . WithCorrelationId ( parameters . ConnectionId )
165+ . WithUsername ( parameters . UserId )
166+ . ExecuteAsync ( cancellationToken : cts . Token ) . Result ;
167+ }
168+ else
169+ {
170+ result = app . AcquireTokenByIntegratedWindowsAuth ( scopes )
171+ . WithCorrelationId ( parameters . ConnectionId )
172+ . ExecuteAsync ( cancellationToken : cts . Token ) . Result ;
173+ }
174+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result . ExpiresOn ) ;
162175 }
163- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result . ExpiresOn ) ;
164176 }
165177 else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
166178 {
167- SecureString password = new SecureString ( ) ;
168- foreach ( char c in parameters . Password )
169- password . AppendChar ( c ) ;
170- password . MakeReadOnly ( ) ;
171- result = app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , password )
172- . WithCorrelationId ( parameters . ConnectionId )
173- . ExecuteAsync ( ) . Result ;
174- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result . ExpiresOn ) ;
179+ string pwCacheKey = GetAccountPwCacheKey ( parameters ) ;
180+ object previousPw = s_accountPwCache . Get ( pwCacheKey ) ;
181+ byte [ ] currPwHash = GetHash ( parameters . Password ) ;
182+
183+ if ( null != previousPw &&
184+ previousPw is byte [ ] previousPwBytes &&
185+ // Only get the cached token if the current password hash matches the previously used password hash
186+ currPwHash . SequenceEqual ( previousPwBytes ) )
187+ {
188+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
189+ }
190+
191+ if ( result == null )
192+ {
193+ SecureString password = new SecureString ( ) ;
194+ foreach ( char c in parameters . Password )
195+ password . AppendChar ( c ) ;
196+ password . MakeReadOnly ( ) ;
197+ result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , password )
198+ . WithCorrelationId ( parameters . ConnectionId )
199+ . ExecuteAsync ( )
200+ . ConfigureAwait ( false ) ;
201+
202+ // We cache the password hash to ensure future connection requests include a validated password
203+ // when we check for a cached MSAL account. Otherwise, a connection request with the same username
204+ // against the same tenant could succeed with an invalid password when we re-use the cached token.
205+ if ( ! s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) )
206+ {
207+ s_accountPwCache . Remove ( pwCacheKey ) ;
208+ s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) ;
209+ }
210+
211+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result . ExpiresOn ) ;
212+ }
175213 }
176214 else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
177215 parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
178216 {
179- // Fetch available accounts from 'app' instance
180- System . Collections . Generic . IEnumerable < IAccount > accounts = await app . GetAccountsAsync ( ) ;
181- IAccount account ;
182- if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
217+ try
183218 {
184- account = accounts . FirstOrDefault ( a => parameters . UserId . Equals ( a . Username , System . StringComparison . InvariantCultureIgnoreCase ) ) ;
219+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
220+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result . ExpiresOn ) ;
185221 }
186- else
222+ catch ( MsalUiRequiredException )
187223 {
188- account = accounts . FirstOrDefault ( ) ;
224+ // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
225+ // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
226+ // or the user needs to perform two factor authentication.
227+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) ;
228+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result . ExpiresOn ) ;
189229 }
190230
191- if ( null != account )
192- {
193- try
194- {
195- // If 'account' is available in 'app', we use the same to acquire token silently.
196- // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
197- result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( ) ;
198- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result . ExpiresOn ) ;
199- }
200- catch ( MsalUiRequiredException )
201- {
202- // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
203- // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
204- // or the user needs to perform two factor authentication.
205- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod ) ;
206- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result . ExpiresOn ) ;
207- }
208- }
209- else
231+ if ( result == null )
210232 {
211233 // If no existing 'account' is found, we request user to sign in interactively.
212- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod ) ;
234+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) ;
213235 SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result . ExpiresOn ) ;
214236 }
215237 }
@@ -222,11 +244,58 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
222244 return new SqlAuthenticationToken ( result . AccessToken , result . ExpiresOn ) ;
223245 } ) ;
224246
247+ private static async Task < AuthenticationResult > TryAcquireTokenSilent ( IPublicClientApplication app ,
248+ SqlAuthenticationParameters parameters ,
249+ string [ ] scopes ,
250+ CancellationTokenSource cts )
251+ {
252+ AuthenticationResult result = null ;
253+
254+ // Fetch available accounts from 'app' instance
255+ System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) . ConfigureAwait ( false ) ) . GetEnumerator ( ) ;
256+
257+ IAccount account = default ;
258+ if ( accounts . MoveNext ( ) )
259+ {
260+ if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
261+ {
262+ do
263+ {
264+ IAccount currentVal = accounts . Current ;
265+ if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
266+ {
267+ account = currentVal ;
268+ break ;
269+ }
270+ }
271+ while ( accounts . MoveNext ( ) ) ;
272+ }
273+ else
274+ {
275+ account = accounts . Current ;
276+ }
277+ }
278+
279+ if ( null != account )
280+ {
281+ // If 'account' is available in 'app', we use the same to acquire token silently.
282+ // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
283+ result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) . ConfigureAwait ( false ) ;
284+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
285+ }
286+
287+ return result ;
288+ }
225289
226- private async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
227- SqlAuthenticationMethod authenticationMethod )
290+ private static async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app ,
291+ string [ ] scopes ,
292+ Guid connectionId ,
293+ string userId ,
294+ SqlAuthenticationMethod authenticationMethod ,
295+ CancellationTokenSource cts ,
296+ ICustomWebUi customWebUI ,
297+ Func < DeviceCodeResult , Task > deviceCodeFlowCallback )
228298 {
229- CancellationTokenSource cts = new CancellationTokenSource ( ) ;
230299#if NETCOREAPP
231300 /*
232301 * On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser,
@@ -243,11 +312,11 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
243312 {
244313 if ( authenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive )
245314 {
246- if ( _customWebUI != null )
315+ if ( customWebUI != null )
247316 {
248317 return await app . AcquireTokenInteractive ( scopes )
249318 . WithCorrelationId ( connectionId )
250- . WithCustomWebUi ( _customWebUI )
319+ . WithCustomWebUi ( customWebUI )
251320 . WithLoginHint ( userId )
252321 . ExecuteAsync ( cts . Token ) ;
253322 }
@@ -279,7 +348,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
279348 else
280349 {
281350 AuthenticationResult result = await app . AcquireTokenWithDeviceCode ( scopes ,
282- deviceCodeResult => _deviceCodeFlowCallback ( deviceCodeResult ) ) . ExecuteAsync ( ) ;
351+ deviceCodeResult => deviceCodeFlowCallback ( deviceCodeResult ) ) . ExecuteAsync ( ) ;
283352 return result ;
284353 }
285354 }
@@ -329,6 +398,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
329398 return clientApplicationInstance ;
330399 }
331400
401+ private static string GetAccountPwCacheKey ( SqlAuthenticationParameters parameters )
402+ {
403+ return parameters . Authority + "+" + parameters . UserId ;
404+ }
405+
406+ private static byte [ ] GetHash ( string input )
407+ {
408+ byte [ ] unhashedBytes = Encoding . Unicode . GetBytes ( input ) ;
409+ SHA256 sha256 = SHA256 . Create ( ) ;
410+ byte [ ] hashedBytes = sha256 . ComputeHash ( unhashedBytes ) ;
411+ return hashedBytes ;
412+ }
413+
332414 private IPublicClientApplication CreateClientAppInstance ( PublicClientAppKey publicClientAppKey )
333415 {
334416 IPublicClientApplication publicClientApplication ;
0 commit comments