Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
using Azure.Security.KeyVault.Keys.Cryptography;
using System;
using System.Collections.Concurrent;
using System.Threading.Tasks;
using System.Threading;
using static Azure.Security.KeyVault.Keys.Cryptography.SignatureAlgorithm;

namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
{
internal class AzureSqlKeyCryptographer
internal sealed class AzureSqlKeyCryptographer : IDisposable
{
/// <summary>
/// TokenCredential to be used with the KeyClient
Expand All @@ -25,16 +25,14 @@ internal class AzureSqlKeyCryptographer
private readonly ConcurrentDictionary<Uri, KeyClient> _keyClientDictionary = new();

/// <summary>
/// Holds references to the fetch key tasks and maps them to their corresponding Azure Key Vault Key Identifier (URI).
/// These tasks will be used for returning the key in the event that the fetch task has not finished depositing the
/// key into the key dictionary.
/// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI).
/// </summary>
private readonly ConcurrentDictionary<string, Task<Azure.Response<KeyVaultKey>>> _keyFetchTaskDictionary = new();
private readonly ConcurrentDictionary<string, KeyVaultKey> _keyDictionary = new();

/// <summary>
/// Holds references to the Azure Key Vault keys and maps them to their corresponding Azure Key Vault Key Identifier (URI).
/// SemaphoreSlim to ensure thread safety when accessing the key dictionary or making network calls to Azure Key Vault to fetch keys.
/// </summary>
private readonly ConcurrentDictionary<string, KeyVaultKey> _keyDictionary = new();
private SemaphoreSlim _keyDictionarySemaphore = new(1, 1);

/// <summary>
/// Holds references to the Azure Key Vault CryptographyClient objects and maps them to their corresponding Azure Key Vault Key Identifier (URI).
Expand All @@ -50,20 +48,44 @@ internal AzureSqlKeyCryptographer(TokenCredential tokenCredential)
TokenCredential = tokenCredential;
}

/// <summary>
/// Disposes the SemaphoreSlim used for thread safety.
/// </summary>
public void Dispose()
{
_keyDictionarySemaphore.Dispose();
}

/// <summary>
/// Adds the key, specified by the Key Identifier URI, to the cache.
/// Validates the key type and fetches the key from Azure Key Vault if it is not already cached.
/// </summary>
/// <param name="keyIdentifierUri"></param>
internal void AddKey(string keyIdentifierUri)
{
if (TheKeyHasNotBeenCached(keyIdentifierUri))
// Allow only one thread to proceed to ensure thread safety
// as we will need to fetch key information from Azure Key Vault if the key is not found in cache.
_keyDictionarySemaphore.Wait();

try
{
ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion);
CreateKeyClient(vaultUri);
FetchKey(vaultUri, keyName, keyVersion, keyIdentifierUri);
}
if (!_keyDictionary.ContainsKey(keyIdentifierUri))
{
ParseAKVPath(keyIdentifierUri, out Uri vaultUri, out string keyName, out string keyVersion);

// Fetch the KeyClient for the Key vault URI.
KeyClient keyClient = GetOrCreateKeyClient(vaultUri);

// Fetch the key from Azure Key Vault.
KeyVaultKey key = FetchKeyFromKeyVault(keyClient, keyName, keyVersion);

bool TheKeyHasNotBeenCached(string k) => !_keyDictionary.ContainsKey(k) && !_keyFetchTaskDictionary.ContainsKey(k);
_keyDictionary.AddOrUpdate(keyIdentifierUri, key, (k, v) => key);
}
}
finally
{
_keyDictionarySemaphore.Release();
}
}

/// <summary>
Expand All @@ -75,18 +97,12 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
{
if (_keyDictionary.TryGetValue(keyIdentifierUri, out KeyVaultKey key))
{
AKVEventSource.Log.TryTraceEvent("Fetched master key from cache");
AKVEventSource.Log.TryTraceEvent("Fetched key name={0} from cache", key.Name);
return key;
}

if (_keyFetchTaskDictionary.TryGetValue(keyIdentifierUri, out Task<Azure.Response<KeyVaultKey>> task))
{
AKVEventSource.Log.TryTraceEvent("New Master key fetched.");
return Task.Run(() => task).GetAwaiter().GetResult();
}

// Not a public exception - not likely to occur.
AKVEventSource.Log.TryTraceEvent("Master key not found.");
AKVEventSource.Log.TryTraceEvent("Key not found; URI={0}", keyIdentifierUri);
throw ADP.MasterKeyNotFound(keyIdentifierUri);
}

Expand All @@ -95,10 +111,7 @@ internal KeyVaultKey GetKey(string keyIdentifierUri)
/// </summary>
/// <param name="keyIdentifierUri">The key vault key identifier URI</param>
/// <returns></returns>
internal int GetKeySize(string keyIdentifierUri)
{
return GetKey(keyIdentifierUri).Key.N.Length;
}
internal int GetKeySize(string keyIdentifierUri) => GetKey(keyIdentifierUri).Key.N.Length;

/// <summary>
/// Generates signature based on RSA PKCS#v1.5 scheme using a specified Azure Key Vault Key URL.
Expand Down Expand Up @@ -142,49 +155,58 @@ private CryptographyClient GetCryptographyClient(string keyIdentifierUri)

CryptographyClient cryptographyClient = new(GetKey(keyIdentifierUri).Id, TokenCredential);
_cryptoClientDictionary.TryAdd(keyIdentifierUri, cryptographyClient);

return cryptographyClient;
}

/// <summary>
///
/// Fetches the column encryption key from the Azure Key Vault.
/// </summary>
/// <param name="vaultUri">The Azure Key Vault URI</param>
/// <param name="keyClient">The KeyClient instance</param>
/// <param name="keyName">The name of the Azure Key Vault key</param>
/// <param name="keyVersion">The version of the Azure Key Vault key</param>
/// <param name="keyResourceUri">The Azure Key Vault key identifier</param>
private void FetchKey(Uri vaultUri, string keyName, string keyVersion, string keyResourceUri)
private KeyVaultKey FetchKeyFromKeyVault(KeyClient keyClient, string keyName, string keyVersion)
{
Task<Azure.Response<KeyVaultKey>> fetchKeyTask = FetchKeyFromKeyVault(vaultUri, keyName, keyVersion);
_keyFetchTaskDictionary.AddOrUpdate(keyResourceUri, fetchKeyTask, (k, v) => fetchKeyTask);
AKVEventSource.Log.TryTraceEvent("Fetching key name={0}", keyName);

fetchKeyTask
.ContinueWith(k => ValidateRsaKey(k.GetAwaiter().GetResult()))
.ContinueWith(k => _keyDictionary.AddOrUpdate(keyResourceUri, k.GetAwaiter().GetResult(), (key, v) => k.GetAwaiter().GetResult()));
Azure.Response<KeyVaultKey> keyResponse = keyClient?.GetKey(keyName, keyVersion);

Task.Run(() => fetchKeyTask);
// Handle the case where the key response is null or contains an error
// This can happen if the key does not exist or if there is an issue with the KeyClient.
// In such cases, we log the error and throw an exception.
if (keyResponse == null || keyResponse.Value == null || keyResponse.GetRawResponse().IsError)
{
AKVEventSource.Log.TryTraceEvent("Get Key failed to fetch Key from Azure Key Vault for key {0}, version {1}", keyName, keyVersion);
if (keyResponse?.GetRawResponse() is Azure.Response response)
{
AKVEventSource.Log.TryTraceEvent("Response status {0} : {1}", response.Status, response.ReasonPhrase);
}
throw ADP.GetKeyFailed(keyName);
}

KeyVaultKey key = keyResponse.Value;

// Validate that the key is of type RSA
key = ValidateRsaKey(key);
return key;
}

/// <summary>
/// Looks up the KeyClient object by it's URI and then fetches the key by name.
/// Gets or creates a KeyClient for the specified Azure Key Vault URI.
/// </summary>
/// <param name="vaultUri">The Azure Key Vault URI</param>
/// <param name="keyName">Then name of the key</param>
/// <param name="keyVersion">Then version of the key</param>
/// <param name="vaultUri">Key Identifier URL</param>
/// <returns></returns>
private Task<Azure.Response<KeyVaultKey>> FetchKeyFromKeyVault(Uri vaultUri, string keyName, string keyVersion)
private KeyClient GetOrCreateKeyClient(Uri vaultUri)
{
_keyClientDictionary.TryGetValue(vaultUri, out KeyClient keyClient);
AKVEventSource.Log.TryTraceEvent("Fetching requested master key: {0}", keyName);
return keyClient?.GetKeyAsync(keyName, keyVersion);
return _keyClientDictionary.GetOrAdd(
vaultUri, (_) => new KeyClient(vaultUri, TokenCredential));
}

/// <summary>
/// Validates that a key is of type RSA
/// </summary>
/// <param name="key"></param>
/// <returns></returns>
private KeyVaultKey ValidateRsaKey(KeyVaultKey key)
private static KeyVaultKey ValidateRsaKey(KeyVaultKey key)
{
if (key.KeyType != KeyType.Rsa && key.KeyType != KeyType.RsaHsm)
{
Expand All @@ -195,26 +217,14 @@ private KeyVaultKey ValidateRsaKey(KeyVaultKey key)
return key;
}

/// <summary>
/// Instantiates and adds a KeyClient to the KeyClient dictionary
/// </summary>
/// <param name="vaultUri">The Azure Key Vault URI</param>
private void CreateKeyClient(Uri vaultUri)
{
if (!_keyClientDictionary.ContainsKey(vaultUri))
{
_keyClientDictionary.TryAdd(vaultUri, new KeyClient(vaultUri, TokenCredential));
}
}

/// <summary>
/// Validates and parses the Azure Key Vault URI and key name.
/// </summary>
/// <param name="masterKeyPath">The Azure Key Vault key identifier</param>
/// <param name="vaultUri">The Azure Key Vault URI</param>
/// <param name="masterKeyName">The name of the key</param>
/// <param name="masterKeyVersion">The version of the key</param>
private void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion)
private static void ParseAKVPath(string masterKeyPath, out Uri vaultUri, out string masterKeyName, out string masterKeyVersion)
{
Uri masterKeyPathUri = new(masterKeyPath);
vaultUri = new Uri(masterKeyPathUri.GetLeftPart(UriPartial.Authority));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Extensions.Caching.Memory;
using System;
using Microsoft.Extensions.Caching.Memory;
using static System.Math;

namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
Expand Down Expand Up @@ -92,6 +92,7 @@ internal TValue GetOrCreate(TKey key, Func<TValue> createItem)

/// <summary>
/// Determines whether the <see cref="LocalCache{TKey, TValue}">LocalCache</see> contains the specified key.
/// Used in unit tests to verify that the cache contains the expected entries.
/// </summary>
/// <param name="key"></param>
/// <returns></returns>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Text;
using System.Threading;
using Azure.Core;
using Azure.Security.KeyVault.Keys.Cryptography;
using static Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider.Validator;
Expand Down Expand Up @@ -55,6 +56,8 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt

private readonly static KeyWrapAlgorithm s_keyWrapAlgorithm = KeyWrapAlgorithm.RsaOaep;

private SemaphoreSlim _cacheSemaphore = new(1, 1);

/// <summary>
/// List of Trusted Endpoints
///
Expand All @@ -69,7 +72,7 @@ public class SqlColumnEncryptionAzureKeyVaultProvider : SqlColumnEncryptionKeySt
/// <summary>
/// A cache for storing the results of signature verification of column master key metadata.
/// </summary>
private readonly LocalCache<Tuple<string, bool, string>, bool> _columnMasterKeyMetadataSignatureVerificationCache =
private readonly LocalCache<Tuple<string, bool, string>, bool> _columnMasterKeyMetadataSignatureVerificationCache =
new(maxSizeLimit: 2000) { TimeToLive = TimeSpan.FromDays(10) };

/// <summary>
Expand Down Expand Up @@ -230,7 +233,7 @@ byte[] DecryptEncryptionKey()
// Get ciphertext
byte[] cipherText = new byte[cipherTextLength];
Array.Copy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength);

currentIndex += cipherTextLength;

// Get signature
Expand Down Expand Up @@ -394,17 +397,10 @@ private byte[] CompileMasterKeyMetadata(string masterKeyPath, bool allowEnclaveC
/// <param name="source">An array of bytes to convert.</param>
/// <returns>A string of hexadecimal characters</returns>
/// <remarks>
/// Produces a string of hexadecimal character pairs preceded with "0x", where each pair represents the corresponding element in value; for example, "0x7F2C4A00".
/// Produces a string of hexadecimal character pairs preceded with "0x", where each pair represents the corresponding element in source; for example, "0x7F2C4A00".
/// </remarks>
private string ToHexString(byte[] source)
{
if (source is null)
{
return null;
}

return "0x" + BitConverter.ToString(source).Replace("-", "");
}
=> source is null ? null : "0x" + BitConverter.ToString(source).Replace("-", "");

/// <summary>
/// Returns the cached decrypted column encryption key, or unwraps the encrypted column encryption key if not present.
Expand All @@ -415,8 +411,21 @@ private string ToHexString(byte[] source)
/// <remarks>
///
/// </remarks>
private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func<byte[]> createItem)
=> _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem);
private byte[] GetOrCreateColumnEncryptionKey(string encryptedColumnEncryptionKey, Func<byte[]> createItem)
{
// Allow only one thread to access the cache at a time.
_cacheSemaphore.Wait();

try
{
return _columnEncryptionKeyCache.GetOrCreate(encryptedColumnEncryptionKey, createItem);
}
finally
{
// Release the semaphore to allow other threads to access the cache.
_cacheSemaphore.Release();
}
}

/// <summary>
/// Returns the cached signature verification result, or proceeds to verify if not present.
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,16 @@
<value>System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value>
</resheader>
<data name="NullOrWhitespaceForEach" xml:space="preserve">
<value>One or more of the elements in {0} are null or empty or consist of only whitespace.</value>
<value>One or more of the elements in '{0}' are null or empty or consist of only whitespace.</value>
</data>
<data name="CipherTextLengthMismatch" xml:space="preserve">
<value>CipherText length does not match the RSA key size.</value>
</data>
<data name="EmptyArgumentInternal" xml:space="preserve">
<value>Internal error. Empty {0} specified.</value>
<value>Internal error. Empty '{0}' specified.</value>
</data>
<data name="GetKeyFailed" xml:space="preserve">
<value>Failed to fetch key from Azure Key Vault. Key: {0}.</value>
</data>
<data name="MasterKeyNotFound" xml:space="preserve">
<value>The key with identifier '{0}' was not found.</value>
Expand Down
Loading
Loading