|
| 1 | +use second_stack::uninit_slice; |
| 2 | + |
1 | 3 | use crate::ser::{self, ForwardNamedToSeqProduct, Serialize}; |
2 | 4 | use crate::{AlgebraicType, AlgebraicValue, ArrayValue, MapValue, F32, F64}; |
3 | 5 | use core::convert::Infallible; |
| 6 | +use core::mem::MaybeUninit; |
4 | 7 | use core::ptr; |
5 | 8 | use std::alloc::{self, Layout}; |
6 | 9 |
|
@@ -96,10 +99,12 @@ impl ser::Serializer for ValueSerializer { |
96 | 99 | chunks: I, |
97 | 100 | ) -> Result<Self::Ok, Self::Error> { |
98 | 101 | // SAFETY: Caller promised `total_bsatn_len == chunks.map(|c| c.len()).sum() <= isize::MAX`. |
99 | | - let bsatn = unsafe { concat_byte_chunks(total_bsatn_len, chunks) }; |
100 | | - |
101 | | - // SAFETY: Caller promised `AlgebraicValue::decode(ty, &mut bytes).is_ok()`. |
102 | | - unsafe { self.serialize_bsatn(ty, &bsatn) } |
| 102 | + unsafe { |
| 103 | + concat_byte_chunks_buf(total_bsatn_len, chunks, |bsatn| { |
| 104 | + // SAFETY: Caller promised `AlgebraicValue::decode(ty, &mut bytes).is_ok()`. |
| 105 | + ValueSerializer.serialize_bsatn(ty, bsatn) |
| 106 | + }) |
| 107 | + } |
103 | 108 | } |
104 | 109 |
|
105 | 110 | unsafe fn serialize_str_in_chunks<'a, I: Iterator<Item = &'a [u8]>>( |
@@ -136,33 +141,88 @@ unsafe fn concat_byte_chunks<'a>(total_len: usize, chunks: impl Iterator<Item = |
136 | 141 | alloc::handle_alloc_error(layout); |
137 | 142 | } |
138 | 143 |
|
| 144 | + // Copy over each `chunk`. |
| 145 | + // SAFETY: |
| 146 | + // 1. `ptr` is valid for writes as we own it |
| 147 | + // caller promised that all `chunk`s will fit in `total_len`. |
| 148 | + // 2. `ptr` points to a new allocation so it cannot overlap with any in `chunks`. |
| 149 | + unsafe { write_byte_chunks(ptr, chunks) }; |
| 150 | + |
| 151 | + // Convert allocation to a `Vec<u8>`. |
| 152 | + // SAFETY: |
| 153 | + // - `ptr` was allocated using global allocator. |
| 154 | + // - `u8` and `ptr`'s allocation both have alignment of 1. |
| 155 | + // - `ptr`'s allocation is `total_len <= isize::MAX`. |
| 156 | + // - `total_len <= total_len` holds. |
| 157 | + // - `total_len` values were initialized at type `u8` |
| 158 | + // as we know `total_len == chunks.map(|c| c.len()).sum()`. |
| 159 | + unsafe { Vec::from_raw_parts(ptr, total_len, total_len) } |
| 160 | +} |
| 161 | + |
| 162 | +/// Returns the concatenation of `chunks` that must be of `total_len` as a `Vec<u8>`. |
| 163 | +/// |
| 164 | +/// # Safety |
| 165 | +/// |
| 166 | +/// - `total_len == chunks.map(|c| c.len()).sum() <= isize::MAX` |
| 167 | +pub unsafe fn concat_byte_chunks_buf<'a, R>( |
| 168 | + total_len: usize, |
| 169 | + chunks: impl Iterator<Item = &'a [u8]>, |
| 170 | + run: impl FnOnce(&[u8]) -> R, |
| 171 | +) -> R { |
| 172 | + uninit_slice(total_len, |buf: &mut [MaybeUninit<u8>]| { |
| 173 | + let dst = buf.as_mut_ptr().cast(); |
| 174 | + debug_assert_eq!(total_len, buf.len()); |
| 175 | + // SAFETY: |
| 176 | + // 1. `buf.len() == total_len` |
| 177 | + // 2. `buf` cannot overlap with anything yielded by `var_iter`. |
| 178 | + unsafe { write_byte_chunks(dst, chunks) } |
| 179 | + // SAFETY: Every byte of `buf` was initialized in the previous call |
| 180 | + // as we know that `total_len == var_iter.map(|c| c.len()).sum()`. |
| 181 | + let bytes = unsafe { slice_assume_init_ref(buf) }; |
| 182 | + run(bytes) |
| 183 | + }) |
| 184 | +} |
| 185 | + |
| 186 | +/// Copies over each `chunk` in `chunks` to `dst`, writing `total_len` bytes to `dst`. |
| 187 | +/// |
| 188 | +/// # Safety |
| 189 | +/// |
| 190 | +/// Let `total_len == chunks.map(|c| c.len()).sum()`. |
| 191 | +/// 1. `dst` must be valid for writes for `total_len` bytes. |
| 192 | +/// 2. `dst..(dst + total_len)` does not overlap with any slice yielded by `chunks`. |
| 193 | +unsafe fn write_byte_chunks<'a>(mut dst: *mut u8, chunks: impl Iterator<Item = &'a [u8]>) { |
139 | 194 | // Copy over each `chunk`, moving `dst` by `chunk.len()` time. |
140 | | - let mut dst = ptr; |
141 | 195 | for chunk in chunks { |
142 | 196 | let len = chunk.len(); |
143 | 197 | // SAFETY: |
144 | | - // - `chunk` is valid for reads for `len` bytes. |
145 | | - // - `dst` is valid for writes as we own it |
146 | | - // and as (1) caller promised that all `chunk`s will fit in `total_len`, |
147 | | - // this entails that `dst..dst + len` is always in bounds of the allocation. |
| 198 | + // - By line above, `chunk` is valid for reads for `len` bytes. |
| 199 | + // - By (1) `dst` is valid for writes as promised by caller |
| 200 | + // and that all `chunk`s will fit in `total_len`. |
| 201 | + // This entails that `dst..dst + len` is always in bounds of the allocation. |
148 | 202 | // - `chunk` and `dst` are trivially properly aligned (`align_of::<u8>() == 1`). |
149 | | - // - The allocation `ptr` points to is new so derived pointers cannot overlap with `chunk`. |
| 203 | + // - By (2) derived pointers of `dst` cannot overlap with `chunk`. |
150 | 204 | unsafe { |
151 | 205 | ptr::copy_nonoverlapping(chunk.as_ptr(), dst, len); |
152 | 206 | } |
153 | 207 | // SAFETY: Same as (1). |
154 | 208 | dst = unsafe { dst.add(len) }; |
155 | 209 | } |
| 210 | +} |
156 | 211 |
|
157 | | - // Convert allocation to a `Vec<u8>`. |
158 | | - // SAFETY: |
159 | | - // - `ptr` was allocated using global allocator. |
160 | | - // - `u8` and `ptr`'s allocation both have alignment of 1. |
161 | | - // - `ptr`'s allocation is `total_len <= isize::MAX`. |
162 | | - // - `total_len <= total_len` holds. |
163 | | - // - `total_len` values were initialized at type `u8` |
164 | | - // as we know `total_len == chunks.map(|c| c.len()).sum()`. |
165 | | - unsafe { Vec::from_raw_parts(ptr, total_len, total_len) } |
| 212 | +/// Convert a `[MaybeUninit<T>]` into a `[T]` by asserting all elements are initialized. |
| 213 | +/// |
| 214 | +/// Identitcal copy of the source of `MaybeUninit::slice_assume_init_ref`, but that's not stabilized. |
| 215 | +/// https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.slice_assume_init_ref |
| 216 | +/// |
| 217 | +/// # Safety |
| 218 | +/// |
| 219 | +/// All elements of `slice` must be initialized. |
| 220 | +pub const unsafe fn slice_assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] { |
| 221 | + // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that |
| 222 | + // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`. |
| 223 | + // The pointer obtained is valid since it refers to memory owned by `slice` which is a |
| 224 | + // reference and thus guaranteed to be valid for reads. |
| 225 | + unsafe { &*(slice as *const [MaybeUninit<T>] as *const [T]) } |
166 | 226 | } |
167 | 227 |
|
168 | 228 | /// Continuation for serializing an array. |
|
0 commit comments