Skip to content

Commit ad560a8

Browse files
Back port dotnet#1925
1 parent acfdeca commit ad560a8

File tree

1 file changed

+136
-54
lines changed

1 file changed

+136
-54
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

Lines changed: 136 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
using System.Collections.Concurrent;
77
using System.Linq;
88
using System.Security;
9+
using System.Runtime.Caching;
10+
using System.Security.Cryptography;
11+
using System.Text;
912
using System.Threading;
1013
using System.Threading.Tasks;
1114
using Microsoft.Identity.Client;
@@ -23,6 +26,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
2326
/// </summary>
2427
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
2528
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
29+
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
30+
private static readonly int s_accountPwCacheTtlInHours = 2;
2631
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
2732
private static readonly string s_defaultScopeSuffix = "/.default";
2833
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
@@ -101,7 +106,9 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
101106
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/AcquireTokenAsync/*'/>
102107
public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters) => Task.Run(async () =>
103108
{
104-
AuthenticationResult result;
109+
CancellationTokenSource cts = new();
110+
111+
AuthenticationResult result = null;
105112
string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix;
106113
string[] scopes = new string[] { scope };
107114

@@ -147,69 +154,84 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
147154

148155
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
149156
{
150-
if (!string.IsNullOrEmpty(parameters.UserId))
151-
{
152-
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
153-
.WithCorrelationId(parameters.ConnectionId)
154-
.WithUsername(parameters.UserId)
155-
.ExecuteAsync().Result;
156-
}
157-
else
157+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
158+
159+
if (result == null)
158160
{
159-
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
160-
.WithCorrelationId(parameters.ConnectionId)
161-
.ExecuteAsync().Result;
161+
if (!string.IsNullOrEmpty(parameters.UserId))
162+
{
163+
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
164+
.WithCorrelationId(parameters.ConnectionId)
165+
.WithUsername(parameters.UserId)
166+
.ExecuteAsync(cancellationToken: cts.Token).Result;
167+
}
168+
else
169+
{
170+
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
171+
.WithCorrelationId(parameters.ConnectionId)
172+
.ExecuteAsync(cancellationToken: cts.Token).Result;
173+
}
174+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result.ExpiresOn);
162175
}
163-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result.ExpiresOn);
164176
}
165177
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
166178
{
167-
SecureString password = new SecureString();
168-
foreach (char c in parameters.Password)
169-
password.AppendChar(c);
170-
password.MakeReadOnly();
171-
result = app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
172-
.WithCorrelationId(parameters.ConnectionId)
173-
.ExecuteAsync().Result;
174-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result.ExpiresOn);
179+
string pwCacheKey = GetAccountPwCacheKey(parameters);
180+
object previousPw = s_accountPwCache.Get(pwCacheKey);
181+
byte[] currPwHash = GetHash(parameters.Password);
182+
183+
if (null != previousPw &&
184+
previousPw is byte[] previousPwBytes &&
185+
// Only get the cached token if the current password hash matches the previously used password hash
186+
currPwHash.SequenceEqual(previousPwBytes))
187+
{
188+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
189+
}
190+
191+
if (result == null)
192+
{
193+
SecureString password = new SecureString();
194+
foreach (char c in parameters.Password)
195+
password.AppendChar(c);
196+
password.MakeReadOnly();
197+
result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
198+
.WithCorrelationId(parameters.ConnectionId)
199+
.ExecuteAsync()
200+
.ConfigureAwait(false);
201+
202+
// We cache the password hash to ensure future connection requests include a validated password
203+
// when we check for a cached MSAL account. Otherwise, a connection request with the same username
204+
// against the same tenant could succeed with an invalid password when we re-use the cached token.
205+
if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)))
206+
{
207+
s_accountPwCache.Remove(pwCacheKey);
208+
s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours));
209+
}
210+
211+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result.ExpiresOn);
212+
}
175213
}
176214
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
177215
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
178216
{
179-
// Fetch available accounts from 'app' instance
180-
System.Collections.Generic.IEnumerable<IAccount> accounts = await app.GetAccountsAsync();
181-
IAccount account;
182-
if (!string.IsNullOrEmpty(parameters.UserId))
217+
try
183218
{
184-
account = accounts.FirstOrDefault(a => parameters.UserId.Equals(a.Username, System.StringComparison.InvariantCultureIgnoreCase));
219+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
220+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
185221
}
186-
else
222+
catch (MsalUiRequiredException)
187223
{
188-
account = accounts.FirstOrDefault();
224+
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
225+
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
226+
// or the user needs to perform two factor authentication.
227+
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback);
228+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
189229
}
190230

191-
if (null != account)
192-
{
193-
try
194-
{
195-
// If 'account' is available in 'app', we use the same to acquire token silently.
196-
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
197-
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync();
198-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
199-
}
200-
catch (MsalUiRequiredException)
201-
{
202-
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
203-
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
204-
// or the user needs to perform two factor authentication.
205-
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
206-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
207-
}
208-
}
209-
else
231+
if (result == null)
210232
{
211233
// If no existing 'account' is found, we request user to sign in interactively.
212-
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
234+
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback);
213235
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
214236
}
215237
}
@@ -222,11 +244,58 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
222244
return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
223245
});
224246

247+
private static async Task<AuthenticationResult> TryAcquireTokenSilent(IPublicClientApplication app,
248+
SqlAuthenticationParameters parameters,
249+
string[] scopes,
250+
CancellationTokenSource cts)
251+
{
252+
AuthenticationResult result = null;
253+
254+
// Fetch available accounts from 'app' instance
255+
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
256+
257+
IAccount account = default;
258+
if (accounts.MoveNext())
259+
{
260+
if (!string.IsNullOrEmpty(parameters.UserId))
261+
{
262+
do
263+
{
264+
IAccount currentVal = accounts.Current;
265+
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
266+
{
267+
account = currentVal;
268+
break;
269+
}
270+
}
271+
while (accounts.MoveNext());
272+
}
273+
else
274+
{
275+
account = accounts.Current;
276+
}
277+
}
278+
279+
if (null != account)
280+
{
281+
// If 'account' is available in 'app', we use the same to acquire token silently.
282+
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
283+
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
284+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
285+
}
286+
287+
return result;
288+
}
225289

226-
private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
227-
SqlAuthenticationMethod authenticationMethod)
290+
private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app,
291+
string[] scopes,
292+
Guid connectionId,
293+
string userId,
294+
SqlAuthenticationMethod authenticationMethod,
295+
CancellationTokenSource cts,
296+
ICustomWebUi customWebUI,
297+
Func<DeviceCodeResult, Task> deviceCodeFlowCallback)
228298
{
229-
CancellationTokenSource cts = new CancellationTokenSource();
230299
#if NETCOREAPP
231300
/*
232301
* On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser,
@@ -243,11 +312,11 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
243312
{
244313
if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive)
245314
{
246-
if (_customWebUI != null)
315+
if (customWebUI != null)
247316
{
248317
return await app.AcquireTokenInteractive(scopes)
249318
.WithCorrelationId(connectionId)
250-
.WithCustomWebUi(_customWebUI)
319+
.WithCustomWebUi(customWebUI)
251320
.WithLoginHint(userId)
252321
.ExecuteAsync(cts.Token);
253322
}
@@ -279,7 +348,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
279348
else
280349
{
281350
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
282-
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult)).ExecuteAsync();
351+
deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult)).ExecuteAsync();
283352
return result;
284353
}
285354
}
@@ -329,6 +398,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
329398
return clientApplicationInstance;
330399
}
331400

401+
private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters)
402+
{
403+
return parameters.Authority + "+" + parameters.UserId;
404+
}
405+
406+
private static byte[] GetHash(string input)
407+
{
408+
byte[] unhashedBytes = Encoding.Unicode.GetBytes(input);
409+
SHA256 sha256 = SHA256.Create();
410+
byte[] hashedBytes = sha256.ComputeHash(unhashedBytes);
411+
return hashedBytes;
412+
}
413+
332414
private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
333415
{
334416
IPublicClientApplication publicClientApplication;

0 commit comments

Comments
 (0)