Skip to content

Commit cd436dc

Browse files
Auto merge of #146331 - RalfJung:copy-prov-repeat, r=<try>
interpret: copy_provenance: avoid large intermediate buffer for large repeat counts
2 parents be8de5d + d098505 commit cd436dc

File tree

5 files changed

+85
-82
lines changed

5 files changed

+85
-82
lines changed

compiler/rustc_const_eval/src/interpret/memory.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,7 +1503,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
15031503
// This will also error if copying partial provenance is not supported.
15041504
let provenance = src_alloc
15051505
.provenance()
1506-
.prepare_copy(src_range, dest_offset, num_copies, self)
1506+
.prepare_copy(src_range, self)
15071507
.map_err(|e| e.to_interp_error(src_alloc_id))?;
15081508
// Prepare a copy of the initialization mask.
15091509
let init = src_alloc.init_mask().prepare_copy(src_range);
@@ -1589,7 +1589,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
15891589
num_copies,
15901590
);
15911591
// copy the provenance to the destination
1592-
dest_alloc.provenance_apply_copy(provenance);
1592+
dest_alloc.provenance_apply_copy(provenance, alloc_range(dest_offset, size), num_copies);
15931593

15941594
interp_ok(())
15951595
}

compiler/rustc_data_structures/src/sorted_map.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -215,36 +215,39 @@ impl<K: Ord, V> SortedMap<K, V> {
215215
/// It is up to the caller to make sure that the elements are sorted by key
216216
/// and that there are no duplicates.
217217
#[inline]
218-
pub fn insert_presorted(&mut self, elements: Vec<(K, V)>) {
219-
if elements.is_empty() {
218+
pub fn insert_presorted(
219+
&mut self,
220+
mut elements: impl Iterator<Item = (K, V)> + DoubleEndedIterator,
221+
) {
222+
let Some(first) = elements.next() else {
220223
return;
221-
}
222-
223-
debug_assert!(elements.array_windows().all(|[fst, snd]| fst.0 < snd.0));
224+
};
224225

225-
let start_index = self.lookup_index_for(&elements[0].0);
226+
let start_index = self.lookup_index_for(&first.0);
226227

227228
let elements = match start_index {
228229
Ok(index) => {
229-
let mut elements = elements.into_iter();
230-
self.data[index] = elements.next().unwrap();
231-
elements
230+
self.data[index] = first; // overwrite first element
231+
elements.chain(None) // insert the rest below
232232
}
233233
Err(index) => {
234-
if index == self.data.len() || elements.last().unwrap().0 < self.data[index].0 {
234+
let last = elements.next_back();
235+
if index == self.data.len()
236+
|| last.as_ref().is_none_or(|l| l.0 < self.data[index].0)
237+
{
235238
// We can copy the whole range without having to mix with
236239
// existing elements.
237-
self.data.splice(index..index, elements);
240+
self.data
241+
.splice(index..index, std::iter::once(first).chain(elements).chain(last));
238242
return;
239243
}
240244

241-
let mut elements = elements.into_iter();
242-
self.data.insert(index, elements.next().unwrap());
243-
elements
245+
self.data.insert(index, first);
246+
elements.chain(last) // insert the rest below
244247
}
245248
};
246249

247-
// Insert the rest
250+
// Insert the rest. This is super inefficicent since each insertion copies the entire tail.
248251
for (k, v) in elements {
249252
self.insert(k, v);
250253
}

compiler/rustc_data_structures/src/sorted_map/tests.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ fn test_insert_presorted_non_overlapping() {
171171
map.insert(2, 0);
172172
map.insert(8, 0);
173173

174-
map.insert_presorted(vec![(3, 0), (7, 0)]);
174+
map.insert_presorted(vec![(3, 0), (7, 0)].into_iter());
175175

176176
let expected = vec![2, 3, 7, 8];
177177
assert_eq!(keys(map), expected);
@@ -183,7 +183,7 @@ fn test_insert_presorted_first_elem_equal() {
183183
map.insert(2, 2);
184184
map.insert(8, 8);
185185

186-
map.insert_presorted(vec![(2, 0), (7, 7)]);
186+
map.insert_presorted(vec![(2, 0), (7, 7)].into_iter());
187187

188188
let expected = vec![(2, 0), (7, 7), (8, 8)];
189189
assert_eq!(elements(map), expected);
@@ -195,7 +195,7 @@ fn test_insert_presorted_last_elem_equal() {
195195
map.insert(2, 2);
196196
map.insert(8, 8);
197197

198-
map.insert_presorted(vec![(3, 3), (8, 0)]);
198+
map.insert_presorted(vec![(3, 3), (8, 0)].into_iter());
199199

200200
let expected = vec![(2, 2), (3, 3), (8, 0)];
201201
assert_eq!(elements(map), expected);
@@ -207,7 +207,7 @@ fn test_insert_presorted_shuffle() {
207207
map.insert(2, 2);
208208
map.insert(7, 7);
209209

210-
map.insert_presorted(vec![(1, 1), (3, 3), (8, 8)]);
210+
map.insert_presorted(vec![(1, 1), (3, 3), (8, 8)].into_iter());
211211

212212
let expected = vec![(1, 1), (2, 2), (3, 3), (7, 7), (8, 8)];
213213
assert_eq!(elements(map), expected);
@@ -219,7 +219,7 @@ fn test_insert_presorted_at_end() {
219219
map.insert(1, 1);
220220
map.insert(2, 2);
221221

222-
map.insert_presorted(vec![(3, 3), (8, 8)]);
222+
map.insert_presorted(vec![(3, 3), (8, 8)].into_iter());
223223

224224
let expected = vec![(1, 1), (2, 2), (3, 3), (8, 8)];
225225
assert_eq!(elements(map), expected);

compiler/rustc_middle/src/mir/interpret/allocation.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,13 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
849849
///
850850
/// This is dangerous to use as it can violate internal `Allocation` invariants!
851851
/// It only exists to support an efficient implementation of `mem_copy_repeatedly`.
852-
pub fn provenance_apply_copy(&mut self, copy: ProvenanceCopy<Prov>) {
853-
self.provenance.apply_copy(copy)
852+
pub fn provenance_apply_copy(
853+
&mut self,
854+
copy: ProvenanceCopy<Prov>,
855+
range: AllocRange,
856+
repeat: u64,
857+
) {
858+
self.provenance.apply_copy(copy, range, repeat)
854859
}
855860

856861
/// Applies a previously prepared copy of the init mask.

compiler/rustc_middle/src/mir/interpret/allocation/provenance_map.rs

Lines changed: 53 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -278,90 +278,78 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
278278

279279
/// A partial, owned list of provenance to transfer into another allocation.
280280
///
281-
/// Offsets are already adjusted to the destination allocation.
281+
/// Offsets are relative to the beginning of the copied range.
282282
pub struct ProvenanceCopy<Prov> {
283-
dest_ptrs: Option<Box<[(Size, Prov)]>>,
284-
dest_bytes: Option<Box<[(Size, (Prov, u8))]>>,
283+
ptrs: Box<[(Size, Prov)]>,
284+
bytes: Box<[(Size, (Prov, u8))]>,
285285
}
286286

287287
impl<Prov: Provenance> ProvenanceMap<Prov> {
288288
pub fn prepare_copy(
289289
&self,
290-
src: AllocRange,
291-
dest: Size,
292-
count: u64,
290+
range: AllocRange,
293291
cx: &impl HasDataLayout,
294292
) -> AllocResult<ProvenanceCopy<Prov>> {
295-
let shift_offset = move |idx, offset| {
296-
// compute offset for current repetition
297-
let dest_offset = dest + src.size * idx; // `Size` operations
298-
// shift offsets from source allocation to destination allocation
299-
(offset - src.start) + dest_offset // `Size` operations
300-
};
293+
let shift_offset = move |offset| offset - range.start;
301294
let ptr_size = cx.data_layout().pointer_size();
302295

303296
// # Pointer-sized provenances
304297
// Get the provenances that are entirely within this range.
305298
// (Different from `range_get_ptrs` which asks if they overlap the range.)
306299
// Only makes sense if we are copying at least one pointer worth of bytes.
307-
let mut dest_ptrs_box = None;
308-
if src.size >= ptr_size {
309-
let adjusted_end = Size::from_bytes(src.end().bytes() - (ptr_size.bytes() - 1));
310-
let ptrs = self.ptrs.range(src.start..adjusted_end);
311-
// If `count` is large, this is rather wasteful -- we are allocating a big array here, which
312-
// is mostly filled with redundant information since it's just N copies of the same `Prov`s
313-
// at slightly adjusted offsets. The reason we do this is so that in `mark_provenance_range`
314-
// we can use `insert_presorted`. That wouldn't work with an `Iterator` that just produces
315-
// the right sequence of provenance for all N copies.
316-
// Basically, this large array would have to be created anyway in the target allocation.
317-
let mut dest_ptrs = Vec::with_capacity(ptrs.len() * (count as usize));
318-
for i in 0..count {
319-
dest_ptrs
320-
.extend(ptrs.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)));
321-
}
322-
debug_assert_eq!(dest_ptrs.len(), dest_ptrs.capacity());
323-
dest_ptrs_box = Some(dest_ptrs.into_boxed_slice());
300+
let mut ptrs_box: Box<[_]> = Box::new([]);
301+
if range.size >= ptr_size {
302+
let adjusted_end = Size::from_bytes(range.end().bytes() - (ptr_size.bytes() - 1));
303+
let ptrs = self.ptrs.range(range.start..adjusted_end);
304+
ptrs_box = ptrs.iter().map(|&(offset, reloc)| (shift_offset(offset), reloc)).collect();
324305
};
325306

326307
// # Byte-sized provenances
327308
// This includes the existing bytewise provenance in the range, and ptr provenance
328309
// that overlaps with the begin/end of the range.
329-
let mut dest_bytes_box = None;
330-
let begin_overlap = self.range_ptrs_get(alloc_range(src.start, Size::ZERO), cx).first();
331-
let end_overlap = self.range_ptrs_get(alloc_range(src.end(), Size::ZERO), cx).first();
310+
let mut bytes_box: Box<[_]> = Box::new([]);
311+
let begin_overlap = self.range_ptrs_get(alloc_range(range.start, Size::ZERO), cx).first();
312+
let end_overlap = self.range_ptrs_get(alloc_range(range.end(), Size::ZERO), cx).first();
332313
// We only need to go here if there is some overlap or some bytewise provenance.
333314
if begin_overlap.is_some() || end_overlap.is_some() || self.bytes.is_some() {
334315
let mut bytes: Vec<(Size, (Prov, u8))> = Vec::new();
335316
// First, if there is a part of a pointer at the start, add that.
336317
if let Some(entry) = begin_overlap {
337318
trace!("start overlapping entry: {entry:?}");
338-
// For really small copies, make sure we don't run off the end of the `src` range.
339-
let entry_end = cmp::min(entry.0 + ptr_size, src.end());
340-
for offset in src.start..entry_end {
341-
bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
319+
// For really small copies, make sure we don't run off the end of the range.
320+
let entry_end = cmp::min(entry.0 + ptr_size, range.end());
321+
for offset in range.start..entry_end {
322+
bytes.push((shift_offset(offset), (entry.1, (offset - entry.0).bytes() as u8)));
342323
}
343324
} else {
344325
trace!("no start overlapping entry");
345326
}
346327

347328
// Then the main part, bytewise provenance from `self.bytes`.
348-
bytes.extend(self.range_bytes_get(src));
329+
bytes.extend(
330+
self.range_bytes_get(range)
331+
.iter()
332+
.map(|&(offset, reloc)| (shift_offset(offset), reloc)),
333+
);
349334

350335
// And finally possibly parts of a pointer at the end.
351336
if let Some(entry) = end_overlap {
352337
trace!("end overlapping entry: {entry:?}");
353-
// For really small copies, make sure we don't start before `src` does.
354-
let entry_start = cmp::max(entry.0, src.start);
355-
for offset in entry_start..src.end() {
338+
// For really small copies, make sure we don't start before `range` does.
339+
let entry_start = cmp::max(entry.0, range.start);
340+
for offset in entry_start..range.end() {
356341
if bytes.last().is_none_or(|bytes_entry| bytes_entry.0 < offset) {
357342
// The last entry, if it exists, has a lower offset than us, so we
358343
// can add it at the end and remain sorted.
359-
bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
344+
bytes.push((
345+
shift_offset(offset),
346+
(entry.1, (offset - entry.0).bytes() as u8),
347+
));
360348
} else {
361349
// There already is an entry for this offset in there! This can happen when the
362350
// start and end range checks actually end up hitting the same pointer, so we
363351
// already added this in the "pointer at the start" part above.
364-
assert!(entry.0 <= src.start);
352+
assert!(entry.0 <= range.start);
365353
}
366354
}
367355
} else {
@@ -372,33 +360,40 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
372360
if !bytes.is_empty() && !Prov::OFFSET_IS_ADDR {
373361
// FIXME(#146291): We need to ensure that we don't mix different pointers with
374362
// the same provenance.
375-
return Err(AllocError::ReadPartialPointer(src.start));
363+
return Err(AllocError::ReadPartialPointer(range.start));
376364
}
377365

378366
// And again a buffer for the new list on the target side.
379-
let mut dest_bytes = Vec::with_capacity(bytes.len() * (count as usize));
380-
for i in 0..count {
381-
dest_bytes
382-
.extend(bytes.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)));
383-
}
384-
debug_assert_eq!(dest_bytes.len(), dest_bytes.capacity());
385-
dest_bytes_box = Some(dest_bytes.into_boxed_slice());
367+
bytes_box = bytes.into_boxed_slice();
386368
}
387369

388-
Ok(ProvenanceCopy { dest_ptrs: dest_ptrs_box, dest_bytes: dest_bytes_box })
370+
Ok(ProvenanceCopy { ptrs: ptrs_box, bytes: bytes_box })
389371
}
390372

391373
/// Applies a provenance copy.
392374
/// The affected range, as defined in the parameters to `prepare_copy` is expected
393375
/// to be clear of provenance.
394-
pub fn apply_copy(&mut self, copy: ProvenanceCopy<Prov>) {
395-
if let Some(dest_ptrs) = copy.dest_ptrs {
396-
self.ptrs.insert_presorted(dest_ptrs.into());
376+
pub fn apply_copy(&mut self, copy: ProvenanceCopy<Prov>, range: AllocRange, repeat: u64) {
377+
let shift_offset = |idx: u64, offset: Size| offset + range.start + idx * range.size;
378+
if !copy.ptrs.is_empty() {
379+
// We want to call `insert_presorted` only once so that, if possible, the entries
380+
// after the range we insert are moved back only once.
381+
let chunk_len = copy.ptrs.len() as u64;
382+
self.ptrs.insert_presorted((0..chunk_len * repeat).map(|i| {
383+
let chunk = i / chunk_len;
384+
let (offset, reloc) = copy.ptrs[(i % chunk_len) as usize];
385+
(shift_offset(chunk, offset), reloc)
386+
}));
397387
}
398-
if let Some(dest_bytes) = copy.dest_bytes
399-
&& !dest_bytes.is_empty()
400-
{
401-
self.bytes.get_or_insert_with(Box::default).insert_presorted(dest_bytes.into());
388+
if !copy.bytes.is_empty() {
389+
let chunk_len = copy.bytes.len() as u64;
390+
self.bytes.get_or_insert_with(Box::default).insert_presorted(
391+
(0..chunk_len * repeat).map(|i| {
392+
let chunk = i / chunk_len;
393+
let (offset, reloc) = copy.bytes[(i % chunk_len) as usize];
394+
(shift_offset(chunk, offset), reloc)
395+
}),
396+
);
402397
}
403398
}
404399
}

0 commit comments

Comments
 (0)