Skip to content

Commit 6215de8

Browse files
Auto merge of #146331 - RalfJung:copy-prov-repeat, r=<try>
interpret: copy_provenance: avoid large intermediate buffer for large repeat counts
2 parents beeb8e3 + 8a1774f commit 6215de8

File tree

5 files changed

+77
-85
lines changed

5 files changed

+77
-85
lines changed

compiler/rustc_const_eval/src/interpret/memory.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,8 +1501,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
15011501
// `get_bytes_mut` will clear the provenance, which is correct,
15021502
// since we don't want to keep any provenance at the target.
15031503
// This will also error if copying partial provenance is not supported.
1504-
let provenance =
1505-
src_alloc.provenance().prepare_copy(src_range, dest_offset, num_copies, self);
1504+
let provenance = src_alloc.provenance().prepare_copy(src_range, self);
15061505
// Prepare a copy of the initialization mask.
15071506
let init = src_alloc.init_mask().prepare_copy(src_range);
15081507

@@ -1587,7 +1586,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
15871586
num_copies,
15881587
);
15891588
// copy the provenance to the destination
1590-
dest_alloc.provenance_apply_copy(provenance);
1589+
dest_alloc.provenance_apply_copy(provenance, alloc_range(dest_offset, size), num_copies);
15911590

15921591
interp_ok(())
15931592
}

compiler/rustc_data_structures/src/sorted_map.rs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -215,32 +215,35 @@ impl<K: Ord, V> SortedMap<K, V> {
215215
/// It is up to the caller to make sure that the elements are sorted by key
216216
/// and that there are no duplicates.
217217
#[inline]
218-
pub fn insert_presorted(&mut self, elements: Vec<(K, V)>) {
219-
if elements.is_empty() {
218+
pub fn insert_presorted(
219+
&mut self,
220+
mut elements: impl Iterator<Item = (K, V)> + DoubleEndedIterator,
221+
) {
222+
let Some(first) = elements.next() else {
220223
return;
221-
}
222-
223-
debug_assert!(elements.array_windows().all(|[fst, snd]| fst.0 < snd.0));
224+
};
224225

225-
let start_index = self.lookup_index_for(&elements[0].0);
226+
let start_index = self.lookup_index_for(&first.0);
226227

227228
let elements = match start_index {
228229
Ok(index) => {
229-
let mut elements = elements.into_iter();
230-
self.data[index] = elements.next().unwrap();
231-
elements
230+
self.data[index] = first; // overwrite first element
231+
elements.chain(None) // insert the rest below
232232
}
233233
Err(index) => {
234-
if index == self.data.len() || elements.last().unwrap().0 < self.data[index].0 {
234+
let last = elements.next_back();
235+
if index == self.data.len()
236+
|| last.as_ref().is_none_or(|l| l.0 < self.data[index].0)
237+
{
235238
// We can copy the whole range without having to mix with
236239
// existing elements.
237-
self.data.splice(index..index, elements);
240+
self.data
241+
.splice(index..index, std::iter::once(first).chain(elements).chain(last));
238242
return;
239243
}
240244

241-
let mut elements = elements.into_iter();
242-
self.data.insert(index, elements.next().unwrap());
243-
elements
245+
self.data.insert(index, first);
246+
elements.chain(last) // insert the rest below
244247
}
245248
};
246249

compiler/rustc_data_structures/src/sorted_map/tests.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ fn test_insert_presorted_non_overlapping() {
171171
map.insert(2, 0);
172172
map.insert(8, 0);
173173

174-
map.insert_presorted(vec![(3, 0), (7, 0)]);
174+
map.insert_presorted(vec![(3, 0), (7, 0)].into_iter());
175175

176176
let expected = vec![2, 3, 7, 8];
177177
assert_eq!(keys(map), expected);
@@ -183,7 +183,7 @@ fn test_insert_presorted_first_elem_equal() {
183183
map.insert(2, 2);
184184
map.insert(8, 8);
185185

186-
map.insert_presorted(vec![(2, 0), (7, 7)]);
186+
map.insert_presorted(vec![(2, 0), (7, 7)].into_iter());
187187

188188
let expected = vec![(2, 0), (7, 7), (8, 8)];
189189
assert_eq!(elements(map), expected);
@@ -195,7 +195,7 @@ fn test_insert_presorted_last_elem_equal() {
195195
map.insert(2, 2);
196196
map.insert(8, 8);
197197

198-
map.insert_presorted(vec![(3, 3), (8, 0)]);
198+
map.insert_presorted(vec![(3, 3), (8, 0)].into_iter());
199199

200200
let expected = vec![(2, 2), (3, 3), (8, 0)];
201201
assert_eq!(elements(map), expected);
@@ -207,7 +207,7 @@ fn test_insert_presorted_shuffle() {
207207
map.insert(2, 2);
208208
map.insert(7, 7);
209209

210-
map.insert_presorted(vec![(1, 1), (3, 3), (8, 8)]);
210+
map.insert_presorted(vec![(1, 1), (3, 3), (8, 8)].into_iter());
211211

212212
let expected = vec![(1, 1), (2, 2), (3, 3), (7, 7), (8, 8)];
213213
assert_eq!(elements(map), expected);
@@ -219,7 +219,7 @@ fn test_insert_presorted_at_end() {
219219
map.insert(1, 1);
220220
map.insert(2, 2);
221221

222-
map.insert_presorted(vec![(3, 3), (8, 8)]);
222+
map.insert_presorted(vec![(3, 3), (8, 8)].into_iter());
223223

224224
let expected = vec![(1, 1), (2, 2), (3, 3), (8, 8)];
225225
assert_eq!(elements(map), expected);

compiler/rustc_middle/src/mir/interpret/allocation.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -844,8 +844,13 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
844844
///
845845
/// This is dangerous to use as it can violate internal `Allocation` invariants!
846846
/// It only exists to support an efficient implementation of `mem_copy_repeatedly`.
847-
pub fn provenance_apply_copy(&mut self, copy: ProvenanceCopy<Prov>) {
848-
self.provenance.apply_copy(copy)
847+
pub fn provenance_apply_copy(
848+
&mut self,
849+
copy: ProvenanceCopy<Prov>,
850+
range: AllocRange,
851+
repeat: u64,
852+
) {
853+
self.provenance.apply_copy(copy, range, repeat)
849854
}
850855

851856
/// Applies a previously prepared copy of the init mask.

compiler/rustc_middle/src/mir/interpret/allocation/provenance_map.rs

Lines changed: 46 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -272,90 +272,74 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
272272

273273
/// A partial, owned list of provenance to transfer into another allocation.
274274
///
275-
/// Offsets are already adjusted to the destination allocation.
275+
/// Offsets are relative to the beginning of the copied range.
276276
pub struct ProvenanceCopy<Prov> {
277-
dest_ptrs: Option<Box<[(Size, Prov)]>>,
278-
dest_bytes: Option<Box<[(Size, (Prov, u8))]>>,
277+
ptrs: Box<[(Size, Prov)]>,
278+
bytes: Box<[(Size, (Prov, u8))]>,
279279
}
280280

281281
impl<Prov: Provenance> ProvenanceMap<Prov> {
282-
pub fn prepare_copy(
283-
&self,
284-
src: AllocRange,
285-
dest: Size,
286-
count: u64,
287-
cx: &impl HasDataLayout,
288-
) -> ProvenanceCopy<Prov> {
289-
let shift_offset = move |idx, offset| {
290-
// compute offset for current repetition
291-
let dest_offset = dest + src.size * idx; // `Size` operations
292-
// shift offsets from source allocation to destination allocation
293-
(offset - src.start) + dest_offset // `Size` operations
294-
};
282+
pub fn prepare_copy(&self, range: AllocRange, cx: &impl HasDataLayout) -> ProvenanceCopy<Prov> {
283+
let shift_offset = move |offset| offset - range.start;
295284
let ptr_size = cx.data_layout().pointer_size();
296285

297286
// # Pointer-sized provenances
298287
// Get the provenances that are entirely within this range.
299288
// (Different from `range_get_ptrs` which asks if they overlap the range.)
300289
// Only makes sense if we are copying at least one pointer worth of bytes.
301-
let mut dest_ptrs_box = None;
302-
if src.size >= ptr_size {
303-
let adjusted_end = Size::from_bytes(src.end().bytes() - (ptr_size.bytes() - 1));
304-
let ptrs = self.ptrs.range(src.start..adjusted_end);
305-
// If `count` is large, this is rather wasteful -- we are allocating a big array here, which
306-
// is mostly filled with redundant information since it's just N copies of the same `Prov`s
307-
// at slightly adjusted offsets. The reason we do this is so that in `mark_provenance_range`
308-
// we can use `insert_presorted`. That wouldn't work with an `Iterator` that just produces
309-
// the right sequence of provenance for all N copies.
310-
// Basically, this large array would have to be created anyway in the target allocation.
311-
let mut dest_ptrs = Vec::with_capacity(ptrs.len() * (count as usize));
312-
for i in 0..count {
313-
dest_ptrs
314-
.extend(ptrs.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)));
315-
}
316-
debug_assert_eq!(dest_ptrs.len(), dest_ptrs.capacity());
317-
dest_ptrs_box = Some(dest_ptrs.into_boxed_slice());
290+
let mut ptrs_box: Box<[_]> = Box::new([]);
291+
if range.size >= ptr_size {
292+
let adjusted_end = Size::from_bytes(range.end().bytes() - (ptr_size.bytes() - 1));
293+
let ptrs = self.ptrs.range(range.start..adjusted_end);
294+
ptrs_box = ptrs.iter().map(|&(offset, reloc)| (shift_offset(offset), reloc)).collect();
318295
};
319296

320297
// # Byte-sized provenances
321298
// This includes the existing bytewise provenance in the range, and ptr provenance
322299
// that overlaps with the begin/end of the range.
323-
let mut dest_bytes_box = None;
324-
let begin_overlap = self.range_ptrs_get(alloc_range(src.start, Size::ZERO), cx).first();
325-
let end_overlap = self.range_ptrs_get(alloc_range(src.end(), Size::ZERO), cx).first();
300+
let mut bytes_box: Box<[_]> = Box::new([]);
301+
let begin_overlap = self.range_ptrs_get(alloc_range(range.start, Size::ZERO), cx).first();
302+
let end_overlap = self.range_ptrs_get(alloc_range(range.end(), Size::ZERO), cx).first();
326303
// We only need to go here if there is some overlap or some bytewise provenance.
327304
if begin_overlap.is_some() || end_overlap.is_some() || self.bytes.is_some() {
328305
let mut bytes: Vec<(Size, (Prov, u8))> = Vec::new();
329306
// First, if there is a part of a pointer at the start, add that.
330307
if let Some(entry) = begin_overlap {
331308
trace!("start overlapping entry: {entry:?}");
332-
// For really small copies, make sure we don't run off the end of the `src` range.
333-
let entry_end = cmp::min(entry.0 + ptr_size, src.end());
334-
for offset in src.start..entry_end {
335-
bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
309+
// For really small copies, make sure we don't run off the end of the range.
310+
let entry_end = cmp::min(entry.0 + ptr_size, range.end());
311+
for offset in range.start..entry_end {
312+
bytes.push((shift_offset(offset), (entry.1, (offset - entry.0).bytes() as u8)));
336313
}
337314
} else {
338315
trace!("no start overlapping entry");
339316
}
340317

341318
// Then the main part, bytewise provenance from `self.bytes`.
342-
bytes.extend(self.range_bytes_get(src));
319+
bytes.extend(
320+
self.range_bytes_get(range)
321+
.iter()
322+
.map(|&(offset, reloc)| (shift_offset(offset), reloc)),
323+
);
343324

344325
// And finally possibly parts of a pointer at the end.
345326
if let Some(entry) = end_overlap {
346327
trace!("end overlapping entry: {entry:?}");
347-
// For really small copies, make sure we don't start before `src` does.
348-
let entry_start = cmp::max(entry.0, src.start);
349-
for offset in entry_start..src.end() {
328+
// For really small copies, make sure we don't start before `range` does.
329+
let entry_start = cmp::max(entry.0, range.start);
330+
for offset in entry_start..range.end() {
350331
if bytes.last().is_none_or(|bytes_entry| bytes_entry.0 < offset) {
351332
// The last entry, if it exists, has a lower offset than us, so we
352333
// can add it at the end and remain sorted.
353-
bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
334+
bytes.push((
335+
shift_offset(offset),
336+
(entry.1, (offset - entry.0).bytes() as u8),
337+
));
354338
} else {
355339
// There already is an entry for this offset in there! This can happen when the
356340
// start and end range checks actually end up hitting the same pointer, so we
357341
// already added this in the "pointer at the start" part above.
358-
assert!(entry.0 <= src.start);
342+
assert!(entry.0 <= range.start);
359343
}
360344
}
361345
} else {
@@ -364,29 +348,30 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
364348
trace!("byte provenances: {bytes:?}");
365349

366350
// And again a buffer for the new list on the target side.
367-
let mut dest_bytes = Vec::with_capacity(bytes.len() * (count as usize));
368-
for i in 0..count {
369-
dest_bytes
370-
.extend(bytes.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)));
371-
}
372-
debug_assert_eq!(dest_bytes.len(), dest_bytes.capacity());
373-
dest_bytes_box = Some(dest_bytes.into_boxed_slice());
351+
bytes_box = bytes.into_boxed_slice();
374352
}
375353

376-
ProvenanceCopy { dest_ptrs: dest_ptrs_box, dest_bytes: dest_bytes_box }
354+
ProvenanceCopy { ptrs: ptrs_box, bytes: bytes_box }
377355
}
378356

379357
/// Applies a provenance copy.
380358
/// The affected range, as defined in the parameters to `prepare_copy` is expected
381359
/// to be clear of provenance.
382-
pub fn apply_copy(&mut self, copy: ProvenanceCopy<Prov>) {
383-
if let Some(dest_ptrs) = copy.dest_ptrs {
384-
self.ptrs.insert_presorted(dest_ptrs.into());
360+
pub fn apply_copy(&mut self, copy: ProvenanceCopy<Prov>, range: AllocRange, repeat: u64) {
361+
let shift_offset = |idx: u64, offset: Size| offset + range.start + idx * range.size;
362+
if !copy.ptrs.is_empty() {
363+
for i in 0..repeat {
364+
self.ptrs.insert_presorted(
365+
copy.ptrs.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)),
366+
);
367+
}
385368
}
386-
if let Some(dest_bytes) = copy.dest_bytes
387-
&& !dest_bytes.is_empty()
388-
{
389-
self.bytes.get_or_insert_with(Box::default).insert_presorted(dest_bytes.into());
369+
if !copy.bytes.is_empty() {
370+
for i in 0..repeat {
371+
self.bytes.get_or_insert_with(Box::default).insert_presorted(
372+
copy.bytes.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)),
373+
);
374+
}
390375
}
391376
}
392377
}

0 commit comments

Comments
 (0)