Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 230 additions & 43 deletions src/libraries/Common/src/System/Security/Cryptography/MLDsa.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,25 @@ internal static MLDsaImplementation DuplicatePrivateKey(MLDsa key)
Debug.Assert(key is not MLDsaImplementation);

MLDsaAlgorithm alg = key.Algorithm;
Debug.Assert(alg.SecretKeySizeInBytes > alg.PrivateSeedSizeInBytes);
byte[] rented = CryptoPool.Rent(alg.SecretKeySizeInBytes);
int written = 0;

try
{
written = key.ExportMLDsaPrivateSeed(rented);
return ImportSeed(alg, new ReadOnlySpan<byte>(rented, 0, written));
Span<byte> seedSpan = rented.AsSpan(0, alg.PrivateSeedSizeInBytes);
key.ExportMLDsaPrivateSeed(seedSpan);
return ImportSeed(alg, seedSpan);
}
catch (CryptographicException)
{
written = key.ExportMLDsaSecretKey(rented);
return ImportSecretKey(alg, new ReadOnlySpan<byte>(rented, 0, written));
// Rented array may still be larger but we expect exact length
Span<byte> skSpan = rented.AsSpan(0, alg.SecretKeySizeInBytes);
key.ExportMLDsaSecretKey(skSpan);
return ImportSecretKey(alg, skSpan);
}
finally
{
CryptoPool.Return(rented, written);
CryptoPool.Return(rented);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ internal static SlhDsaImplementation DuplicatePrivateKey(SlhDsa key)

Span<byte> secretKey = (stackalloc byte[128])[..key.Algorithm.SecretKeySizeInBytes];
key.ExportSlhDsaSecretKey(secretKey);

try
{
return ImportSecretKey(key.Algorithm, secretKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,21 @@ public class MLDsaImplementationTests : MLDsaTestsBase
public static void GenerateImport_NullAlgorithm()
{
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.GenerateKey(null));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaPrivateSeed(null, default));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaPublicKey(null, default));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaSecretKey(null, default));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaPrivateSeed(null, default(ReadOnlySpan<byte>)));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaPublicKey(null, default(ReadOnlySpan<byte>)));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaSecretKey(null, default(ReadOnlySpan<byte>)));

AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaPrivateSeed(null, (byte[]?)null));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaPublicKey(null, (byte[]?)null));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaSecretKey(null, (byte[]?)null));
}

[Fact]
public static void Import_NullSource()
{
AssertExtensions.Throws<ArgumentNullException>("source", static () => MLDsa.ImportMLDsaPrivateSeed(MLDsaAlgorithm.MLDsa44, (byte[]?)null));
AssertExtensions.Throws<ArgumentNullException>("source", static () => MLDsa.ImportMLDsaPublicKey(MLDsaAlgorithm.MLDsa44, (byte[]?)null));
AssertExtensions.Throws<ArgumentNullException>("source", static () => MLDsa.ImportMLDsaSecretKey(MLDsaAlgorithm.MLDsa44, (byte[]?)null));
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ internal static void AssertImportPublicKey(Action<Func<MLDsa>> testDirectCall, A
{
testDirectCall(() => MLDsa.ImportMLDsaPublicKey(algorithm, Array.Empty<byte>().AsSpan()));
testDirectCall(() => MLDsa.ImportMLDsaPublicKey(algorithm, ReadOnlySpan<byte>.Empty));
testDirectCall(() => MLDsa.ImportMLDsaPublicKey(algorithm, default(ReadOnlySpan<byte>)));
}
else
{
testDirectCall(() => MLDsa.ImportMLDsaPublicKey(algorithm, publicKey));
testDirectCall(() => MLDsa.ImportMLDsaPublicKey(algorithm, publicKey.AsSpan()));
}

Expand Down Expand Up @@ -109,9 +111,11 @@ internal static void AssertImportSecretKey(Action<Func<MLDsa>> testDirectCall, A
{
testDirectCall(() => MLDsa.ImportMLDsaSecretKey(algorithm, Array.Empty<byte>().AsSpan()));
testDirectCall(() => MLDsa.ImportMLDsaSecretKey(algorithm, ReadOnlySpan<byte>.Empty));
testDirectCall(() => MLDsa.ImportMLDsaSecretKey(algorithm, default(ReadOnlySpan<byte>)));
}
else
{
testDirectCall(() => MLDsa.ImportMLDsaSecretKey(algorithm, secretKey));
testDirectCall(() => MLDsa.ImportMLDsaSecretKey(algorithm, secretKey.AsSpan()));
}

Expand Down Expand Up @@ -147,9 +151,11 @@ internal static void AssertImportPrivateSeed(Action<Func<MLDsa>> testDirectCall,
{
testDirectCall(() => MLDsa.ImportMLDsaPrivateSeed(algorithm, Array.Empty<byte>().AsSpan()));
testDirectCall(() => MLDsa.ImportMLDsaPrivateSeed(algorithm, ReadOnlySpan<byte>.Empty));
testDirectCall(() => MLDsa.ImportMLDsaPrivateSeed(algorithm, default(ReadOnlySpan<byte>)));
}
else
{
testDirectCall(() => MLDsa.ImportMLDsaPrivateSeed(algorithm, privateSeed));
testDirectCall(() => MLDsa.ImportMLDsaPrivateSeed(algorithm, privateSeed.AsSpan()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ public static void NullArgumentValidation(MLDsaAlgorithm algorithm, bool shouldD

PbeParameters pbeParameters = new PbeParameters(PbeEncryptionAlgorithm.TripleDes3KeyPkcs12, HashAlgorithmName.SHA1, 42);

AssertExtensions.Throws<ArgumentNullException>("data", () => mldsa.SignData(null));
AssertExtensions.Throws<ArgumentNullException>("data", () => mldsa.VerifyData(null, null));
AssertExtensions.Throws<ArgumentNullException>("signature", () => mldsa.VerifyData(Array.Empty<byte>(), null));

AssertExtensions.Throws<ArgumentNullException>("password", () => mldsa.ExportEncryptedPkcs8PrivateKey((string)null, pbeParameters));
AssertExtensions.Throws<ArgumentNullException>("password", () => mldsa.ExportEncryptedPkcs8PrivateKeyPem((string)null, pbeParameters));
AssertExtensions.Throws<ArgumentNullException>("password", () => mldsa.TryExportEncryptedPkcs8PrivateKey((string)null, pbeParameters, Span<byte>.Empty, out _));
Expand Down Expand Up @@ -107,9 +111,13 @@ public static void ArgumentValidation(MLDsaAlgorithm algorithm, bool shouldDispo
}

AssertExtensions.Throws<ArgumentException>("destination", () => mldsa.ExportMLDsaPublicKey(new byte[publicKeySize - 1]));
AssertExtensions.Throws<ArgumentException>("destination", () => mldsa.ExportMLDsaPublicKey(new byte[publicKeySize + 1]));
AssertExtensions.Throws<ArgumentException>("destination", () => mldsa.ExportMLDsaSecretKey(new byte[secretKeySize - 1]));
AssertExtensions.Throws<ArgumentException>("destination", () => mldsa.ExportMLDsaSecretKey(new byte[secretKeySize + 1]));
AssertExtensions.Throws<ArgumentException>("destination", () => mldsa.ExportMLDsaPrivateSeed(new byte[privateSeedSize - 1]));
AssertExtensions.Throws<ArgumentException>("destination", () => mldsa.ExportMLDsaPrivateSeed(new byte[privateSeedSize + 1]));
AssertExtensions.Throws<ArgumentException>("destination", () => mldsa.SignData(ReadOnlySpan<byte>.Empty, new byte[signatureSize - 1], ReadOnlySpan<byte>.Empty));
AssertExtensions.Throws<ArgumentException>("destination", () => mldsa.SignData(ReadOnlySpan<byte>.Empty, new byte[signatureSize + 1], ReadOnlySpan<byte>.Empty));

// Context length must be less than 256
AssertExtensions.Throws<ArgumentOutOfRangeException>("context", () => mldsa.SignData(ReadOnlySpan<byte>.Empty, new byte[signatureSize], new byte[256]));
Expand Down Expand Up @@ -165,14 +173,22 @@ public static void ExportMLDsaPublicKey_CallsCore(MLDsaAlgorithm algorithm)
mldsa.AddFillDestination(1);

int publicKeySize = algorithm.PublicKeySizeInBytes;

// Array overload
byte[] exported = mldsa.ExportMLDsaPublicKey();
Assert.Equal(1, mldsa.ExportMLDsaPublicKeyCoreCallCount);
Assert.Equal(publicKeySize, exported.Length);
AssertExpectedFill(exported, fillElement: 1);

// Span overload
byte[] publicKey = CreatePaddedFilledArray(publicKeySize, 42);

// Extra bytes in destination buffer should not be touched
Memory<byte> destination = publicKey.AsMemory(PaddingSize, publicKeySize);
mldsa.AddDestinationBufferIsSameAssertion(destination);

mldsa.ExportMLDsaPublicKey(destination.Span);
Assert.Equal(1, mldsa.ExportMLDsaPublicKeyCoreCallCount);
Assert.Equal(2, mldsa.ExportMLDsaPublicKeyCoreCallCount);
AssertExpectedFill(publicKey, fillElement: 1, paddingElement: 42, PaddingSize, publicKeySize);
}

Expand All @@ -185,17 +201,53 @@ public static void ExportMLDsaSecretKey_CallsCore(MLDsaAlgorithm algorithm)
mldsa.AddFillDestination(1);

int secretKeySize = algorithm.SecretKeySizeInBytes;

// Array overload
byte[] exported = mldsa.ExportMLDsaSecretKey();
Assert.Equal(1, mldsa.ExportMLDsaSecretKeyCoreCallCount);
Assert.Equal(secretKeySize, exported.Length);
AssertExpectedFill(exported, fillElement: 1);

// Span overload
byte[] secretKey = CreatePaddedFilledArray(secretKeySize, 42);

// Extra bytes in destination buffer should not be touched
Memory<byte> destination = secretKey.AsMemory(PaddingSize, secretKeySize);
mldsa.AddDestinationBufferIsSameAssertion(destination);

mldsa.ExportMLDsaSecretKey(destination.Span);
Assert.Equal(1, mldsa.ExportMLDsaSecretKeyCoreCallCount);
Assert.Equal(2, mldsa.ExportMLDsaSecretKeyCoreCallCount);
AssertExpectedFill(secretKey, fillElement: 1, paddingElement: 42, PaddingSize, secretKeySize);
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public static void ExportMLDsaPrivateSeed_CallsCore(MLDsaAlgorithm algorithm)
{
using MLDsaTestImplementation mldsa = MLDsaTestImplementation.CreateOverriddenCoreMethodsFail(algorithm);
mldsa.ExportMLDsaPrivateSeedHook = _ => { };
mldsa.AddFillDestination(1);

int privateSeedSize = algorithm.PrivateSeedSizeInBytes;

// Array overload
byte[] exported = mldsa.ExportMLDsaPrivateSeed();
Assert.Equal(1, mldsa.ExportMLDsaPrivateSeedCoreCallCount);
Assert.Equal(privateSeedSize, exported.Length);
AssertExpectedFill(exported, fillElement: 1);

// Span overload
byte[] privateSeed = CreatePaddedFilledArray(privateSeedSize, 42);

// Extra bytes in destination buffer should not be touched
Memory<byte> destination = privateSeed.AsMemory(PaddingSize, privateSeedSize);
mldsa.AddDestinationBufferIsSameAssertion(destination);

mldsa.ExportMLDsaPrivateSeed(destination.Span);
Assert.Equal(2, mldsa.ExportMLDsaPrivateSeedCoreCallCount);
AssertExpectedFill(privateSeed, fillElement: 1, paddingElement: 42, PaddingSize, privateSeedSize);
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public static void SignData_CallsCore(MLDsaAlgorithm algorithm)
Expand All @@ -210,14 +262,22 @@ public static void SignData_CallsCore(MLDsaAlgorithm algorithm)
mldsa.AddFillDestination(1);

int signatureSize = algorithm.SignatureSizeInBytes;

// Array overload
byte[] exported = mldsa.SignData(testData, testContext);
Assert.Equal(1, mldsa.SignDataCoreCallCount);
Assert.Equal(signatureSize, exported.Length);
AssertExpectedFill(exported, fillElement: 1);

// Span overload
byte[] signature = CreatePaddedFilledArray(signatureSize, 42);

// Extra bytes in destination buffer should not be touched
Memory<byte> destination = signature.AsMemory(PaddingSize, signatureSize);
mldsa.AddDestinationBufferIsSameAssertion(destination);

mldsa.SignData(testData, destination.Span, testContext);
Assert.Equal(1, mldsa.SignDataCoreCallCount);
Assert.Equal(2, mldsa.SignDataCoreCallCount);
AssertExpectedFill(signature, fillElement: 1, paddingElement: 42, PaddingSize, signatureSize);
}

Expand Down
Loading
Loading