Skip to content

Commit 2e3f39e

Browse files
Fix access token behavior in connection pool (#443)
* Initial test changes * Fix access token behavior in connection pool * Compare ordinals for strings * Access token only
1 parent a9dfe48 commit 2e3f39e

File tree

6 files changed

+171
-72
lines changed

6 files changed

+171
-72
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionPoolKey.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ internal string AccessToken
6363
public override bool Equals(object obj)
6464
{
6565
SqlConnectionPoolKey key = obj as SqlConnectionPoolKey;
66-
return (key != null && _credential == key._credential && ConnectionString == key.ConnectionString && Object.ReferenceEquals(_accessToken, key._accessToken));
66+
return (key != null
67+
&& _credential == key._credential
68+
&& ConnectionString == key.ConnectionString
69+
&& string.CompareOrdinal(_accessToken, key._accessToken) == 0);
6770
}
6871

6972
public override int GetHashCode()

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnectionPoolKey.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ public override bool Equals(object obj)
110110
return (key != null &&
111111
_credential == key._credential &&
112112
ConnectionString == key.ConnectionString &&
113-
Object.ReferenceEquals(_accessToken, key._accessToken) &&
113+
string.CompareOrdinal(_accessToken, key._accessToken) == 0 &&
114114
_serverCertificateValidationCallback == key._serverCertificateValidationCallback &&
115115
_clientCertificateRetrievalCallback == key._clientCertificateRetrievalCallback &&
116116
_originalNetworkAddressInfo == key._originalNetworkAddressInfo);

src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
using System.IO;
1212
using System.Linq;
1313
using System.Reflection;
14+
using System.Security;
1415
using System.Text;
1516
using System.Threading;
1617
using System.Threading.Tasks;
18+
using Microsoft.Identity.Client;
1719
using Newtonsoft.Json;
1820
using Xunit;
1921

@@ -26,8 +28,9 @@ public static class DataTestUtility
2628
public static readonly string TCPConnectionStringHGSVBS = null;
2729
public static readonly string TCPConnectionStringAASVBS = null;
2830
public static readonly string TCPConnectionStringAASSGX = null;
29-
public static readonly string AADAccessToken = null;
31+
public static readonly string AADAuthorityURL = null;
3032
public static readonly string AADPasswordConnectionString = null;
33+
public static readonly string AADAccessToken = null;
3134
public static readonly string AKVBaseUrl = null;
3235
public static readonly string AKVUrl = null;
3336
public static readonly string AKVClientId = null;
@@ -60,7 +63,7 @@ private class Config
6063
public string TCPConnectionStringHGSVBS = null;
6164
public string TCPConnectionStringAASVBS = null;
6265
public string TCPConnectionStringAASSGX = null;
63-
public string AADAccessToken = null;
66+
public string AADAuthorityURL = null;
6467
public string AADPasswordConnectionString = null;
6568
public string AzureKeyVaultURL = null;
6669
public string AzureKeyVaultClientId = null;
@@ -83,13 +86,20 @@ static DataTestUtility()
8386
TCPConnectionStringHGSVBS = c.TCPConnectionStringHGSVBS;
8487
TCPConnectionStringAASVBS = c.TCPConnectionStringAASVBS;
8588
TCPConnectionStringAASSGX = c.TCPConnectionStringAASSGX;
86-
AADAccessToken = c.AADAccessToken;
89+
AADAuthorityURL = c.AADAuthorityURL;
8790
AADPasswordConnectionString = c.AADPasswordConnectionString;
8891
SupportsLocalDb = c.SupportsLocalDb;
8992
SupportsIntegratedSecurity = c.SupportsIntegratedSecurity;
9093
SupportsFileStream = c.SupportsFileStream;
9194
EnclaveEnabled = c.EnclaveEnabled;
9295

96+
if (IsAADPasswordConnStrSetup() && IsAADAuthorityURLSetup())
97+
{
98+
string username = RetrieveValueFromConnStr(AADPasswordConnectionString, new string[] { "User ID", "UID" });
99+
string password = RetrieveValueFromConnStr(AADPasswordConnectionString, new string[] { "Password", "PWD" });
100+
AADAccessToken = GenerateAccessToken(AADAuthorityURL, username, password);
101+
}
102+
93103
string url = c.AzureKeyVaultURL;
94104
Uri AKVBaseUri = null;
95105
if (!string.IsNullOrEmpty(url) && Uri.TryCreate(url, UriKind.Absolute, out AKVBaseUri))
@@ -134,6 +144,41 @@ static DataTestUtility()
134144
}
135145
}
136146

147+
private static string GenerateAccessToken(string authorityURL, string aADAuthUserID, string aADAuthPassword)
148+
{
149+
return AcquireTokenAsync(authorityURL, aADAuthUserID, aADAuthPassword).Result;
150+
}
151+
152+
private static Task<string> AcquireTokenAsync(string authorityURL, string userID, string password) => Task.Run(() =>
153+
{
154+
// The below properties are set specific to test configurations.
155+
string scope = "https://database.windows.net//.default";
156+
string applicationName = "Microsoft Data SqlClient Manual Tests";
157+
string clientVersion = "1.0.0.0";
158+
string adoClientId = "4d079b4c-cab7-4b7c-a115-8fd51b6f8239";
159+
160+
IPublicClientApplication app = PublicClientApplicationBuilder.Create(adoClientId)
161+
.WithAuthority(authorityURL)
162+
.WithClientName(applicationName)
163+
.WithClientVersion(clientVersion)
164+
.Build();
165+
AuthenticationResult result;
166+
string[] scopes = new string[] { scope };
167+
168+
// Note: CorrelationId, which existed in ADAL, can not be set in MSAL (yet?).
169+
// parameter.ConnectionId was passed as the CorrelationId in ADAL to aid support in troubleshooting.
170+
// If/When MSAL adds CorrelationId support, it should be passed from parameters here, too.
171+
172+
SecureString securePassword = new SecureString();
173+
174+
foreach (char c in password)
175+
securePassword.AppendChar(c);
176+
securePassword.MakeReadOnly();
177+
result = app.AcquireTokenByUsernamePassword(scopes, userID, securePassword).ExecuteAsync().Result;
178+
179+
return result.AccessToken;
180+
});
181+
137182
public static bool IsDatabasePresent(string name)
138183
{
139184
AvailableDatabases = AvailableDatabases ?? new Dictionary<string, bool>();
@@ -171,6 +216,11 @@ public static bool IsAADPasswordConnStrSetup()
171216
return !string.IsNullOrEmpty(AADPasswordConnectionString);
172217
}
173218

219+
public static bool IsAADAuthorityURLSetup()
220+
{
221+
return !string.IsNullOrEmpty(AADAuthorityURL);
222+
}
223+
174224
public static bool IsNotAzureServer()
175225
{
176226
return AreConnStringsSetup() ? !DataTestUtility.IsAzureSqlServer(new SqlConnectionStringBuilder((DataTestUtility.TCPConnectionString)).DataSource) : true;
@@ -248,10 +298,11 @@ public static string GetUniqueNameForSqlServer(string prefix)
248298

249299
public static string GetAccessToken()
250300
{
251-
return AADAccessToken;
301+
// Creates a new Object Reference of Access Token - See GitHub Issue 438
302+
return (null != AADAccessToken) ? new string(AADAccessToken.ToCharArray()) : null;
252303
}
253304

254-
public static bool IsAccessTokenSetup() => string.IsNullOrEmpty(GetAccessToken()) ? false : true;
305+
public static bool IsAccessTokenSetup() => !string.IsNullOrEmpty(GetAccessToken());
255306

256307
public static bool IsFileStreamSetup() => SupportsFileStream;
257308

@@ -519,5 +570,54 @@ public static string GetValueString(object paramValue)
519570

520571
return paramValue.ToString();
521572
}
573+
574+
public static string RemoveKeysInConnStr(string connStr, string[] keysToRemove)
575+
{
576+
// tokenize connection string and remove input keys.
577+
string res = "";
578+
string[] keys = connStr.Split(';');
579+
foreach (var key in keys)
580+
{
581+
if (!string.IsNullOrEmpty(key.Trim()))
582+
{
583+
bool removeKey = false;
584+
foreach (var keyToRemove in keysToRemove)
585+
{
586+
if (key.Trim().ToLower().StartsWith(keyToRemove.Trim().ToLower()))
587+
{
588+
removeKey = true;
589+
break;
590+
}
591+
}
592+
if (!removeKey)
593+
{
594+
res += key + ";";
595+
}
596+
}
597+
}
598+
return res;
599+
}
600+
601+
public static string RetrieveValueFromConnStr(string connStr, string[] keywords)
602+
{
603+
// tokenize connection string and retrieve value for a specific key.
604+
string res = "";
605+
string[] keys = connStr.Split(';');
606+
foreach (var key in keys)
607+
{
608+
foreach (var keyword in keywords)
609+
{
610+
if (!string.IsNullOrEmpty(key.Trim()))
611+
{
612+
if (key.Trim().ToLower().StartsWith(keyword.Trim().ToLower()))
613+
{
614+
res = key.Substring(key.IndexOf('=') + 1).Trim();
615+
break;
616+
}
617+
}
618+
}
619+
}
620+
return res;
621+
}
522622
}
523623
}

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionPoolTest/ConnectionPoolTest.cs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,46 @@ private static void BasicConnectionPoolingTest(string connectionString)
6464
connection3.Close();
6565

6666
connectionPool.Cleanup();
67+
6768
SqlConnection connection4 = new SqlConnection(connectionString);
69+
connection4.Open();
70+
Assert.True(internalConnection.IsInternalConnectionOf(connection4), "New connection does not use same internal connection");
71+
Assert.True(connectionPool.ContainsConnection(connection4), "New connection is in a different pool");
72+
connection4.Close();
73+
}
74+
75+
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAADPasswordConnStrSetup), nameof(DataTestUtility.IsAADAuthorityURLSetup))]
76+
public static void AccessTokenConnectionPoolingTest()
77+
{
78+
// Remove cred info and add invalid token
79+
string[] credKeys = { "User ID", "Password", "UID", "PWD", "Authentication" };
80+
string connectionString = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, credKeys);
6881

82+
SqlConnection connection = new SqlConnection(connectionString);
83+
connection.AccessToken = DataTestUtility.GetAccessToken();
84+
connection.Open();
85+
InternalConnectionWrapper internalConnection = new InternalConnectionWrapper(connection);
86+
ConnectionPoolWrapper connectionPool = new ConnectionPoolWrapper(connection);
87+
connection.Close();
88+
89+
SqlConnection connection2 = new SqlConnection(connectionString);
90+
connection2.AccessToken = DataTestUtility.GetAccessToken();
91+
connection2.Open();
92+
Assert.True(internalConnection.IsInternalConnectionOf(connection2), "New connection does not use same internal connection");
93+
Assert.True(connectionPool.ContainsConnection(connection2), "New connection is in a different pool");
94+
connection2.Close();
95+
96+
SqlConnection connection3 = new SqlConnection(connectionString + ";App=SqlConnectionPoolUnitTest;");
97+
connection3.AccessToken = DataTestUtility.GetAccessToken();
98+
connection3.Open();
99+
Assert.False(internalConnection.IsInternalConnectionOf(connection3), "Connection with different connection string uses same internal connection");
100+
Assert.False(connectionPool.ContainsConnection(connection3), "Connection with different connection string uses same connection pool");
101+
connection3.Close();
102+
103+
connectionPool.Cleanup();
104+
105+
SqlConnection connection4 = new SqlConnection(connectionString);
106+
connection4.AccessToken = DataTestUtility.GetAccessToken();
69107
connection4.Open();
70108
Assert.True(internalConnection.IsInternalConnectionOf(connection4), "New connection does not use same internal connection");
71109
Assert.True(connectionPool.ContainsConnection(connection4), "New connection is in a different pool");

0 commit comments

Comments
 (0)