Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 146 additions & 52 deletions src/libraries/Common/src/System/Security/Cryptography/CompositeMLDsa.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,15 @@ protected CompositeMLDsa(CompositeMLDsaAlgorithm algorithm)
/// <returns>
/// <see langword="true"/> if the algorithm is supported; otherwise, <see langword="false"/>.
/// </returns>
public static bool IsAlgorithmSupported(CompositeMLDsaAlgorithm algorithm) =>
CompositeMLDsaImplementation.IsAlgorithmSupportedImpl(algorithm);
/// <exception cref="ArgumentNullException">
/// <paramref name="algorithm"/> is <see langword="null"/>.
/// </exception>
public static bool IsAlgorithmSupported(CompositeMLDsaAlgorithm algorithm)
{
ArgumentNullException.ThrowIfNull(algorithm);

return CompositeMLDsaImplementation.IsAlgorithmSupportedImpl(algorithm);
}

/// <summary>
/// Signs the specified data.
Expand Down Expand Up @@ -136,24 +143,17 @@ public byte[] SignData(byte[] data, byte[]? context = default)

ThrowIfDisposed();

// TODO If we know exact size of signature, then we can allocate instead of renting and copying.
byte[] rented = CryptoPool.Rent(32 + Algorithm.MaxSignatureSizeInBytes);

try
if (Algorithm.SignatureSize.IsExact)
{
if (!TrySignDataCore(new ReadOnlySpan<byte>(data), new ReadOnlySpan<byte>(context), rented, out int written))
{
Debug.Fail($"Signature exceeds {nameof(Algorithm.MaxSignatureSizeInBytes)} ({Algorithm.MaxSignatureSizeInBytes}).");
throw new CryptographicException();
}

return rented.AsSpan(0, written).ToArray();
}
finally
{
// Signature does not contain sensitive information.
CryptoPool.Return(rented, clearSize: 0);
return ExportExactSize(
Algorithm.MaxSignatureSizeInBytes,
(key, dest, out written) => key.TrySignData(new ReadOnlySpan<byte>(data), dest, out written, new ReadOnlySpan<byte>(context)));
}

return ExportWithCallback(
Algorithm.MaxSignatureSizeInBytes,
(key, dest, out written) => key.TrySignData(new ReadOnlySpan<byte>(data), dest, out written, new ReadOnlySpan<byte>(context)),
key => key.ToArray());
}

/// <summary>
Expand Down Expand Up @@ -204,13 +204,28 @@ public bool TrySignData(ReadOnlySpan<byte> data, Span<byte> destination, out int

ThrowIfDisposed();

if (destination.Length < 32 + Algorithm.MLDsaAlgorithm.SignatureSizeInBytes)
if (Algorithm.SignatureSize.IsAlwaysLargerThan(destination.Length))
{
bytesWritten = 0;
return false;
}

return TrySignDataCore(data, context, destination, out bytesWritten);
if (TrySignDataCore(data, context, destination, out int written))
{
if (!Algorithm.SignatureSize.IsValidSize(written))
{
CryptographicOperations.ZeroMemory(destination);

bytesWritten = 0;
throw new CryptographicException(SR.Cryptography_UnexpectedExportBufferSize);
}

bytesWritten = written;
return true;
}

bytesWritten = 0;
return false;
}

/// <summary>
Expand Down Expand Up @@ -316,13 +331,7 @@ public bool VerifyData(ReadOnlySpan<byte> data, ReadOnlySpan<byte> signature, Re

ThrowIfDisposed();

// TODO change this to 32 + Algorithm.MLDsaAlgorithm.SignatureSizeInBytes. Check other places too.
if (signature.Length < 32 + Algorithm.MLDsaAlgorithm.SignatureSizeInBytes)
{
return false;
}

return VerifyDataCore(data, context, signature);
return Algorithm.SignatureSize.IsValidSize(signature.Length) && VerifyDataCore(data, context, signature);
}

/// <summary>
Expand Down Expand Up @@ -661,9 +670,9 @@ static void SubjectPublicKeyReader(ReadOnlyMemory<byte> key, in AlgorithmIdentif
{
CompositeMLDsaAlgorithm algorithm = GetAlgorithmIdentifier(in identifier);

if (key.Length < algorithm.MLDsaAlgorithm.PublicKeySizeInBytes)
if (!algorithm.PublicKeySize.IsValidSize(key.Length))
{
throw new CryptographicException(SR.Argument_PublicKeyTooShortForAlgorithm);
throw new CryptographicException(SR.Argument_PublicKeyWrongSizeForAlgorithm);
}

dsa = CompositeMLDsaImplementation.ImportCompositeMLDsaPublicKeyImpl(algorithm, key.Span);
Expand Down Expand Up @@ -860,15 +869,27 @@ static void PrivateKeyReader(
throw new CryptographicException(SR.Cryptography_Der_Invalid_Encoding);
}

if (key.Length < algorithm.MLDsaAlgorithm.PrivateSeedSizeInBytes)
if (!algorithm.PrivateKeySize.IsValidSize(key.Length))
{
throw new CryptographicException(SR.Argument_PrivateKeyTooShortForAlgorithm);
throw new CryptographicException(SR.Argument_PrivateKeyWrongSizeForAlgorithm);
}

dsa = CompositeMLDsaImplementation.ImportCompositeMLDsaPrivateKeyImpl(algorithm, key);
}
}

/// <inheritdoc cref="ImportCompositeMLDsaPublicKey(CompositeMLDsaAlgorithm, ReadOnlySpan{byte})" />
/// <exception cref="ArgumentNullException">
/// <paramref name="algorithm"/> or <paramref name="source" /> is <see langword="null" />.
/// </exception>
public static CompositeMLDsa ImportCompositeMLDsaPublicKey(CompositeMLDsaAlgorithm algorithm, byte[] source)
{
ArgumentNullException.ThrowIfNull(algorithm);
ArgumentNullException.ThrowIfNull(source);

return ImportCompositeMLDsaPublicKey(algorithm, new ReadOnlySpan<byte>(source));
}

/// <summary>
/// Imports a Composite ML-DSA public key.
/// </summary>
Expand Down Expand Up @@ -896,15 +917,29 @@ static void PrivateKeyReader(
/// </exception>
public static CompositeMLDsa ImportCompositeMLDsaPublicKey(CompositeMLDsaAlgorithm algorithm, ReadOnlySpan<byte> source)
{
ArgumentNullException.ThrowIfNull(algorithm);
ThrowIfNotSupported(algorithm);

if (source.Length < algorithm.MLDsaAlgorithm.PublicKeySizeInBytes)
if (!algorithm.PublicKeySize.IsValidSize(source.Length))
{
throw new CryptographicException(SR.Argument_PublicKeyTooShortForAlgorithm);
throw new CryptographicException(SR.Argument_PublicKeyWrongSizeForAlgorithm);
}

return CompositeMLDsaImplementation.ImportCompositeMLDsaPublicKeyImpl(algorithm, source);
}

/// <inheritdoc cref="ImportCompositeMLDsaPrivateKey(CompositeMLDsaAlgorithm, ReadOnlySpan{byte})" />
/// <exception cref="ArgumentNullException">
/// <paramref name="algorithm"/> or <paramref name="source" /> is <see langword="null" />.
/// </exception>
public static CompositeMLDsa ImportCompositeMLDsaPrivateKey(CompositeMLDsaAlgorithm algorithm, byte[] source)
{
ArgumentNullException.ThrowIfNull(algorithm);
ArgumentNullException.ThrowIfNull(source);

return ImportCompositeMLDsaPrivateKey(algorithm, new ReadOnlySpan<byte>(source));
}

/// <summary>
/// Imports a Composite ML-DSA private key.
/// </summary>
Expand Down Expand Up @@ -932,11 +967,12 @@ public static CompositeMLDsa ImportCompositeMLDsaPublicKey(CompositeMLDsaAlgorit
/// </exception>
public static CompositeMLDsa ImportCompositeMLDsaPrivateKey(CompositeMLDsaAlgorithm algorithm, ReadOnlySpan<byte> source)
{
ArgumentNullException.ThrowIfNull(algorithm);
ThrowIfNotSupported(algorithm);

if (source.Length < algorithm.MLDsaAlgorithm.PrivateSeedSizeInBytes)
if (!algorithm.PrivateKeySize.IsValidSize(source.Length))
{
throw new CryptographicException(SR.Argument_PrivateKeyTooShortForAlgorithm);
throw new CryptographicException(SR.Argument_PrivateKeyWrongSizeForAlgorithm);
}

return CompositeMLDsaImplementation.ImportCompositeMLDsaPrivateKeyImpl(algorithm, source);
Expand Down Expand Up @@ -1347,7 +1383,7 @@ public bool TryExportPkcs8PrivateKey(Span<byte> destination, out int bytesWritte

// The bound can be tightened but private key length of some traditional algorithms,
// can vary and aren't worth the complex calculation.
int minimumPossiblePkcs8Key = Algorithm.MLDsaAlgorithm.PrivateSeedSizeInBytes;
int minimumPossiblePkcs8Key = Algorithm.PrivateKeySize.MinimumSizeInBytes;

if (destination.Length < minimumPossiblePkcs8Key)
{
Expand Down Expand Up @@ -1466,6 +1502,13 @@ public byte[] ExportCompositeMLDsaPublicKey()
{
ThrowIfDisposed();

if (Algorithm.PublicKeySize.IsExact)
{
return ExportExactSize(
Algorithm.PublicKeySize.MinimumSizeInBytes,
static (key, dest, out written) => key.TryExportCompositeMLDsaPublicKey(dest, out written));
}

return ExportPublicKeyCallback(static publicKey => publicKey.ToArray());
}

Expand Down Expand Up @@ -1493,9 +1536,28 @@ public bool TryExportCompositeMLDsaPublicKey(Span<byte> destination, out int byt
{
ThrowIfDisposed();

// TODO short-circuit based on known required length lower bounds
if (Algorithm.PublicKeySize.IsAlwaysLargerThan(destination.Length))
{
bytesWritten = 0;
return false;
}

return TryExportCompositeMLDsaPublicKeyCore(destination, out bytesWritten);
if (TryExportCompositeMLDsaPublicKeyCore(destination, out int written))
{
if (!Algorithm.PublicKeySize.IsValidSize(written))
{
CryptographicOperations.ZeroMemory(destination);

bytesWritten = 0;
throw new CryptographicException(SR.Cryptography_UnexpectedExportBufferSize);
}

bytesWritten = written;
return true;
}

bytesWritten = 0;
return false;
}

/// <summary>
Expand All @@ -1514,12 +1576,16 @@ public byte[] ExportCompositeMLDsaPrivateKey()
{
ThrowIfDisposed();

// TODO The private key has a max size so add it as CompositeMLDsaAlgorithm.MaxPrivateKeySize and use it here.
int initalSize = Algorithm.MLDsaAlgorithm.PrivateSeedSizeInBytes;
if (Algorithm.PrivateKeySize.IsExact)
{
return ExportExactSize(
Algorithm.PrivateKeySize.MinimumSizeInBytes,
static (key, dest, out written) => key.TryExportCompositeMLDsaPrivateKey(dest, out written));
}

return ExportWithCallback(
initalSize,
static (key, dest, out written) => key.TryExportCompositeMLDsaPrivateKeyCore(dest, out written),
Algorithm.PrivateKeySize.InitialExportBufferSizeInBytes,
static (key, dest, out written) => key.TryExportCompositeMLDsaPrivateKey(dest, out written),
static privateKey => privateKey.ToArray());
}

Expand Down Expand Up @@ -1547,9 +1613,28 @@ public bool TryExportCompositeMLDsaPrivateKey(Span<byte> destination, out int by
{
ThrowIfDisposed();

// TODO short-circuit based on known required length lower bounds
if (Algorithm.PrivateKeySize.IsAlwaysLargerThan(destination.Length))
{
bytesWritten = 0;
return false;
}

if (TryExportCompositeMLDsaPrivateKeyCore(destination, out int written))
{
if (!Algorithm.PrivateKeySize.IsValidSize(written))
{
CryptographicOperations.ZeroMemory(destination);

bytesWritten = 0;
throw new CryptographicException(SR.Cryptography_UnexpectedExportBufferSize);
}

bytesWritten = written;
return true;
}

return TryExportCompositeMLDsaPrivateKeyCore(destination, out bytesWritten);
bytesWritten = 0;
return false;
}

/// <summary>
Expand Down Expand Up @@ -1671,8 +1756,7 @@ private AsnWriter WritePkcs8ToAsnWriter()

private TResult ExportPkcs8PrivateKeyCallback<TResult>(ProcessExportedContent<TResult> func)
{
// TODO Pick a good estimate for the initial size of the buffer.
int initialSize = 1;
int initialSize = Algorithm.PrivateKeySize.InitialExportBufferSizeInBytes;

return ExportWithCallback(
initialSize,
Expand Down Expand Up @@ -1709,13 +1793,9 @@ private AsnWriter WriteSubjectPublicKeyToAsnWriter()

private TResult ExportPublicKeyCallback<TResult>(ProcessExportedContent<TResult> func)
{
// TODO RSA is the only algo without a strict max size. The exponent can be arbitrarily large,
// but in practice it is always 65537. Add an internal CompositeMLDsaAlgorithm.EstimatedMaxPublicKeySizeInBytes and use that here.
int initialSize = Algorithm.MLDsaAlgorithm.PublicKeySizeInBytes;

return ExportWithCallback(
initialSize,
static (key, dest, out written) => key.TryExportCompositeMLDsaPublicKeyCore(dest, out written),
Algorithm.PublicKeySize.InitialExportBufferSizeInBytes,
static (key, dest, out written) => key.TryExportCompositeMLDsaPublicKey(dest, out written),
func);
}

Expand Down Expand Up @@ -1757,6 +1837,20 @@ private TResult ExportWithCallback<TResult>(
}
}

private byte[] ExportExactSize(int exactSize, TryExportFunc tryExportFunc)
{
byte[] ret = new byte[exactSize];

if (!tryExportFunc(this, ret, out int written) || written != exactSize)
{
CryptographicOperations.ZeroMemory(ret);

throw new CryptographicException(SR.Cryptography_UnexpectedExportBufferSize);
}

return ret;
}

private static CompositeMLDsaAlgorithm GetAlgorithmIdentifier(ref readonly AlgorithmIdentifierAsn identifier)
{
CompositeMLDsaAlgorithm? algorithm = CompositeMLDsaAlgorithm.GetAlgorithmFromOid(identifier.Algorithm);
Expand Down
Loading
Loading