Auto merge of #78681 - m-ou-se:binary-heap-retain, r=Amanieu

Improve rebuilding behaviour of BinaryHeap::retain.

This changes `BinaryHeap::retain` such that it doesn't always fully rebuild the heap, but only rebuilds the parts for which that's necessary.

This makes use of the fact that retain gives out `&T`s and not `&mut T`s.

Retaining every element or removing only elements at the end results in no rebuilding at all. Retaining most elements results in only reordering the elements that got moved (those after the first removed element), using the same logic as was already used for `append`.

cc `@KodrAus` `@sfackler` - We briefly discussed this possibility in the meeting last week while we talked about stabilization of this function (#71503).
This commit is contained in:
bors
2021-04-23 00:07:19 +00:00
2 changed files with 69 additions and 35 deletions
+53 -32
View File
@@ -652,6 +652,43 @@ unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
unsafe { self.sift_up(start, pos) };
}
/// Rebuild assuming data[0..start] is still a proper heap.
fn rebuild_tail(&mut self, start: usize) {
if start == self.len() {
return;
}
let tail_len = self.len() - start;
#[inline(always)]
fn log2_fast(x: usize) -> usize {
(usize::BITS - x.leading_zeros() - 1) as usize
}
// `rebuild` takes O(self.len()) operations
// and about 2 * self.len() comparisons in the worst case
// while repeating `sift_up` takes O(tail_len * log(start)) operations
// and about 1 * tail_len * log_2(start) comparisons in the worst case,
// assuming start >= tail_len. For larger heaps, the crossover point
// no longer follows this reasoning and was determined empirically.
let better_to_rebuild = if start < tail_len {
true
} else if self.len() <= 2048 {
2 * self.len() < tail_len * log2_fast(start)
} else {
2 * self.len() < tail_len * 11
};
if better_to_rebuild {
self.rebuild();
} else {
for i in start..self.len() {
// SAFETY: The index `i` is always less than self.len().
unsafe { self.sift_up(0, i) };
}
}
}
fn rebuild(&mut self) {
let mut n = self.len() / 2;
while n > 0 {
@@ -689,37 +726,11 @@ pub fn append(&mut self, other: &mut Self) {
swap(self, other);
}
if other.is_empty() {
return;
}
let start = self.data.len();
#[inline(always)]
fn log2_fast(x: usize) -> usize {
(usize::BITS - x.leading_zeros() - 1) as usize
}
self.data.append(&mut other.data);
// `rebuild` takes O(len1 + len2) operations
// and about 2 * (len1 + len2) comparisons in the worst case
// while `extend` takes O(len2 * log(len1)) operations
// and about 1 * len2 * log_2(len1) comparisons in the worst case,
// assuming len1 >= len2. For larger heaps, the crossover point
// no longer follows this reasoning and was determined empirically.
#[inline]
fn better_to_rebuild(len1: usize, len2: usize) -> bool {
let tot_len = len1 + len2;
if tot_len <= 2048 {
2 * tot_len < len2 * log2_fast(len1)
} else {
2 * tot_len < len2 * 11
}
}
if better_to_rebuild(self.len(), other.len()) {
self.data.append(&mut other.data);
self.rebuild();
} else {
self.extend(other.drain());
}
self.rebuild_tail(start);
}
/// Returns an iterator which retrieves elements in heap order.
@@ -770,12 +781,22 @@ pub fn drain_sorted(&mut self) -> DrainSorted<'_, T> {
/// assert_eq!(heap.into_sorted_vec(), [-10, 2, 4])
/// ```
#[unstable(feature = "binary_heap_retain", issue = "71503")]
pub fn retain<F>(&mut self, f: F)
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> bool,
{
self.data.retain(f);
self.rebuild();
let mut first_removed = self.len();
let mut i = 0;
self.data.retain(|e| {
let keep = f(e);
if !keep && i < first_removed {
first_removed = i;
}
i += 1;
keep
});
// data[0..first_removed] is untouched, so we only need to rebuild the tail:
self.rebuild_tail(first_removed);
}
}
+16 -3
View File
@@ -386,10 +386,23 @@ fn drain<'new>(d: Drain<'static, &'static str>) -> Drain<'new, &'new str> {
#[test]
fn test_retain() {
let mut a = BinaryHeap::from(vec![-10, -5, 1, 2, 4, 13]);
a.retain(|x| x % 2 == 0);
let mut a = BinaryHeap::from(vec![100, 10, 50, 1, 2, 20, 30]);
a.retain(|&x| x != 2);
assert_eq!(a.into_sorted_vec(), [-10, 2, 4])
// Check that 20 moved into 10's place.
assert_eq!(a.clone().into_vec(), [100, 20, 50, 1, 10, 30]);
a.retain(|_| true);
assert_eq!(a.clone().into_vec(), [100, 20, 50, 1, 10, 30]);
a.retain(|&x| x < 50);
assert_eq!(a.clone().into_vec(), [30, 20, 10, 1]);
a.retain(|_| false);
assert!(a.is_empty());
}
// old binaryheap failed this test