Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/sats/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ smallvec.workspace = true
thiserror.workspace = true

[dev-dependencies]
ahash.workspace = true
bytes.workspace = true
rand.workspace = true
# Also as dev-dependencies for use in _this_ crate's tests.
Expand Down
7 changes: 7 additions & 0 deletions crates/sats/proptest-regressions/algebraic_value_hash.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Seeds for failure cases proptest has generated in the past. It is
# automatically read and these particular cases re-run before any
# novel cases are generated.
#
# It is recommended to check this file in to source control so that
# everyone who runs the test benefits from these saved cases.
cc aaa05e16925268348653cb8c1945d820f2f8da931fd7ff9a895178d443e0e64f # shrinks to (ty, val) = (Builtin(Array(ArrayType { elem_ty: Builtin(U8) })), Array([]))
166 changes: 165 additions & 1 deletion crates/sats/src/algebraic_value_hash.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
//! Defines hash functions for `AlgebraicValue` and friends.

use crate::{AlgebraicValue, ArrayValue, ProductValue};
use crate::{
bsatn::Deserializer,
buffer::{BufReader, DecodeError},
de::{Deserialize, Deserializer as _},
AlgebraicType, AlgebraicValue, ArrayValue, BuiltinType, MapType, ProductType, ProductValue, SumType, F32, F64,
};
use core::hash::{Hash, Hasher};

// We only manually implement those hash functions that cannot be `#[derive(Hash)]`ed.
Expand Down Expand Up @@ -85,3 +90,162 @@ impl Hash for ArrayValue {
}
}
}

type HR = Result<(), DecodeError>;

pub fn hash_bsatn<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
match ty {
AlgebraicType::Ref(_) => unreachable!("hash_bsatn does not have a typespace"),
AlgebraicType::Sum(ty) => hash_bsatn_sum(state, ty, de),
AlgebraicType::Product(ty) => hash_bsatn_prod(state, ty, de),
AlgebraicType::Builtin(BuiltinType::Array(ty)) => hash_bsatn_array(state, &ty.elem_ty, de),
AlgebraicType::Builtin(BuiltinType::Map(ty)) => hash_bsatn_map(state, ty, de),
&AlgebraicType::Bool => hash_bsatn_de::<bool>(state, de),
&AlgebraicType::I8 => hash_bsatn_de::<i8>(state, de),
&AlgebraicType::U8 => hash_bsatn_de::<u8>(state, de),
&AlgebraicType::I16 => hash_bsatn_de::<i16>(state, de),
&AlgebraicType::U16 => hash_bsatn_de::<u16>(state, de),
&AlgebraicType::I32 => hash_bsatn_de::<i32>(state, de),
&AlgebraicType::U32 => hash_bsatn_de::<u32>(state, de),
&AlgebraicType::I64 => hash_bsatn_de::<i64>(state, de),
&AlgebraicType::U64 => hash_bsatn_de::<u64>(state, de),
&AlgebraicType::I128 => hash_bsatn_de::<i128>(state, de),
&AlgebraicType::U128 => hash_bsatn_de::<u128>(state, de),
&AlgebraicType::F32 => hash_bsatn_de::<F32>(state, de),
&AlgebraicType::F64 => hash_bsatn_de::<F64>(state, de),
&AlgebraicType::String => hash_bsatn_de::<&str>(state, de),
}
}

/// Hashes the tag and payload of the BSATN-encoded sum value.
fn hash_bsatn_sum<'a>(state: &mut impl Hasher, ty: &SumType, mut de: Deserializer<'_, impl BufReader<'a>>) -> HR {
// Read + hash the tag.
let tag = de.reborrow().deserialize_u8()?;
tag.hash(state);

// Hash the payload.
let data_ty = &ty.variants[tag as usize].algebraic_type;
hash_bsatn(state, data_ty, de)
}

/// Hashes every field in the BSATN-encoded product value.
fn hash_bsatn_prod<'a>(state: &mut impl Hasher, ty: &ProductType, mut de: Deserializer<'_, impl BufReader<'a>>) -> HR {
ty.elements
.iter()
.try_for_each(|f| hash_bsatn(state, &f.algebraic_type, de.reborrow()))
}

/// Hashes every elem in the BSATN-encoded array value.
fn hash_bsatn_array<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
// The BSATN is length-prefixed.
// `Hash for &[T]` also does length-prefixing.
match ty {
AlgebraicType::Ref(_) => unreachable!("hash_bsatn does not have a typespace"),
AlgebraicType::Sum(ty) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_sum(s, ty, d)),
AlgebraicType::Product(ty) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_prod(s, ty, d)),
AlgebraicType::Builtin(BuiltinType::Array(ty)) => {
hash_bsatn_seq(state, de, |s, d| hash_bsatn_array(s, &ty.elem_ty, d))
}
AlgebraicType::Builtin(BuiltinType::Map(ty)) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_map(s, ty, d)),
&AlgebraicType::Bool => hash_bsatn_seq(state, de, hash_bsatn_de::<bool>),
&AlgebraicType::I8 | &AlgebraicType::U8 => hash_bsatn_int_seq(state, de, 1),
&AlgebraicType::I16 | &AlgebraicType::U16 => hash_bsatn_int_seq(state, de, 2),
&AlgebraicType::I32 | &AlgebraicType::U32 => hash_bsatn_int_seq(state, de, 4),
&AlgebraicType::I64 | &AlgebraicType::U64 => hash_bsatn_int_seq(state, de, 8),
&AlgebraicType::I128 | &AlgebraicType::U128 => hash_bsatn_int_seq(state, de, 16),
&AlgebraicType::F32 => hash_bsatn_seq(state, de, hash_bsatn_de::<F32>),
&AlgebraicType::F64 => hash_bsatn_seq(state, de, hash_bsatn_de::<F64>),
&AlgebraicType::String => hash_bsatn_seq(state, de, hash_bsatn_de::<&str>),
}
}

/// Hashes every (key, value) in the BSATN-encoded map value.
fn hash_bsatn_map<'a>(state: &mut impl Hasher, ty: &MapType, de: Deserializer<'_, impl BufReader<'a>>) -> HR {
// Hash each (key, value) pair but first length-prefix.
// This is OK as BSATN serializes the map in order
// and `BTreeMap` will hash the elements in order,
// so everything stays consistent.
hash_bsatn_seq(state, de, |state, mut de| {
hash_bsatn(state, &ty.key_ty, de.reborrow())?;
hash_bsatn(state, &ty.ty, de)?;
Ok(())
})
}

/// Hashes elements in the BSATN-encoded element sequence.
/// The sequence is prefixed with its length and the hash will as well.
fn hash_bsatn_seq<'a, H: Hasher, R: BufReader<'a>>(
state: &mut H,
mut de: Deserializer<'_, R>,
mut elem_hash: impl FnMut(&mut H, Deserializer<'_, R>) -> Result<(), DecodeError>,
) -> HR {
// The BSATN is length-prefixed.
// The Hash also needs to be length-prefixed.
let len = de.reborrow().deserialize_len()?;
state.write_usize(len);

// Hash each element.
(0..len).try_for_each(|_| elem_hash(state, de.reborrow()))
}

/// Hashes the BSATN-encoded integer sequence where each integer is `width` bytes wide.
/// The sequence is prefixed with its length and the hash will as well.
fn hash_bsatn_int_seq<'a, H: Hasher, R: BufReader<'a>>(state: &mut H, mut de: Deserializer<'_, R>, width: usize) -> HR {
// The BSATN is length-prefixed.
// The Hash also needs to be length-prefixed.
let len = de.reborrow().deserialize_len()?;
state.write_usize(len);

// Extract and hash the bytes.
//´This is consistent with what `<$int_primitive>::hash_slice` will do.
let bytes = de.get_slice(len * width)?;
state.write(bytes);
Ok(())
}

/// Deserializes from `de` an `x: T` and then proceeds to hash `x`.
fn hash_bsatn_de<'a, T: Hash + Deserialize<'a>>(
state: &mut impl Hasher,
de: Deserializer<'_, impl BufReader<'a>>,
) -> HR {
T::deserialize(de).map(|x| x.hash(state))
}

#[cfg(test)]
mod tests {
use crate::{
bsatn::{to_vec, Deserializer},
hash_bsatn,
proptest::generate_typed_value,
AlgebraicType, AlgebraicValue,
};
use proptest::prelude::*;
use std::hash::{BuildHasher, Hasher as _};

fn hash_one_bsatn_av(bh: &impl BuildHasher, ty: &AlgebraicType, val: &AlgebraicValue) -> u64 {
let mut bsatn = &*to_vec(&val).unwrap();
let de = Deserializer::new(&mut bsatn);
let mut hasher = bh.build_hasher();
hash_bsatn(&mut hasher, ty, de).unwrap();
hasher.finish()
}

proptest! {
#![proptest_config(ProptestConfig::with_cases(2048))]
#[test]
fn av_bsatn_hash_same_std_random_state((ty, val) in generate_typed_value()) {
let rs = std::hash::RandomState::new();
let hash_av = rs.hash_one(&val);
let hash_av_bsatn = hash_one_bsatn_av(&rs, &ty, &val);
prop_assert_eq!(hash_av, hash_av_bsatn);
}

#[test]
fn av_bsatn_hash_same_ahash((ty, val) in generate_typed_value()) {
let rs = ahash::RandomState::new();
let hash_av = rs.hash_one(&val);
let hash_av_bsatn = hash_one_bsatn_av(&rs, &ty, &val);
prop_assert_eq!(hash_av, hash_av_bsatn);
}
}
}
1 change: 1 addition & 0 deletions crates/sats/src/bsatn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::Typespace;
use smallvec::SmallVec;

pub mod de;
pub mod eq;
pub mod ser;

pub use de::Deserializer;
Expand Down
44 changes: 24 additions & 20 deletions crates/sats/src/bsatn/de.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::buffer::{BufReader, DecodeError};

use crate::de::{self, SeqProductAccess, SumAccess, VariantAccess};
use crate::de::{self, Deserializer as _, SeqProductAccess, SumAccess, VariantAccess};

/// Deserializer from the BSATN data format.
pub struct Deserializer<'a, R> {
Expand All @@ -16,9 +15,25 @@ impl<'a, 'de, R: BufReader<'de>> Deserializer<'a, R> {

/// Reborrows the deserializer.
#[inline]
fn reborrow(&mut self) -> Deserializer<'_, R> {
pub(crate) fn reborrow(&mut self) -> Deserializer<'_, R> {
Deserializer { reader: self.reader }
}

/// Reads a length as a `u32` then converted to `usize`.
pub(crate) fn deserialize_len(self) -> Result<usize, DecodeError> {
Ok(self.deserialize_u32()? as usize)
}

/// Reads a slice of `len` elements.
pub(crate) fn get_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError> {
self.reader.get_slice(len)
}

/// Reads a byte slice from the `reader`.
fn deserialize_bytes_inner(mut self) -> Result<&'de [u8], DecodeError> {
let len = self.reborrow().deserialize_len()?;
self.get_slice(len)
}
}

impl de::Error for DecodeError {
Expand All @@ -31,17 +46,6 @@ impl de::Error for DecodeError {
}
}

/// Read a length as a `u32` then converted to `usize`.
fn get_len<'de>(reader: &mut impl BufReader<'de>) -> Result<usize, DecodeError> {
Ok(reader.get_u32()? as usize)
}

/// Read a byte slice from the `reader`.
fn read_bytes<'a, 'de: 'a>(reader: &'a mut impl BufReader<'de>) -> Result<&'de [u8], DecodeError> {
let len = get_len(reader)?;
reader.get_slice(len)
}

impl<'de, 'a, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'a, R> {
type Error = DecodeError;

Expand Down Expand Up @@ -94,22 +98,22 @@ impl<'de, 'a, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'a, R> {
}

fn deserialize_str<V: de::SliceVisitor<'de, str>>(self, visitor: V) -> Result<V::Output, Self::Error> {
let slice = read_bytes(self.reader)?;
let slice = self.deserialize_bytes_inner()?;
let slice = core::str::from_utf8(slice)?;
visitor.visit_borrowed(slice)
}

fn deserialize_bytes<V: de::SliceVisitor<'de, [u8]>>(self, visitor: V) -> Result<V::Output, Self::Error> {
let slice = read_bytes(self.reader)?;
let slice = self.deserialize_bytes_inner()?;
visitor.visit_borrowed(slice)
}

fn deserialize_array_seed<V: de::ArrayVisitor<'de, T::Output>, T: de::DeserializeSeed<'de> + Clone>(
self,
mut self,
visitor: V,
seed: T,
) -> Result<V::Output, Self::Error> {
let len = get_len(self.reader)?;
let len = self.reborrow().deserialize_len()?;
let seeds = itertools::repeat_n(seed, len);
visitor.visit(ArrayAccess { de: self, seeds })
}
Expand All @@ -119,12 +123,12 @@ impl<'de, 'a, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'a, R> {
K: de::DeserializeSeed<'de> + Clone,
V: de::DeserializeSeed<'de> + Clone,
>(
self,
mut self,
visitor: Vi,
kseed: K,
vseed: V,
) -> Result<Vi::Output, Self::Error> {
let len = get_len(self.reader)?;
let len = self.reborrow().deserialize_len()?;
let seeds = itertools::repeat_n((kseed, vseed), len);
visitor.visit(MapAccess { de: self, seeds })
}
Expand Down
Loading