Skip to content

Commit 5ef9bd3

Browse files
committed
add eq_bsatn
1 parent 134c752 commit 5ef9bd3

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed

crates/sats/src/bsatn.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::Typespace;
55
use smallvec::SmallVec;
66

77
pub mod de;
8+
pub mod eq;
89
pub mod ser;
910

1011
pub use de::Deserializer;

crates/sats/src/bsatn/eq.rs

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
//! Defines the function [`eq_bsatn`] which equates `lhs: &AlgebraicValue` to `rhs` defined in BSATN.
2+
//!
3+
//! The lifetime `'r` in `eq_bsatn` is the lifetime of `rhs`'s backing data, i.e., the BSATN itself.
4+
5+
use super::Deserializer;
6+
use crate::{buffer::BufReader, de::Deserialize, AlgebraicValue, ArrayValue, MapValue, ProductValue, SumValue};
7+
use core::{mem, slice};
8+
9+
/// Equates `lhs` to a BSATN-encoded `AlgebraicValue` of the same type.
10+
pub fn eq_bsatn<'r>(lhs: &AlgebraicValue, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
11+
match lhs {
12+
AlgebraicValue::Sum(lhs) => eq_bsatn_sum(lhs, rhs),
13+
AlgebraicValue::Product(lhs) => eq_bsatn_prod(lhs, rhs),
14+
AlgebraicValue::Array(lhs) => eq_bsatn_array(lhs, rhs),
15+
AlgebraicValue::Map(lhs) => eq_bsatn_map(lhs, rhs),
16+
AlgebraicValue::Bool(lhs) => eq_bsatn_de(lhs, rhs),
17+
AlgebraicValue::I8(lhs) => eq_bsatn_de(lhs, rhs),
18+
AlgebraicValue::U8(lhs) => eq_bsatn_de(lhs, rhs),
19+
AlgebraicValue::I16(lhs) => eq_bsatn_de(lhs, rhs),
20+
AlgebraicValue::U16(lhs) => eq_bsatn_de(lhs, rhs),
21+
AlgebraicValue::I32(lhs) => eq_bsatn_de(lhs, rhs),
22+
AlgebraicValue::U32(lhs) => eq_bsatn_de(lhs, rhs),
23+
AlgebraicValue::I64(lhs) => eq_bsatn_de(lhs, rhs),
24+
AlgebraicValue::U64(lhs) => eq_bsatn_de(lhs, rhs),
25+
AlgebraicValue::I128(lhs) => eq_bsatn_de(&{ lhs.0 }, rhs),
26+
AlgebraicValue::U128(lhs) => eq_bsatn_de(&{ lhs.0 }, rhs),
27+
AlgebraicValue::F32(lhs) => eq_bsatn_de(lhs, rhs),
28+
AlgebraicValue::F64(lhs) => eq_bsatn_de(lhs, rhs),
29+
AlgebraicValue::String(lhs) => eq_bsatn_str(lhs, rhs),
30+
}
31+
}
32+
33+
/// Equates the tag and payload to that of the BSATN-encoded sum value.
34+
fn eq_bsatn_sum<'r>(lhs: &SumValue, mut rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
35+
eq_bsatn_de(&lhs.tag, rhs.reborrow()) && eq_bsatn(&lhs.value, rhs)
36+
}
37+
38+
/// Equates every field `lhs` to those in the BSATN-encoded product value.
39+
fn eq_bsatn_prod<'r>(lhs: &ProductValue, mut rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
40+
lhs.elements.iter().all(|f| eq_bsatn(f, rhs.reborrow()))
41+
}
42+
43+
/// Equates `lhs` to the `(key, value)`s in the BSATN-encoded map value.
44+
fn eq_bsatn_map<'r>(lhs: &MapValue, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
45+
eq_bsatn_seq(lhs, rhs, |(key, value), mut rhs| {
46+
eq_bsatn(key, rhs.reborrow()) && eq_bsatn(value, rhs)
47+
})
48+
}
49+
50+
/// Equates every elem in `lhs` to those in the BSATN-encoded array value.
51+
fn eq_bsatn_array<'r>(lhs: &ArrayValue, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
52+
match lhs {
53+
ArrayValue::Sum(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_sum),
54+
ArrayValue::Product(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_prod),
55+
ArrayValue::Bool(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_de),
56+
ArrayValue::F32(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_de),
57+
ArrayValue::F64(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_de),
58+
ArrayValue::String(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_str),
59+
ArrayValue::Array(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_array),
60+
ArrayValue::Map(lhs) => eq_bsatn_seq(&**lhs, rhs, eq_bsatn_map),
61+
// SAFETY: For all of the below, the element types are integer types, as required.
62+
ArrayValue::I8(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
63+
ArrayValue::U8(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
64+
ArrayValue::I16(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
65+
ArrayValue::U16(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
66+
ArrayValue::I32(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
67+
ArrayValue::U32(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
68+
ArrayValue::I64(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
69+
ArrayValue::U64(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
70+
ArrayValue::I128(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
71+
ArrayValue::U128(lhs) => unsafe { eq_bsatn_int_seq(lhs, rhs) },
72+
}
73+
}
74+
75+
/// Equates the integer slice `lhs` to the BSATN-encoded one in `rhs`.
76+
///
77+
/// SAFETY: `T` must be an integer type.
78+
unsafe fn eq_bsatn_int_seq<'r, T>(lhs: &[T], mut rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
79+
// The BSATN is length-prefixed.
80+
let Ok(len) = rhs.reborrow().deserialize_len() else {
81+
return false;
82+
};
83+
84+
// Extract the rhs bytes.
85+
let Ok(rhs_bytes) = rhs.get_slice(len * mem::size_of::<T>()) else {
86+
return false;
87+
};
88+
89+
// Convert `lhs` to `&[u8]`.
90+
let ptr = lhs.as_ptr().cast::<u8>();
91+
// SAFETY: Caller promised that `T` is an integer type.
92+
// Thus it has no safety requirements and no padding,
93+
// so it is legal to convert `&[IntType] -> &[u8]`.
94+
let lhs_bytes = unsafe { slice::from_raw_parts(ptr, mem::size_of_val(lhs)) };
95+
96+
lhs_bytes == rhs_bytes
97+
}
98+
99+
/// Equates the string `lhs` to the BSATN-encoded one in `rhs`.
100+
#[allow(clippy::borrowed_box)]
101+
fn eq_bsatn_str<'r>(lhs: &Box<str>, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
102+
<&str>::deserialize(rhs).map(|rhs| &**lhs == rhs).unwrap_or(false)
103+
}
104+
105+
/// Equates elements in `lhs` to the BSATN-encoded element sequence in `rhs`.
106+
/// The sequence is prefixed with its length.
107+
fn eq_bsatn_seq<'r, T, I: ExactSizeIterator<Item = T>, R: BufReader<'r>>(
108+
lhs: impl IntoIterator<IntoIter = I>,
109+
mut rhs: Deserializer<'_, R>,
110+
elem_eq: impl Fn(T, Deserializer<'_, R>) -> bool,
111+
) -> bool {
112+
let mut lhs = lhs.into_iter();
113+
// The BSATN is length-prefixed.
114+
// Compare against length first.
115+
match rhs.reborrow().deserialize_len() {
116+
Ok(len) if lhs.len() == len => lhs.all(|e| elem_eq(e, rhs.reborrow())),
117+
_ => false,
118+
}
119+
}
120+
121+
/// Deserializes from `de` an `rhs: T` and then proceeds to `lhs == rhs`.
122+
fn eq_bsatn_de<'r, T: Eq + Deserialize<'r>>(lhs: &T, rhs: Deserializer<'_, impl BufReader<'r>>) -> bool {
123+
T::deserialize(rhs).map(|rhs| lhs == &rhs).unwrap_or(false)
124+
}
125+
126+
#[cfg(test)]
127+
mod tests {
128+
use super::eq_bsatn;
129+
use crate::{
130+
bsatn::{to_vec, Deserializer},
131+
proptest::generate_typed_value,
132+
};
133+
use proptest::prelude::*;
134+
135+
proptest! {
136+
#![proptest_config(ProptestConfig::with_cases(2048))]
137+
#[test]
138+
fn encoded_val_eq_to_self((_, val) in generate_typed_value()) {
139+
let mut bsatn = &*to_vec(&val).unwrap();
140+
let de = Deserializer::new(&mut bsatn);
141+
prop_assert!(eq_bsatn(&val, de));
142+
}
143+
}
144+
}

0 commit comments

Comments
 (0)