Optimized BTreeMap::merge using CursorMut

This commit is contained in:
Mahdi Ali-Raihan
2026-02-09 21:44:34 -05:00
parent 1b50859d36
commit 24efac1063
3 changed files with 219 additions and 72 deletions
@@ -33,36 +33,6 @@ pub(super) fn append_from_sorted_iters<I, A: Allocator + Clone>(
self.bulk_push(iter, length, alloc)
}
/// Merges all key-value pairs from the union of two ascending iterators,
/// incrementing a `length` variable along the way. The latter makes it
/// easier for the caller to avoid a leak when a drop handler panicks.
///
/// If both iterators produce the same key, this method constructs a pair using the
/// key from the left iterator and calls on a closure `f` to return a value given
/// the conflicting key and value from left and right iterators.
///
/// If you want the tree to end up in a strictly ascending order, like for
/// a `BTreeMap`, both iterators should produce keys in strictly ascending
/// order, each greater than all keys in the tree, including any keys
/// already in the tree upon entry.
pub(super) fn merge_from_sorted_iters_with<I, A: Allocator + Clone>(
&mut self,
left: I,
right: I,
length: &mut usize,
alloc: A,
f: impl FnMut(&K, V, V) -> V,
) where
K: Ord,
I: Iterator<Item = (K, V)> + FusedIterator,
{
// We prepare to merge `left` and `right` into a sorted sequence in linear time.
let iter = MergeIterWith { inner: MergeIterInner::new(left, right), f };
// Meanwhile, we build a tree from the sorted sequence in linear time.
self.bulk_push(iter, length, alloc)
}
/// Pushes all key-value pairs to the end of the tree, incrementing a
/// `length` variable along the way. The latter makes it easier for the
/// caller to avoid a leak when the iterator panicks.
@@ -145,33 +115,3 @@ fn next(&mut self) -> Option<(K, V)> {
}
}
}
/// An iterator for merging two sorted sequences into one with
/// a callback function to return a value on conflicting keys
struct MergeIterWith<F, K, V, I: Iterator<Item = (K, V)>> {
inner: MergeIterInner<I>,
f: F,
}
impl<F, K: Ord, V, I> Iterator for MergeIterWith<F, K, V, I>
where
F: FnMut(&K, V, V) -> V,
I: Iterator<Item = (K, V)> + FusedIterator,
{
type Item = (K, V);
/// If two keys are equal, returns the key from the left and uses `f` to return
/// a value given the conflicting key and values from left and right
fn next(&mut self) -> Option<(K, V)> {
let (a_next, b_next) = self.inner.nexts(|a: &(K, V), b: &(K, V)| K::cmp(&a.0, &b.0));
match (a_next, b_next) {
(Some((a_k, a_v)), Some((_, b_v))) => Some({
let next_val = (self.f)(&a_k, a_v, b_v);
(a_k, next_val)
}),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
}
}
}
+124 -11
View File
@@ -1287,7 +1287,7 @@ pub fn append(&mut self, other: &mut Self)
/// assert_eq!(a[&5], "f");
/// ```
#[unstable(feature = "btree_merge", issue = "152152")]
pub fn merge(&mut self, mut other: Self, conflict: impl FnMut(&K, V, V) -> V)
pub fn merge(&mut self, mut other: Self, mut conflict: impl FnMut(&K, V, V) -> V)
where
K: Ord,
A: Clone,
@@ -1303,16 +1303,75 @@ pub fn merge(&mut self, mut other: Self, conflict: impl FnMut(&K, V, V) -> V)
return;
}
let self_iter = mem::replace(self, Self::new_in((*self.alloc).clone())).into_iter();
let other_iter = mem::replace(&mut other, Self::new_in((*self.alloc).clone())).into_iter();
let root = self.root.get_or_insert_with(|| Root::new((*self.alloc).clone()));
root.merge_from_sorted_iters_with(
self_iter,
other_iter,
&mut self.length,
(*self.alloc).clone(),
conflict,
)
let mut other_iter = other.into_iter();
let (first_other_key, first_other_val) = other_iter.next().unwrap();
// find the first gap that has the smallest key greater than or equal to
// the first key from other
let mut self_cursor = self.lower_bound_mut(Bound::Included(&first_other_key));
if let Some((self_key, _)) = self_cursor.peek_next() {
match K::cmp(&first_other_key, self_key) {
Ordering::Equal => {
self_cursor.with_next(|self_key, self_val| {
conflict(self_key, self_val, first_other_val)
});
}
Ordering::Less =>
// SAFETY: we know our other_key's ordering is less than self_key,
// so inserting before will guarantee sorted order
unsafe {
self_cursor.insert_before_unchecked(first_other_key, first_other_val);
},
Ordering::Greater => {
unreachable!("Cursor's peek_next should return None.");
}
}
} else {
// SAFETY: reaching here means our cursor is at the end
// self BTreeMap so we just insert other_key here
unsafe {
self_cursor.insert_before_unchecked(first_other_key, first_other_val);
}
}
for (other_key, other_val) in other_iter {
loop {
if let Some((self_key, _)) = self_cursor.peek_next() {
match K::cmp(&other_key, self_key) {
Ordering::Equal => {
self_cursor.with_next(|self_key, self_val| {
conflict(self_key, self_val, other_val)
});
break;
}
Ordering::Less => {
// SAFETY: we know our other_key's ordering is less than self_key,
// so inserting before will guarantee sorted order
unsafe {
self_cursor.insert_before_unchecked(other_key, other_val);
}
break;
}
Ordering::Greater => {
// FIXME: instead of doing a linear search here,
// this can be optimized to search the tree by starting
// from self_cursor and going towards the root and then
// back down to the proper node -- that should probably
// be a new method on Cursor*.
self_cursor.next();
}
}
} else {
// SAFETY: reaching here means our cursor is at the end
// self BTreeMap so we just insert other_key here
unsafe {
self_cursor.insert_before_unchecked(other_key, other_val);
}
break;
}
}
}
}
/// Constructs a double-ended iterator over a sub-range of elements in the map.
@@ -3337,6 +3396,37 @@ pub fn as_cursor(&self) -> Cursor<'_, K, V> {
// Now the tree editing operations
impl<'a, K: Ord, V, A: Allocator + Clone> CursorMutKey<'a, K, V, A> {
/// Calls a function with ownership of the next element's key and
/// and value and expects it to return a value to write
/// back to the next element's key and value. The cursor is not
/// advanced forward.
///
/// If the cursor is at the end of the map then the function is not called
/// and this essentially does not do anything.
///
/// # Safety
///
/// You must ensure that the `BTreeMap` invariants are maintained.
/// Specifically:
///
/// * The next element's key must be unique in the tree.
/// * All keys in the tree must remain in sorted order.
#[allow(dead_code)] /* This function exists for consistency with CursorMut */
pub(super) fn with_next(&mut self, f: impl FnOnce(K, V) -> (K, V)) {
// if `f` unwinds, the next entry is already removed leaving
// the tree in valid state.
// FIXME: Once `MaybeDangling` is implemented, we can optimize
// this through using a drop handler and transmutating CursorMutKey<K, V>
// to CursorMutKey<ManuallyDrop<K>, ManuallyDrop<V>> (see PR #152418)
if let Some((k, v)) = self.remove_next() {
// SAFETY: we remove the K, V out of the next entry,
// apply 'f' to get a new (K, V), and insert it back
// into the next entry that the cursor is pointing at
let (k, v) = f(k, v);
unsafe { self.insert_after_unchecked(k, v) };
}
}
/// Inserts a new key-value pair into the map in the gap that the
/// cursor is currently pointing to.
///
@@ -3542,6 +3632,29 @@ pub fn remove_prev(&mut self) -> Option<(K, V)> {
}
impl<'a, K: Ord, V, A: Allocator + Clone> CursorMut<'a, K, V, A> {
/// Calls a function with a reference to the next element's key and
/// ownership of its value. The function is expected to return a value
/// to write back to the next element's value. The cursor is not
/// advanced forward.
///
/// If the cursor is at the end of the map then the function is not called
/// and this essentially does not do anything.
pub(super) fn with_next(&mut self, f: impl FnOnce(&K, V) -> V) {
// FIXME: This can be optimized to not do all the removing/reinserting
// logic by using ptr::read, calling `f`, and then using ptr::write.
// if `f` unwinds, then we need to remove the entry while being careful to
// not cause UB by moving or dropping the already-dropped `V`
// for the entry. Some implementation ideas:
// https://github.com/rust-lang/rust/pull/152418#discussion_r2800232576
if let Some((k, v)) = self.remove_next() {
// SAFETY: we remove the K, V out of the next entry,
// apply 'f' to get a new V, and insert (K, V) back
// into the next entry that the cursor is pointing at
let v = f(&k, v);
unsafe { self.insert_after_unchecked(k, v) };
}
}
/// Inserts a new key-value pair into the map in the gap that the
/// cursor is currently pointing to.
///
@@ -2128,6 +2128,16 @@ fn $name() {
#[cfg(not(miri))] // Miri is too slow
create_append_test!(test_append_1700, 1700);
// a inserts (0, 0)..(8, 8) to its own tree
// b inserts (5, 5 * 2)..($len, 2 * $len) to its own tree
// note that between a and b, there are duplicate keys
// between 5..min($len, 8), so on merge we add the values
// of these keys together
// we check that:
// - the merged tree 'a' has a length of max(8, $len)
// - all keys in 'a' have the correct value associated
// - removing and inserting an element into the merged
// tree 'a' still keeps it in valid tree form
macro_rules! create_merge_test {
($name:ident, $len:expr) => {
#[test]
@@ -2239,6 +2249,84 @@ fn test_append_ord_chaos() {
map2.check();
}
#[test]
#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")]
fn test_merge_drop_leak() {
let a = CrashTestDummy::new(0);
let b = CrashTestDummy::new(1);
let c = CrashTestDummy::new(2);
let mut left = BTreeMap::new();
let mut right = BTreeMap::new();
left.insert(a.spawn(Panic::Never), ());
left.insert(b.spawn(Panic::Never), ());
left.insert(c.spawn(Panic::Never), ());
right.insert(b.spawn(Panic::InDrop), ()); // first duplicate key, dropped during merge
right.insert(c.spawn(Panic::Never), ());
catch_unwind(move || left.merge(right, |_, _, _| ())).unwrap_err();
assert_eq!(a.dropped(), 1); // this should not be dropped
assert_eq!(b.dropped(), 2); // key is dropped on panic
assert_eq!(c.dropped(), 2); // key is dropped on panic
}
#[test]
#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")]
fn test_merge_conflict_drop_leak() {
let a = CrashTestDummy::new(0);
let a_val_left = CrashTestDummy::new(0);
let b = CrashTestDummy::new(1);
let b_val_left = CrashTestDummy::new(1);
let b_val_right = CrashTestDummy::new(1);
let c = CrashTestDummy::new(2);
let c_val_left = CrashTestDummy::new(2);
let c_val_right = CrashTestDummy::new(2);
let mut left = BTreeMap::new();
let mut right = BTreeMap::new();
left.insert(a.spawn(Panic::Never), a_val_left.spawn(Panic::Never));
left.insert(b.spawn(Panic::Never), b_val_left.spawn(Panic::Never));
left.insert(c.spawn(Panic::Never), c_val_left.spawn(Panic::Never));
right.insert(b.spawn(Panic::Never), b_val_right.spawn(Panic::Never));
right.insert(c.spawn(Panic::Never), c_val_right.spawn(Panic::Never));
// First key that conflicts should
catch_unwind(move || {
left.merge(right, |_, _, _| panic!("Panic in conflict function"));
assert_eq!(left.len(), 1); // only 1 entry should be left
})
.unwrap_err();
assert_eq!(a.dropped(), 1); // should not panic
assert_eq!(a_val_left.dropped(), 1); // should not panic
assert_eq!(b.dropped(), 2); // should drop from panic (conflict)
assert_eq!(b_val_left.dropped(), 1); // should be 2 were it not for Rust issue #47949
assert_eq!(b_val_right.dropped(), 1); // should be 2 were it not for Rust issue #47949
assert_eq!(c.dropped(), 2); // should drop from panic (conflict)
assert_eq!(c_val_left.dropped(), 1); // should be 2 were it not for Rust issue #47949
assert_eq!(c_val_right.dropped(), 1); // should be 2 were it not for Rust issue #47949
}
#[test]
fn test_merge_ord_chaos() {
let mut map1 = BTreeMap::new();
map1.insert(Cyclic3::A, ());
map1.insert(Cyclic3::B, ());
let mut map2 = BTreeMap::new();
map2.insert(Cyclic3::A, ());
map2.insert(Cyclic3::B, ());
map2.insert(Cyclic3::C, ()); // lands first, before A
map2.insert(Cyclic3::B, ()); // lands first, before C
map1.check();
map2.check(); // keys are not unique but still strictly ascending
assert_eq!(map1.len(), 2);
assert_eq!(map2.len(), 4);
map1.merge(map2, |_, _, _| ());
assert_eq!(map1.len(), 5);
map1.check();
}
fn rand_data(len: usize) -> Vec<(u32, u32)> {
let mut rng = DeterministicRng::new();
Vec::from_iter((0..len).map(|_| (rng.next(), rng.next())))
@@ -2695,9 +2783,15 @@ fn test_id_based_merge() {
rhs.insert(IdBased { id: 0, name: "rhs_k".to_string() }, "2".to_string());
lhs.merge(rhs, |_, mut lhs_val, rhs_val| {
// confirming that lhs_val comes from lhs tree,
// rhs_val comes from rhs tree
assert_eq!(lhs_val, String::from("1"));
assert_eq!(rhs_val, String::from("2"));
lhs_val.push_str(&rhs_val);
lhs_val
});
assert_eq!(lhs.pop_first().unwrap().0.name, "lhs_k".to_string());
let merged_kv_pair = lhs.pop_first().unwrap();
assert_eq!(merged_kv_pair.0.id, 0);
assert_eq!(merged_kv_pair.0.name, "lhs_k".to_string());
}