Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 207 additions & 10 deletions aws-lc-rs/src/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@
//! ```
use crate::aws_lc::{
EVP_PKEY_CTX_kem_set_params, EVP_PKEY_decapsulate, EVP_PKEY_encapsulate,
EVP_PKEY_kem_new_raw_public_key, EVP_PKEY, EVP_PKEY_KEM,
EVP_PKEY_kem_new_raw_public_key, EVP_PKEY_kem_new_raw_secret_key, EVP_PKEY, EVP_PKEY_KEM,
};
use crate::buffer::Buffer;
use crate::encoding::generated_encodings;
use crate::encoding::{generated_encodings, AsDer, Pkcs8V1Der, PublicKeyX509Der};
use crate::error::{KeyRejected, Unspecified};
use crate::ptr::LcPtr;
use alloc::borrow::Cow;
use aws_lc::EVP_PKEY_get_raw_private_key;
use core::cmp::Ordering;
use std::ptr::null_mut;
use zeroize::Zeroize;

const ML_KEM_512_SHARED_SECRET_LENGTH: usize = 32;
Expand Down Expand Up @@ -202,6 +204,66 @@ impl<Id> DecapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
/// Creates a `DecapsulationKey` from "raw" bytes, as those provided by `key_bytes`.
///
/// NOTE: The associated `EncapsulationKey` must be serialized separately. The `DecapsulationKey` returned by this
/// function will not provide the associated `EncapsulationKey`.
///
/// `alg` is the [`Algorithm`] to be associated with the generated `DecapsulationKey`.
///
/// `bytes` is a slice of raw bytes representing a `DecapsulationKey`.
///
/// # Errors
/// `error::Unspecified` when operation fails during key creation.
pub fn new(algorithm: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
match bytes.len().cmp(&algorithm.decapsulate_key_size()) {
Ordering::Less => Err(KeyRejected::too_small()),
Ordering::Greater => Err(KeyRejected::too_large()),
Ordering::Equal => Ok(()),
}?;
let evp_pkey = LcPtr::new(unsafe {
EVP_PKEY_kem_new_raw_secret_key(algorithm.id.nid(), bytes.as_ptr(), bytes.len())
})?;

Ok(DecapsulationKey {
algorithm,
evp_pkey,
})
}

/// Creates a `DecapsulationKey` from a PKCS#8 encoded KEM key.
///
/// NOTE: The associated `EncapsulationKey` might need to be serialized separately. Depending on the encoding, a
/// `DecapsulationKey` returned by this function might not provide its associated `EncapsulationKey`.
///
/// `alg` is the [`Algorithm`] to be associated with the generated `DecapsulationKey`.
///
/// `bytes` is a slice of raw bytes representing a `DecapsulationKey`.
///
/// # Errors
/// `error::Unspecified` when operation fails during key creation.
pub fn from_pkcs8(
algorithm: &'static Algorithm<Id>,
pkcs8: &[u8],
) -> Result<Self, KeyRejected> {
let evp_pkey = LcPtr::<EVP_PKEY>::parse_rfc5208_private_key(pkcs8, EVP_PKEY_KEM)?;

// TODO: A better way to verify the ML-KEM type
let mut size = 0;
if 1 != unsafe { EVP_PKEY_get_raw_private_key(*evp_pkey.as_const(), null_mut(), &mut size) }
{
return Err(KeyRejected::invalid_encoding());
}
if size != algorithm.decapsulate_key_size {
return Err(KeyRejected::invalid_encoding());
}

Ok(DecapsulationKey {
algorithm,
evp_pkey,
})
}

/// Generate a new KEM decapsulation key for the given algorithm.
///
/// # Errors
Expand All @@ -220,18 +282,27 @@ where
self.algorithm
}

/// Computes the KEM encapsulation key from the KEM decapsulation key.
/// If available, provides the associated `EncapsulationKey`. This will be available on a newly generated
/// `DescapsulationKey`. However, a `DescapsulationKey` constructed from deserialization might not have
/// the associated `EncapsulationKey` available.
///
/// # Errors
/// `error::Unspecified` when operation fails due to internal error.
/// `error::Unspecified` if the associated `EncapsulationKey` is not available.
#[allow(clippy::missing_panics_doc)]
pub fn encapsulation_key(&self) -> Result<EncapsulationKey<Id>, Unspecified> {
let evp_pkey = self.evp_pkey.clone();

Ok(EncapsulationKey {
let retval = EncapsulationKey {
algorithm: self.algorithm,
evp_pkey,
})
};

// TODO: A better way of validating that the `EncapsulationKey` is valid.
if retval.key_bytes().is_err() {
return Err(Unspecified);
}

Ok(retval)
}

/// Performs the decapsulate operation using this KEM decapsulation key on the given ciphertext.
Expand Down Expand Up @@ -272,12 +343,38 @@ where

Ok(SharedSecret(shared_secret.into_boxed_slice()))
}

/// Returns the `DecapsulationKey` bytes.
///
/// # Errors
/// * `Unspecified`: Any failure to retrieve the `DecapsulationKey` bytes.
pub fn key_bytes(&self) -> Result<DecapsulationKeyBytes<'static>, Unspecified> {
let decapsulation_key_bytes = self.evp_pkey.as_const().marshal_raw_private_key()?;
debug_assert_eq!(
decapsulation_key_bytes.len(),
self.algorithm.decapsulate_key_size()
);
Ok(DecapsulationKeyBytes::new(decapsulation_key_bytes))
}
}

unsafe impl<Id> Send for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}

unsafe impl<Id> Sync for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}

impl<Id> AsDer<Pkcs8V1Der<'static>> for DecapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
fn as_der(&self) -> Result<Pkcs8V1Der<'static>, crate::error::Unspecified> {
Ok(Pkcs8V1Der::new(
self.evp_pkey
.as_const()
.marshal_rfc5208_private_key(crate::pkcs8::Version::V1)?,
))
}
}

impl<Id> Debug for DecapsulationKey<Id>
where
Id: AlgorithmIdentifier,
Expand All @@ -289,7 +386,10 @@ where
}
}

generated_encodings!((EncapsulationKeyBytes, EncapsulationKeyBytesType));
generated_encodings!(
(EncapsulationKeyBytes, EncapsulationKeyBytesType),
(DecapsulationKeyBytes, DecapsulationKeyBytesType)
);

/// A serializable encapsulation key usable with KEM algorithms. Constructed
/// from either a `DecapsulationKey` or raw bytes.
Expand Down Expand Up @@ -399,6 +499,17 @@ unsafe impl<Id> Send for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}

unsafe impl<Id> Sync for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}

impl<Id> AsDer<PublicKeyX509Der<'static>> for EncapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
fn as_der(&self) -> Result<PublicKeyX509Der<'static>, crate::error::Unspecified> {
Ok(PublicKeyX509Der::new(
self.evp_pkey.as_const().marshal_rfc5280_public_key()?,
))
}
}

impl<Id> Debug for EncapsulationKey<Id>
where
Id: AlgorithmIdentifier,
Expand Down Expand Up @@ -508,7 +619,14 @@ mod tests {
fn test_kem_serialize() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);
let priv_key_raw_bytes = priv_key.key_bytes().unwrap();
let priv_key_from_bytes =
DecapsulationKey::new(algorithm, priv_key_raw_bytes.as_ref()).unwrap();

assert_eq!(
priv_key.key_bytes().unwrap().as_ref(),
priv_key_from_bytes.key_bytes().unwrap().as_ref()
);

let pub_key = priv_key.encapsulation_key().unwrap();
let pubkey_raw_bytes = pub_key.key_bytes().unwrap();
Expand Down Expand Up @@ -539,6 +657,20 @@ mod tests {
short_pub_key_from_bytes.err(),
Some(KeyRejected::too_small())
);

let too_long_bytes = vec![0u8; algorithm.decapsulate_key_size() + 1];
let long_priv_key_from_bytes = DecapsulationKey::new(algorithm, &too_long_bytes);
assert_eq!(
long_priv_key_from_bytes.err(),
Some(KeyRejected::too_large())
);

let too_short_bytes = vec![0u8; algorithm.decapsulate_key_size() - 1];
let short_priv_key_from_bytes = DecapsulationKey::new(algorithm, &too_short_bytes);
assert_eq!(
short_priv_key_from_bytes.err(),
Some(KeyRejected::too_small())
);
}
}

Expand All @@ -547,13 +679,17 @@ mod tests {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);
let priv_key_raw_bytes = priv_key.key_bytes().unwrap();
let priv_key_from_bytes =
DecapsulationKey::new(algorithm, priv_key_raw_bytes.as_ref()).unwrap();

assert!(priv_key_from_bytes.encapsulation_key().is_err());
let pub_key = priv_key.encapsulation_key().unwrap();

let (alice_ciphertext, alice_secret) =
pub_key.encapsulate().expect("encapsulate successful");

let bob_secret = priv_key
let bob_secret = priv_key_from_bytes
.decapsulate(alice_ciphertext)
.expect("decapsulate successful");

Expand All @@ -572,23 +708,84 @@ mod tests {
// Generate public key bytes to send to bob
let pub_key_bytes = pub_key.key_bytes().unwrap();

// Generate private key bytes for alice to store securely
let priv_key_bytes = priv_key.key_bytes().unwrap();

// Test that priv_key's EVP_PKEY isn't entirely freed since we remove this pub_key's reference.
drop(pub_key);
drop(priv_key);

let retrieved_pub_key =
EncapsulationKey::new(algorithm, pub_key_bytes.as_ref()).unwrap();
let (ciphertext, bob_secret) = retrieved_pub_key
.encapsulate()
.expect("encapsulate successful");

let alice_secret = priv_key
// Alice reconstructs her private key from stored bytes
let retrieved_priv_key =
DecapsulationKey::new(algorithm, priv_key_bytes.as_ref()).unwrap();
let alice_secret = retrieved_priv_key
.decapsulate(ciphertext)
.expect("decapsulate successful");

assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
}
}

#[test]
fn test_decapsulation_key_serialization_comprehensive() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
// Generate original key
let original_key = DecapsulationKey::generate(algorithm).unwrap();

// Test key_bytes() returns correct size
let key_bytes = original_key.key_bytes().unwrap();
assert_eq!(key_bytes.as_ref().len(), algorithm.decapsulate_key_size());

// Test round-trip serialization/deserialization
let reconstructed_key = DecapsulationKey::new(algorithm, key_bytes.as_ref()).unwrap();

// Verify algorithm consistency
assert_eq!(original_key.algorithm(), reconstructed_key.algorithm());
assert_eq!(original_key.algorithm(), algorithm);

// Test multiple serialization rounds produce identical results
let key_bytes_2 = reconstructed_key.key_bytes().unwrap();
assert_eq!(key_bytes.as_ref(), key_bytes_2.as_ref());

let reconstructed_key_2 =
DecapsulationKey::new(algorithm, key_bytes_2.as_ref()).unwrap();
let key_bytes_3 = reconstructed_key_2.key_bytes().unwrap();
assert_eq!(key_bytes.as_ref(), key_bytes_3.as_ref());

// Test functional equivalence - both keys should decrypt the same ciphertext identically
let pub_key = original_key.encapsulation_key().unwrap();
let (ciphertext, expected_secret) = pub_key.encapsulate().unwrap();

// Both the original and reconstructed keys should decrypt to the same secret
let secret_from_original = original_key
.decapsulate(Ciphertext::from(ciphertext.as_ref()))
.unwrap();
let secret_from_reconstructed = reconstructed_key
.decapsulate(Ciphertext::from(ciphertext.as_ref()))
.unwrap();

// All three secrets should be identical
assert_eq!(expected_secret.as_ref(), secret_from_original.as_ref());
assert_eq!(expected_secret.as_ref(), secret_from_reconstructed.as_ref());
assert_eq!(
secret_from_original.as_ref(),
secret_from_reconstructed.as_ref()
);

// Verify secret length is correct
assert_eq!(
expected_secret.as_ref().len(),
algorithm.shared_secret_size()
);
}
}

#[test]
fn test_debug_fmt() {
let private = DecapsulationKey::generate(&ML_KEM_512).expect("successful generation");
Expand Down
Loading