From 134c7526f3721ceac9037c835fa10080240536ff Mon Sep 17 00:00:00 2001 From: Mazdak Farrokhzad Date: Wed, 10 Apr 2024 21:23:23 +0200 Subject: [PATCH 1/2] add hash_bsatn + move proptest generators to sats crate --- Cargo.lock | 1 + crates/sats/Cargo.toml | 1 + .../algebraic_value_hash.txt | 7 + crates/sats/src/algebraic_value_hash.rs | 166 +++++++++++++++++- crates/sats/src/bsatn/de.rs | 44 ++--- crates/sats/src/lib.rs | 4 + .../proptest_sats.rs => sats/src/proptest.rs} | 27 +-- crates/table/Cargo.toml | 2 +- crates/table/src/bflatn_to.rs | 2 +- crates/table/src/bflatn_to_bsatn_fast_path.rs | 4 +- crates/table/src/btree_index.rs | 11 +- crates/table/src/layout.rs | 2 +- crates/table/src/lib.rs | 3 - crates/table/src/read_column.rs | 5 +- crates/table/src/table.rs | 2 +- 15 files changed, 231 insertions(+), 50 deletions(-) create mode 100644 crates/sats/proptest-regressions/algebraic_value_hash.txt rename crates/{table/src/proptest_sats.rs => sats/src/proptest.rs} (91%) diff --git a/Cargo.lock b/Cargo.lock index 9ba129daba8..f1e5286c07c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4641,6 +4641,7 @@ dependencies = [ name = "spacetimedb-sats" version = "0.8.2" dependencies = [ + "ahash 0.8.3", "arrayvec", "bitflags 2.4.1", "bytes", diff --git a/crates/sats/Cargo.toml b/crates/sats/Cargo.toml index 4200a084412..dbc7c101154 100644 --- a/crates/sats/Cargo.toml +++ b/crates/sats/Cargo.toml @@ -34,6 +34,7 @@ smallvec.workspace = true thiserror.workspace = true [dev-dependencies] +ahash.workspace = true bytes.workspace = true rand.workspace = true # Also as dev-dependencies for use in _this_ crate's tests. diff --git a/crates/sats/proptest-regressions/algebraic_value_hash.txt b/crates/sats/proptest-regressions/algebraic_value_hash.txt new file mode 100644 index 00000000000..29a270921a6 --- /dev/null +++ b/crates/sats/proptest-regressions/algebraic_value_hash.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc aaa05e16925268348653cb8c1945d820f2f8da931fd7ff9a895178d443e0e64f # shrinks to (ty, val) = (Builtin(Array(ArrayType { elem_ty: Builtin(U8) })), Array([])) diff --git a/crates/sats/src/algebraic_value_hash.rs b/crates/sats/src/algebraic_value_hash.rs index 216056b262f..469866564ef 100644 --- a/crates/sats/src/algebraic_value_hash.rs +++ b/crates/sats/src/algebraic_value_hash.rs @@ -1,6 +1,11 @@ //! Defines hash functions for `AlgebraicValue` and friends. -use crate::{AlgebraicValue, ArrayValue, ProductValue}; +use crate::{ + bsatn::Deserializer, + buffer::{BufReader, DecodeError}, + de::{Deserialize, Deserializer as _}, + AlgebraicType, AlgebraicValue, ArrayValue, BuiltinType, MapType, ProductType, ProductValue, SumType, F32, F64, +}; use core::hash::{Hash, Hasher}; // We only manually implement those hash functions that cannot be `#[derive(Hash)]`ed. @@ -85,3 +90,162 @@ impl Hash for ArrayValue { } } } + +type HR = Result<(), DecodeError>; + +pub fn hash_bsatn<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR { + match ty { + AlgebraicType::Ref(_) => unreachable!("hash_bsatn does not have a typespace"), + AlgebraicType::Sum(ty) => hash_bsatn_sum(state, ty, de), + AlgebraicType::Product(ty) => hash_bsatn_prod(state, ty, de), + AlgebraicType::Builtin(BuiltinType::Array(ty)) => hash_bsatn_array(state, &ty.elem_ty, de), + AlgebraicType::Builtin(BuiltinType::Map(ty)) => hash_bsatn_map(state, ty, de), + &AlgebraicType::Bool => hash_bsatn_de::(state, de), + &AlgebraicType::I8 => hash_bsatn_de::(state, de), + &AlgebraicType::U8 => hash_bsatn_de::(state, de), + &AlgebraicType::I16 => hash_bsatn_de::(state, de), + &AlgebraicType::U16 => hash_bsatn_de::(state, de), + &AlgebraicType::I32 => hash_bsatn_de::(state, de), + &AlgebraicType::U32 => hash_bsatn_de::(state, de), + &AlgebraicType::I64 => hash_bsatn_de::(state, de), + &AlgebraicType::U64 => hash_bsatn_de::(state, de), + &AlgebraicType::I128 => hash_bsatn_de::(state, de), + &AlgebraicType::U128 => hash_bsatn_de::(state, de), + &AlgebraicType::F32 => hash_bsatn_de::(state, de), + &AlgebraicType::F64 => hash_bsatn_de::(state, de), + &AlgebraicType::String => hash_bsatn_de::<&str>(state, de), + } +} + +/// Hashes the tag and payload of the BSATN-encoded sum value. +fn hash_bsatn_sum<'a>(state: &mut impl Hasher, ty: &SumType, mut de: Deserializer<'_, impl BufReader<'a>>) -> HR { + // Read + hash the tag. + let tag = de.reborrow().deserialize_u8()?; + tag.hash(state); + + // Hash the payload. + let data_ty = &ty.variants[tag as usize].algebraic_type; + hash_bsatn(state, data_ty, de) +} + +/// Hashes every field in the BSATN-encoded product value. +fn hash_bsatn_prod<'a>(state: &mut impl Hasher, ty: &ProductType, mut de: Deserializer<'_, impl BufReader<'a>>) -> HR { + ty.elements + .iter() + .try_for_each(|f| hash_bsatn(state, &f.algebraic_type, de.reborrow())) +} + +/// Hashes every elem in the BSATN-encoded array value. +fn hash_bsatn_array<'a>(state: &mut impl Hasher, ty: &AlgebraicType, de: Deserializer<'_, impl BufReader<'a>>) -> HR { + // The BSATN is length-prefixed. + // `Hash for &[T]` also does length-prefixing. + match ty { + AlgebraicType::Ref(_) => unreachable!("hash_bsatn does not have a typespace"), + AlgebraicType::Sum(ty) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_sum(s, ty, d)), + AlgebraicType::Product(ty) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_prod(s, ty, d)), + AlgebraicType::Builtin(BuiltinType::Array(ty)) => { + hash_bsatn_seq(state, de, |s, d| hash_bsatn_array(s, &ty.elem_ty, d)) + } + AlgebraicType::Builtin(BuiltinType::Map(ty)) => hash_bsatn_seq(state, de, |s, d| hash_bsatn_map(s, ty, d)), + &AlgebraicType::Bool => hash_bsatn_seq(state, de, hash_bsatn_de::), + &AlgebraicType::I8 | &AlgebraicType::U8 => hash_bsatn_int_seq(state, de, 1), + &AlgebraicType::I16 | &AlgebraicType::U16 => hash_bsatn_int_seq(state, de, 2), + &AlgebraicType::I32 | &AlgebraicType::U32 => hash_bsatn_int_seq(state, de, 4), + &AlgebraicType::I64 | &AlgebraicType::U64 => hash_bsatn_int_seq(state, de, 8), + &AlgebraicType::I128 | &AlgebraicType::U128 => hash_bsatn_int_seq(state, de, 16), + &AlgebraicType::F32 => hash_bsatn_seq(state, de, hash_bsatn_de::), + &AlgebraicType::F64 => hash_bsatn_seq(state, de, hash_bsatn_de::), + &AlgebraicType::String => hash_bsatn_seq(state, de, hash_bsatn_de::<&str>), + } +} + +/// Hashes every (key, value) in the BSATN-encoded map value. +fn hash_bsatn_map<'a>(state: &mut impl Hasher, ty: &MapType, de: Deserializer<'_, impl BufReader<'a>>) -> HR { + // Hash each (key, value) pair but first length-prefix. + // This is OK as BSATN serializes the map in order + // and `BTreeMap` will hash the elements in order, + // so everything stays consistent. + hash_bsatn_seq(state, de, |state, mut de| { + hash_bsatn(state, &ty.key_ty, de.reborrow())?; + hash_bsatn(state, &ty.ty, de)?; + Ok(()) + }) +} + +/// Hashes elements in the BSATN-encoded element sequence. +/// The sequence is prefixed with its length and the hash will as well. +fn hash_bsatn_seq<'a, H: Hasher, R: BufReader<'a>>( + state: &mut H, + mut de: Deserializer<'_, R>, + mut elem_hash: impl FnMut(&mut H, Deserializer<'_, R>) -> Result<(), DecodeError>, +) -> HR { + // The BSATN is length-prefixed. + // The Hash also needs to be length-prefixed. + let len = de.reborrow().deserialize_len()?; + state.write_usize(len); + + // Hash each element. + (0..len).try_for_each(|_| elem_hash(state, de.reborrow())) +} + +/// Hashes the BSATN-encoded integer sequence where each integer is `width` bytes wide. +/// The sequence is prefixed with its length and the hash will as well. +fn hash_bsatn_int_seq<'a, H: Hasher, R: BufReader<'a>>(state: &mut H, mut de: Deserializer<'_, R>, width: usize) -> HR { + // The BSATN is length-prefixed. + // The Hash also needs to be length-prefixed. + let len = de.reborrow().deserialize_len()?; + state.write_usize(len); + + // Extract and hash the bytes. + //´This is consistent with what `<$int_primitive>::hash_slice` will do. + let bytes = de.get_slice(len * width)?; + state.write(bytes); + Ok(()) +} + +/// Deserializes from `de` an `x: T` and then proceeds to hash `x`. +fn hash_bsatn_de<'a, T: Hash + Deserialize<'a>>( + state: &mut impl Hasher, + de: Deserializer<'_, impl BufReader<'a>>, +) -> HR { + T::deserialize(de).map(|x| x.hash(state)) +} + +#[cfg(test)] +mod tests { + use crate::{ + bsatn::{to_vec, Deserializer}, + hash_bsatn, + proptest::generate_typed_value, + AlgebraicType, AlgebraicValue, + }; + use proptest::prelude::*; + use std::hash::{BuildHasher, Hasher as _}; + + fn hash_one_bsatn_av(bh: &impl BuildHasher, ty: &AlgebraicType, val: &AlgebraicValue) -> u64 { + let mut bsatn = &*to_vec(&val).unwrap(); + let de = Deserializer::new(&mut bsatn); + let mut hasher = bh.build_hasher(); + hash_bsatn(&mut hasher, ty, de).unwrap(); + hasher.finish() + } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(2048))] + #[test] + fn av_bsatn_hash_same_std_random_state((ty, val) in generate_typed_value()) { + let rs = std::hash::RandomState::new(); + let hash_av = rs.hash_one(&val); + let hash_av_bsatn = hash_one_bsatn_av(&rs, &ty, &val); + prop_assert_eq!(hash_av, hash_av_bsatn); + } + + #[test] + fn av_bsatn_hash_same_ahash((ty, val) in generate_typed_value()) { + let rs = ahash::RandomState::new(); + let hash_av = rs.hash_one(&val); + let hash_av_bsatn = hash_one_bsatn_av(&rs, &ty, &val); + prop_assert_eq!(hash_av, hash_av_bsatn); + } + } +} diff --git a/crates/sats/src/bsatn/de.rs b/crates/sats/src/bsatn/de.rs index e8c2a2de88b..fdc3232ea97 100644 --- a/crates/sats/src/bsatn/de.rs +++ b/crates/sats/src/bsatn/de.rs @@ -1,6 +1,5 @@ use crate::buffer::{BufReader, DecodeError}; - -use crate::de::{self, SeqProductAccess, SumAccess, VariantAccess}; +use crate::de::{self, Deserializer as _, SeqProductAccess, SumAccess, VariantAccess}; /// Deserializer from the BSATN data format. pub struct Deserializer<'a, R> { @@ -16,9 +15,25 @@ impl<'a, 'de, R: BufReader<'de>> Deserializer<'a, R> { /// Reborrows the deserializer. #[inline] - fn reborrow(&mut self) -> Deserializer<'_, R> { + pub(crate) fn reborrow(&mut self) -> Deserializer<'_, R> { Deserializer { reader: self.reader } } + + /// Reads a length as a `u32` then converted to `usize`. + pub(crate) fn deserialize_len(self) -> Result { + Ok(self.deserialize_u32()? as usize) + } + + /// Reads a slice of `len` elements. + pub(crate) fn get_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError> { + self.reader.get_slice(len) + } + + /// Reads a byte slice from the `reader`. + fn deserialize_bytes_inner(mut self) -> Result<&'de [u8], DecodeError> { + let len = self.reborrow().deserialize_len()?; + self.get_slice(len) + } } impl de::Error for DecodeError { @@ -31,17 +46,6 @@ impl de::Error for DecodeError { } } -/// Read a length as a `u32` then converted to `usize`. -fn get_len<'de>(reader: &mut impl BufReader<'de>) -> Result { - Ok(reader.get_u32()? as usize) -} - -/// Read a byte slice from the `reader`. -fn read_bytes<'a, 'de: 'a>(reader: &'a mut impl BufReader<'de>) -> Result<&'de [u8], DecodeError> { - let len = get_len(reader)?; - reader.get_slice(len) -} - impl<'de, 'a, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'a, R> { type Error = DecodeError; @@ -94,22 +98,22 @@ impl<'de, 'a, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'a, R> { } fn deserialize_str>(self, visitor: V) -> Result { - let slice = read_bytes(self.reader)?; + let slice = self.deserialize_bytes_inner()?; let slice = core::str::from_utf8(slice)?; visitor.visit_borrowed(slice) } fn deserialize_bytes>(self, visitor: V) -> Result { - let slice = read_bytes(self.reader)?; + let slice = self.deserialize_bytes_inner()?; visitor.visit_borrowed(slice) } fn deserialize_array_seed, T: de::DeserializeSeed<'de> + Clone>( - self, + mut self, visitor: V, seed: T, ) -> Result { - let len = get_len(self.reader)?; + let len = self.reborrow().deserialize_len()?; let seeds = itertools::repeat_n(seed, len); visitor.visit(ArrayAccess { de: self, seeds }) } @@ -119,12 +123,12 @@ impl<'de, 'a, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'a, R> { K: de::DeserializeSeed<'de> + Clone, V: de::DeserializeSeed<'de> + Clone, >( - self, + mut self, visitor: Vi, kseed: K, vseed: V, ) -> Result { - let len = get_len(self.reader)?; + let len = self.reborrow().deserialize_len()?; let seeds = itertools::repeat_n((kseed, vseed), len); visitor.visit(MapAccess { de: self, seeds }) } diff --git a/crates/sats/src/lib.rs b/crates/sats/src/lib.rs index 2e473e55668..34fab7cb840 100644 --- a/crates/sats/src/lib.rs +++ b/crates/sats/src/lib.rs @@ -27,9 +27,13 @@ pub mod sum_type_variant; pub mod sum_value; pub mod typespace; +#[cfg(any(test, feature = "proptest"))] +pub mod proptest; + pub use algebraic_type::AlgebraicType; pub use algebraic_type_ref::AlgebraicTypeRef; pub use algebraic_value::{AlgebraicValue, F32, F64}; +pub use algebraic_value_hash::hash_bsatn; pub use array_type::ArrayType; pub use array_value::ArrayValue; pub use builtin_type::BuiltinType; diff --git a/crates/table/src/proptest_sats.rs b/crates/sats/src/proptest.rs similarity index 91% rename from crates/table/src/proptest_sats.rs rename to crates/sats/src/proptest.rs index c772c041f08..373f009ebbd 100644 --- a/crates/table/src/proptest_sats.rs +++ b/crates/sats/src/proptest.rs @@ -2,6 +2,10 @@ //! //! This notably excludes `Ref` types. +use crate::{ + AlgebraicType, AlgebraicValue, ArrayValue, BuiltinType, MapType, MapValue, ProductType, ProductValue, SumType, + SumValue, F32, F64, +}; use proptest::{ collection::{vec, SizeRange}, prelude::*, @@ -9,10 +13,8 @@ use proptest::{ strategy::Just, strategy::{BoxedStrategy, Strategy}, }; -use spacetimedb_sats::{ - AlgebraicType, AlgebraicValue, ArrayValue, BuiltinType, MapType, MapValue, ProductType, ProductValue, SumType, - SumValue, F32, F64, -}; + +const SIZE: usize = 16; /// Generates leaf (i.e. non-compound) `AlgebraicType`s. /// @@ -47,7 +49,7 @@ fn generate_non_compound_algebraic_type() -> impl Strategy impl Strategy { - generate_non_compound_algebraic_type().prop_recursive(4, 16, 16, |gen_element| { + generate_non_compound_algebraic_type().prop_recursive(4, SIZE as u32, SIZE as u32, |gen_element| { prop_oneof![ gen_element.clone().prop_map(AlgebraicType::array), (gen_element.clone(), gen_element.clone()).prop_map(|(key, val)| AlgebraicType::map(key, val)), @@ -55,11 +57,11 @@ pub fn generate_algebraic_type() -> impl Strategy { // No need to generate units here; // we already generate them in `generate_non_compound_algebraic_type`. - vec(gen_element.clone().prop_map_into(), 1..=16) + vec(gen_element.clone().prop_map_into(), 1..=SIZE) .prop_map(Vec::into_boxed_slice) .prop_map(AlgebraicType::product), // Do not generate nevers here; we can't store never in a page. - vec(gen_element.clone().prop_map_into(), 1..=16) + vec(gen_element.clone().prop_map_into(), 1..=SIZE) .prop_map(Vec::into_boxed_slice) .prop_map(AlgebraicType::sum), ] @@ -135,7 +137,7 @@ fn generate_sum_value(ty: SumType) -> impl Strategy { fn generate_map_value(ty: MapType) -> impl Strategy { vec( (generate_algebraic_value(ty.key_ty), generate_algebraic_value(ty.ty)), - 0..=16, + 0..=SIZE, ) .prop_map(|entries| entries.into_iter().collect()) } @@ -146,7 +148,7 @@ where S: Strategy + 'static, Box<[S::Value]>: 'static + Into, { - vec(gen_elem, 0..=16) + vec(gen_elem, 0..=SIZE) .prop_map(Vec::into_boxed_slice) .prop_map_into() .boxed() @@ -179,5 +181,10 @@ fn generate_array_value(ty: AlgebraicType) -> BoxedStrategy { /// Generates a row type `ty` and a row value typed at `ty`. pub fn generate_typed_row() -> impl Strategy { - generate_row_type(0..=16).prop_flat_map(|ty| (Just(ty.clone()), generate_product_value(ty))) + generate_row_type(0..=SIZE).prop_flat_map(|ty| (Just(ty.clone()), generate_product_value(ty))) +} + +/// Generates a type `ty` and a value typed at `ty`. +pub fn generate_typed_value() -> impl Strategy { + generate_algebraic_type().prop_flat_map(|ty| (Just(ty.clone()), generate_algebraic_value(ty))) } diff --git a/crates/table/Cargo.toml b/crates/table/Cargo.toml index 752a304cc5e..19bd93f7977 100644 --- a/crates/table/Cargo.toml +++ b/crates/table/Cargo.toml @@ -23,7 +23,7 @@ harness = false [features] # Allows using `Arbitrary` impls defined in this crate. -proptest = ["dep:proptest", "dep:proptest-derive"] +proptest = ["dep:proptest", "dep:proptest-derive", "spacetimedb-sats/proptest"] # Needed for miri blake3_pure = ["blake3/pure"] diff --git a/crates/table/src/bflatn_to.rs b/crates/table/src/bflatn_to.rs index 7941387a420..66cfabdbf10 100644 --- a/crates/table/src/bflatn_to.rs +++ b/crates/table/src/bflatn_to.rs @@ -450,12 +450,12 @@ impl BflatnSerializedRowBuffer<'_> { #[cfg(test)] pub mod test { use super::*; - use crate::proptest_sats::generate_typed_row; use crate::{ bflatn_from::serialize_row_from_page, blob_store::HashMapBlobStore, row_type_visitor::row_type_visitor, }; use proptest::{prelude::*, prop_assert_eq, proptest}; use spacetimedb_sats::algebraic_value::ser::ValueSerializer; + use spacetimedb_sats::proptest::generate_typed_row; proptest! { #![proptest_config(ProptestConfig::with_cases(2048))] diff --git a/crates/table/src/bflatn_to_bsatn_fast_path.rs b/crates/table/src/bflatn_to_bsatn_fast_path.rs index c745e4c55fa..033736cfdb0 100644 --- a/crates/table/src/bflatn_to_bsatn_fast_path.rs +++ b/crates/table/src/bflatn_to_bsatn_fast_path.rs @@ -274,9 +274,9 @@ impl LayoutBuilder { #[cfg(test)] mod test { use super::*; - use crate::{blob_store::HashMapBlobStore, proptest_sats::generate_typed_row}; + use crate::blob_store::HashMapBlobStore; use proptest::prelude::*; - use spacetimedb_sats::{bsatn, AlgebraicType, ProductType}; + use spacetimedb_sats::{bsatn, proptest::generate_typed_row, AlgebraicType, ProductType}; fn assert_expected_layout(ty: ProductType, bsatn_length: u16, fields: &[(u16, u16, u16)]) { let expected_layout = StaticBsatnLayout { diff --git a/crates/table/src/btree_index.rs b/crates/table/src/btree_index.rs index 05d3cf70acf..ae391b301b5 100644 --- a/crates/table/src/btree_index.rs +++ b/crates/table/src/btree_index.rs @@ -441,12 +441,7 @@ impl BTreeIndex { #[cfg(test)] mod test { use super::*; - use crate::{ - blob_store::HashMapBlobStore, - indexes::SquashedOffset, - proptest_sats::{generate_product_value, generate_row_type}, - table::Table, - }; + use crate::{blob_store::HashMapBlobStore, indexes::SquashedOffset, table::Table}; use core::ops::Bound::*; use proptest::prelude::*; use proptest::{collection::vec, test_runner::TestCaseResult}; @@ -454,7 +449,9 @@ mod test { use spacetimedb_primitives::ColListBuilder; use spacetimedb_sats::{ db::def::{TableDef, TableSchema}, - product, AlgebraicType, ProductType, ProductValue, + product, + proptest::{generate_product_value, generate_row_type}, + AlgebraicType, ProductType, ProductValue, }; fn gen_cols(ty_len: usize) -> impl Strategy { diff --git a/crates/table/src/layout.rs b/crates/table/src/layout.rs index c463cf037db..2320b3d7a40 100644 --- a/crates/table/src/layout.rs +++ b/crates/table/src/layout.rs @@ -620,10 +620,10 @@ pub fn bsatn_len(val: &AlgebraicValue) -> usize { #[cfg(test)] mod test { use super::*; - use crate::proptest_sats::generate_algebraic_type; use itertools::Itertools; use proptest::collection::vec; use proptest::prelude::*; + use spacetimedb_sats::proptest::generate_algebraic_type; #[test] fn align_to_expected() { diff --git a/crates/table/src/lib.rs b/crates/table/src/lib.rs index f4b1be0db5c..9f6e4941c85 100644 --- a/crates/table/src/lib.rs +++ b/crates/table/src/lib.rs @@ -23,8 +23,5 @@ pub mod row_type_visitor; pub mod table; pub mod var_len; -#[cfg(test)] -mod proptest_sats; - #[doc(hidden)] // Used in tests and benchmarks. pub mod util; diff --git a/crates/table/src/read_column.rs b/crates/table/src/read_column.rs index bd673dc03ba..706054d2760 100644 --- a/crates/table/src/read_column.rs +++ b/crates/table/src/read_column.rs @@ -343,13 +343,12 @@ impl_read_column_via_from! { #[cfg(test)] mod test { use super::*; - use crate::{ - blob_store::HashMapBlobStore, indexes::SquashedOffset, proptest_sats::generate_typed_row, table::Table, - }; + use crate::{blob_store::HashMapBlobStore, indexes::SquashedOffset, table::Table}; use proptest::{prelude::*, prop_assert_eq, proptest, test_runner::TestCaseResult}; use spacetimedb_sats::{ db::def::{TableDef, TableSchema}, product, + proptest::generate_typed_row, }; fn table(ty: ProductType) -> Table { diff --git a/crates/table/src/table.rs b/crates/table/src/table.rs index 965c49c2da8..6ab0378e7a5 100644 --- a/crates/table/src/table.rs +++ b/crates/table/src/table.rs @@ -973,11 +973,11 @@ pub(crate) mod test { use super::*; use crate::blob_store::HashMapBlobStore; use crate::indexes::{PageIndex, PageOffset}; - use crate::proptest_sats::generate_typed_row; use proptest::prelude::*; use proptest::test_runner::TestCaseResult; use spacetimedb_sats::bsatn::to_vec; use spacetimedb_sats::db::def::{ColumnDef, IndexDef, IndexType, TableDef}; + use spacetimedb_sats::proptest::generate_typed_row; use spacetimedb_sats::{product, AlgebraicType, ArrayValue}; pub(crate) fn table(ty: ProductType) -> Table { From 5ef9bd34cd05ca8fb0f9a2f6503c52abbddf9ba8 Mon Sep 17 00:00:00 2001 From: Mazdak Farrokhzad Date: Thu, 18 Apr 2024 15:27:20 +0200 Subject: [PATCH 2/2] add eq_bsatn --- crates/sats/src/bsatn.rs | 1 + crates/sats/src/bsatn/eq.rs | 144 ++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 crates/sats/src/bsatn/eq.rs diff --git a/crates/sats/src/bsatn.rs b/crates/sats/src/bsatn.rs index 8a892a98e4e..577e57a0081 100644 --- a/crates/sats/src/bsatn.rs +++ b/crates/sats/src/bsatn.rs @@ -5,6 +5,7 @@ use crate::Typespace; use smallvec::SmallVec; pub mod de; +pub mod eq; pub mod ser; pub use de::Deserializer; diff --git a/crates/sats/src/bsatn/eq.rs b/crates/sats/src/bsatn/eq.rs new file mode 100644 index 00000000000..a9537d295b0 --- /dev/null +++ b/crates/sats/src/bsatn/eq.rs @@ -0,0 +1,144 @@ +//! Defines the function [`eq_bsatn`] which equates `lhs: &AlgebraicValue` to `rhs` defined in BSATN. +//! +//! The lifetime `'r` in `eq_bsatn` is the lifetime of `rhs`'s backing data, i.e., the BSATN itself. + +use super::Deserializer; +use crate::{buffer::BufReader, de::Deserialize, AlgebraicValue, ArrayValue, MapValue, ProductValue, SumValue}; +use core::{mem, slice}; + +/// Equates `lhs` to a BSATN-encoded `AlgebraicValue` of the same type. +pub fn eq_bsatn<'r>(lhs: &AlgebraicValue, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool { + match lhs { + AlgebraicValue::Sum(lhs) => eq_bsatn_sum(lhs, rhs), + AlgebraicValue::Product(lhs) => eq_bsatn_prod(lhs, rhs), + AlgebraicValue::Array(lhs) => eq_bsatn_array(lhs, rhs), + AlgebraicValue::Map(lhs) => eq_bsatn_map(lhs, rhs), + AlgebraicValue::Bool(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::I8(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::U8(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::I16(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::U16(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::I32(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::U32(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::I64(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::U64(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::I128(lhs) => eq_bsatn_de(&{ lhs.0 }, rhs), + AlgebraicValue::U128(lhs) => eq_bsatn_de(&{ lhs.0 }, rhs), + AlgebraicValue::F32(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::F64(lhs) => eq_bsatn_de(lhs, rhs), + AlgebraicValue::String(lhs) => eq_bsatn_str(lhs, rhs), + } +} + +/// Equates the tag and payload to that of the BSATN-encoded sum value. +fn eq_bsatn_sum<'r>(lhs: &SumValue, mut rhs: Deserializer<'_, impl BufReader<'r>>) -> bool { + eq_bsatn_de(&lhs.tag, rhs.reborrow()) && eq_bsatn(&lhs.value, rhs) +} + +/// Equates every field `lhs` to those in the BSATN-encoded product value. +fn eq_bsatn_prod<'r>(lhs: &ProductValue, mut rhs: Deserializer<'_, impl BufReader<'r>>) -> bool { + lhs.elements.iter().all(|f| eq_bsatn(f, rhs.reborrow())) +} + +/// Equates `lhs` to the `(key, value)`s in the BSATN-encoded map value. +fn eq_bsatn_map<'r>(lhs: &MapValue, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool { + eq_bsatn_seq(lhs, rhs, |(key, value), mut rhs| { + eq_bsatn(key, rhs.reborrow()) && eq_bsatn(value, rhs) + }) +} + +/// Equates every elem in `lhs` to those in the BSATN-encoded array value. +fn eq_bsatn_array<'r>(lhs: &ArrayValue, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool { + match lhs { + ArrayValue::Sum(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_sum), + ArrayValue::Product(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_prod), + ArrayValue::Bool(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_de), + ArrayValue::F32(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_de), + ArrayValue::F64(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_de), + ArrayValue::String(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_str), + ArrayValue::Array(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_array), + ArrayValue::Map(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_map), + // SAFETY: For all of the below, the element types are integer types, as required. + ArrayValue::I8(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) }, + ArrayValue::U8(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) }, + ArrayValue::I16(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) }, + ArrayValue::U16(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) }, + ArrayValue::I32(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) }, + ArrayValue::U32(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) }, + ArrayValue::I64(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) }, + ArrayValue::U64(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) }, + ArrayValue::I128(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) }, + ArrayValue::U128(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) }, + } +} + +/// Equates the integer slice `lhs` to the BSATN-encoded one in `rhs`. +/// +/// SAFETY: `T` must be an integer type. +unsafe fn eq_bsatn_int_seq<'r, T>(lhs: &[T], mut rhs: Deserializer<'_, impl BufReader<'r>>) -> bool { + // The BSATN is length-prefixed. + let Ok(len) = rhs.reborrow().deserialize_len() else { + return false; + }; + + // Extract the rhs bytes. + let Ok(rhs_bytes) = rhs.get_slice(len * mem::size_of::()) else { + return false; + }; + + // Convert `lhs` to `&[u8]`. + let ptr = lhs.as_ptr().cast::(); + // SAFETY: Caller promised that `T` is an integer type. + // Thus it has no safety requirements and no padding, + // so it is legal to convert `&[IntType] -> &[u8]`. + let lhs_bytes = unsafe { slice::from_raw_parts(ptr, mem::size_of_val(lhs)) }; + + lhs_bytes == rhs_bytes +} + +/// Equates the string `lhs` to the BSATN-encoded one in `rhs`. +#[allow(clippy::borrowed_box)] +fn eq_bsatn_str<'r>(lhs: &Box, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool { + <&str>::deserialize(rhs).map(|rhs| &**lhs == rhs).unwrap_or(false) +} + +/// Equates elements in `lhs` to the BSATN-encoded element sequence in `rhs`. +/// The sequence is prefixed with its length. +fn eq_bsatn_seq<'r, T, I: ExactSizeIterator, R: BufReader<'r>>( + lhs: impl IntoIterator, + mut rhs: Deserializer<'_, R>, + elem_eq: impl Fn(T, Deserializer<'_, R>) -> bool, +) -> bool { + let mut lhs = lhs.into_iter(); + // The BSATN is length-prefixed. + // Compare against length first. + match rhs.reborrow().deserialize_len() { + Ok(len) if lhs.len() == len => lhs.all(|e| elem_eq(e, rhs.reborrow())), + _ => false, + } +} + +/// Deserializes from `de` an `rhs: T` and then proceeds to `lhs == rhs`. +fn eq_bsatn_de<'r, T: Eq + Deserialize<'r>>(lhs: &T, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool { + T::deserialize(rhs).map(|rhs| lhs == &rhs).unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::eq_bsatn; + use crate::{ + bsatn::{to_vec, Deserializer}, + proptest::generate_typed_value, + }; + use proptest::prelude::*; + + proptest! { + #![proptest_config(ProptestConfig::with_cases(2048))] + #[test] + fn encoded_val_eq_to_self((_, val) in generate_typed_value()) { + let mut bsatn = &*to_vec(&val).unwrap(); + let de = Deserializer::new(&mut bsatn); + prop_assert!(eq_bsatn(&val, de)); + } + } +}