mirror of
https://github.com/rust-lang/rust.git
synced 2026-04-27 18:57:42 +03:00
Optimized BTreeMap::merge using CursorMut
This commit is contained in:
@@ -33,36 +33,6 @@ pub(super) fn append_from_sorted_iters<I, A: Allocator + Clone>(
|
||||
self.bulk_push(iter, length, alloc)
|
||||
}
|
||||
|
||||
/// Merges all key-value pairs from the union of two ascending iterators,
|
||||
/// incrementing a `length` variable along the way. The latter makes it
|
||||
/// easier for the caller to avoid a leak when a drop handler panicks.
|
||||
///
|
||||
/// If both iterators produce the same key, this method constructs a pair using the
|
||||
/// key from the left iterator and calls on a closure `f` to return a value given
|
||||
/// the conflicting key and value from left and right iterators.
|
||||
///
|
||||
/// If you want the tree to end up in a strictly ascending order, like for
|
||||
/// a `BTreeMap`, both iterators should produce keys in strictly ascending
|
||||
/// order, each greater than all keys in the tree, including any keys
|
||||
/// already in the tree upon entry.
|
||||
pub(super) fn merge_from_sorted_iters_with<I, A: Allocator + Clone>(
|
||||
&mut self,
|
||||
left: I,
|
||||
right: I,
|
||||
length: &mut usize,
|
||||
alloc: A,
|
||||
f: impl FnMut(&K, V, V) -> V,
|
||||
) where
|
||||
K: Ord,
|
||||
I: Iterator<Item = (K, V)> + FusedIterator,
|
||||
{
|
||||
// We prepare to merge `left` and `right` into a sorted sequence in linear time.
|
||||
let iter = MergeIterWith { inner: MergeIterInner::new(left, right), f };
|
||||
|
||||
// Meanwhile, we build a tree from the sorted sequence in linear time.
|
||||
self.bulk_push(iter, length, alloc)
|
||||
}
|
||||
|
||||
/// Pushes all key-value pairs to the end of the tree, incrementing a
|
||||
/// `length` variable along the way. The latter makes it easier for the
|
||||
/// caller to avoid a leak when the iterator panicks.
|
||||
@@ -145,33 +115,3 @@ fn next(&mut self) -> Option<(K, V)> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator for merging two sorted sequences into one with
|
||||
/// a callback function to return a value on conflicting keys
|
||||
struct MergeIterWith<F, K, V, I: Iterator<Item = (K, V)>> {
|
||||
inner: MergeIterInner<I>,
|
||||
f: F,
|
||||
}
|
||||
|
||||
impl<F, K: Ord, V, I> Iterator for MergeIterWith<F, K, V, I>
|
||||
where
|
||||
F: FnMut(&K, V, V) -> V,
|
||||
I: Iterator<Item = (K, V)> + FusedIterator,
|
||||
{
|
||||
type Item = (K, V);
|
||||
|
||||
/// If two keys are equal, returns the key from the left and uses `f` to return
|
||||
/// a value given the conflicting key and values from left and right
|
||||
fn next(&mut self) -> Option<(K, V)> {
|
||||
let (a_next, b_next) = self.inner.nexts(|a: &(K, V), b: &(K, V)| K::cmp(&a.0, &b.0));
|
||||
match (a_next, b_next) {
|
||||
(Some((a_k, a_v)), Some((_, b_v))) => Some({
|
||||
let next_val = (self.f)(&a_k, a_v, b_v);
|
||||
(a_k, next_val)
|
||||
}),
|
||||
(Some(a), None) => Some(a),
|
||||
(None, Some(b)) => Some(b),
|
||||
(None, None) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1287,7 +1287,7 @@ pub fn append(&mut self, other: &mut Self)
|
||||
/// assert_eq!(a[&5], "f");
|
||||
/// ```
|
||||
#[unstable(feature = "btree_merge", issue = "152152")]
|
||||
pub fn merge(&mut self, mut other: Self, conflict: impl FnMut(&K, V, V) -> V)
|
||||
pub fn merge(&mut self, mut other: Self, mut conflict: impl FnMut(&K, V, V) -> V)
|
||||
where
|
||||
K: Ord,
|
||||
A: Clone,
|
||||
@@ -1303,16 +1303,75 @@ pub fn merge(&mut self, mut other: Self, conflict: impl FnMut(&K, V, V) -> V)
|
||||
return;
|
||||
}
|
||||
|
||||
let self_iter = mem::replace(self, Self::new_in((*self.alloc).clone())).into_iter();
|
||||
let other_iter = mem::replace(&mut other, Self::new_in((*self.alloc).clone())).into_iter();
|
||||
let root = self.root.get_or_insert_with(|| Root::new((*self.alloc).clone()));
|
||||
root.merge_from_sorted_iters_with(
|
||||
self_iter,
|
||||
other_iter,
|
||||
&mut self.length,
|
||||
(*self.alloc).clone(),
|
||||
conflict,
|
||||
)
|
||||
let mut other_iter = other.into_iter();
|
||||
let (first_other_key, first_other_val) = other_iter.next().unwrap();
|
||||
|
||||
// find the first gap that has the smallest key greater than or equal to
|
||||
// the first key from other
|
||||
let mut self_cursor = self.lower_bound_mut(Bound::Included(&first_other_key));
|
||||
|
||||
if let Some((self_key, _)) = self_cursor.peek_next() {
|
||||
match K::cmp(&first_other_key, self_key) {
|
||||
Ordering::Equal => {
|
||||
self_cursor.with_next(|self_key, self_val| {
|
||||
conflict(self_key, self_val, first_other_val)
|
||||
});
|
||||
}
|
||||
Ordering::Less =>
|
||||
// SAFETY: we know our other_key's ordering is less than self_key,
|
||||
// so inserting before will guarantee sorted order
|
||||
unsafe {
|
||||
self_cursor.insert_before_unchecked(first_other_key, first_other_val);
|
||||
},
|
||||
Ordering::Greater => {
|
||||
unreachable!("Cursor's peek_next should return None.");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// SAFETY: reaching here means our cursor is at the end
|
||||
// self BTreeMap so we just insert other_key here
|
||||
unsafe {
|
||||
self_cursor.insert_before_unchecked(first_other_key, first_other_val);
|
||||
}
|
||||
}
|
||||
|
||||
for (other_key, other_val) in other_iter {
|
||||
loop {
|
||||
if let Some((self_key, _)) = self_cursor.peek_next() {
|
||||
match K::cmp(&other_key, self_key) {
|
||||
Ordering::Equal => {
|
||||
self_cursor.with_next(|self_key, self_val| {
|
||||
conflict(self_key, self_val, other_val)
|
||||
});
|
||||
break;
|
||||
}
|
||||
Ordering::Less => {
|
||||
// SAFETY: we know our other_key's ordering is less than self_key,
|
||||
// so inserting before will guarantee sorted order
|
||||
unsafe {
|
||||
self_cursor.insert_before_unchecked(other_key, other_val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ordering::Greater => {
|
||||
// FIXME: instead of doing a linear search here,
|
||||
// this can be optimized to search the tree by starting
|
||||
// from self_cursor and going towards the root and then
|
||||
// back down to the proper node -- that should probably
|
||||
// be a new method on Cursor*.
|
||||
self_cursor.next();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// SAFETY: reaching here means our cursor is at the end
|
||||
// self BTreeMap so we just insert other_key here
|
||||
unsafe {
|
||||
self_cursor.insert_before_unchecked(other_key, other_val);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs a double-ended iterator over a sub-range of elements in the map.
|
||||
@@ -3337,6 +3396,37 @@ pub fn as_cursor(&self) -> Cursor<'_, K, V> {
|
||||
|
||||
// Now the tree editing operations
|
||||
impl<'a, K: Ord, V, A: Allocator + Clone> CursorMutKey<'a, K, V, A> {
|
||||
/// Calls a function with ownership of the next element's key and
|
||||
/// and value and expects it to return a value to write
|
||||
/// back to the next element's key and value. The cursor is not
|
||||
/// advanced forward.
|
||||
///
|
||||
/// If the cursor is at the end of the map then the function is not called
|
||||
/// and this essentially does not do anything.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// You must ensure that the `BTreeMap` invariants are maintained.
|
||||
/// Specifically:
|
||||
///
|
||||
/// * The next element's key must be unique in the tree.
|
||||
/// * All keys in the tree must remain in sorted order.
|
||||
#[allow(dead_code)] /* This function exists for consistency with CursorMut */
|
||||
pub(super) fn with_next(&mut self, f: impl FnOnce(K, V) -> (K, V)) {
|
||||
// if `f` unwinds, the next entry is already removed leaving
|
||||
// the tree in valid state.
|
||||
// FIXME: Once `MaybeDangling` is implemented, we can optimize
|
||||
// this through using a drop handler and transmutating CursorMutKey<K, V>
|
||||
// to CursorMutKey<ManuallyDrop<K>, ManuallyDrop<V>> (see PR #152418)
|
||||
if let Some((k, v)) = self.remove_next() {
|
||||
// SAFETY: we remove the K, V out of the next entry,
|
||||
// apply 'f' to get a new (K, V), and insert it back
|
||||
// into the next entry that the cursor is pointing at
|
||||
let (k, v) = f(k, v);
|
||||
unsafe { self.insert_after_unchecked(k, v) };
|
||||
}
|
||||
}
|
||||
|
||||
/// Inserts a new key-value pair into the map in the gap that the
|
||||
/// cursor is currently pointing to.
|
||||
///
|
||||
@@ -3542,6 +3632,29 @@ pub fn remove_prev(&mut self) -> Option<(K, V)> {
|
||||
}
|
||||
|
||||
impl<'a, K: Ord, V, A: Allocator + Clone> CursorMut<'a, K, V, A> {
|
||||
/// Calls a function with a reference to the next element's key and
|
||||
/// ownership of its value. The function is expected to return a value
|
||||
/// to write back to the next element's value. The cursor is not
|
||||
/// advanced forward.
|
||||
///
|
||||
/// If the cursor is at the end of the map then the function is not called
|
||||
/// and this essentially does not do anything.
|
||||
pub(super) fn with_next(&mut self, f: impl FnOnce(&K, V) -> V) {
|
||||
// FIXME: This can be optimized to not do all the removing/reinserting
|
||||
// logic by using ptr::read, calling `f`, and then using ptr::write.
|
||||
// if `f` unwinds, then we need to remove the entry while being careful to
|
||||
// not cause UB by moving or dropping the already-dropped `V`
|
||||
// for the entry. Some implementation ideas:
|
||||
// https://github.com/rust-lang/rust/pull/152418#discussion_r2800232576
|
||||
if let Some((k, v)) = self.remove_next() {
|
||||
// SAFETY: we remove the K, V out of the next entry,
|
||||
// apply 'f' to get a new V, and insert (K, V) back
|
||||
// into the next entry that the cursor is pointing at
|
||||
let v = f(&k, v);
|
||||
unsafe { self.insert_after_unchecked(k, v) };
|
||||
}
|
||||
}
|
||||
|
||||
/// Inserts a new key-value pair into the map in the gap that the
|
||||
/// cursor is currently pointing to.
|
||||
///
|
||||
|
||||
@@ -2128,6 +2128,16 @@ fn $name() {
|
||||
#[cfg(not(miri))] // Miri is too slow
|
||||
create_append_test!(test_append_1700, 1700);
|
||||
|
||||
// a inserts (0, 0)..(8, 8) to its own tree
|
||||
// b inserts (5, 5 * 2)..($len, 2 * $len) to its own tree
|
||||
// note that between a and b, there are duplicate keys
|
||||
// between 5..min($len, 8), so on merge we add the values
|
||||
// of these keys together
|
||||
// we check that:
|
||||
// - the merged tree 'a' has a length of max(8, $len)
|
||||
// - all keys in 'a' have the correct value associated
|
||||
// - removing and inserting an element into the merged
|
||||
// tree 'a' still keeps it in valid tree form
|
||||
macro_rules! create_merge_test {
|
||||
($name:ident, $len:expr) => {
|
||||
#[test]
|
||||
@@ -2239,6 +2249,84 @@ fn test_append_ord_chaos() {
|
||||
map2.check();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")]
|
||||
fn test_merge_drop_leak() {
|
||||
let a = CrashTestDummy::new(0);
|
||||
let b = CrashTestDummy::new(1);
|
||||
let c = CrashTestDummy::new(2);
|
||||
let mut left = BTreeMap::new();
|
||||
let mut right = BTreeMap::new();
|
||||
left.insert(a.spawn(Panic::Never), ());
|
||||
left.insert(b.spawn(Panic::Never), ());
|
||||
left.insert(c.spawn(Panic::Never), ());
|
||||
right.insert(b.spawn(Panic::InDrop), ()); // first duplicate key, dropped during merge
|
||||
right.insert(c.spawn(Panic::Never), ());
|
||||
|
||||
catch_unwind(move || left.merge(right, |_, _, _| ())).unwrap_err();
|
||||
assert_eq!(a.dropped(), 1); // this should not be dropped
|
||||
assert_eq!(b.dropped(), 2); // key is dropped on panic
|
||||
assert_eq!(c.dropped(), 2); // key is dropped on panic
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")]
|
||||
fn test_merge_conflict_drop_leak() {
|
||||
let a = CrashTestDummy::new(0);
|
||||
let a_val_left = CrashTestDummy::new(0);
|
||||
|
||||
let b = CrashTestDummy::new(1);
|
||||
let b_val_left = CrashTestDummy::new(1);
|
||||
let b_val_right = CrashTestDummy::new(1);
|
||||
|
||||
let c = CrashTestDummy::new(2);
|
||||
let c_val_left = CrashTestDummy::new(2);
|
||||
let c_val_right = CrashTestDummy::new(2);
|
||||
|
||||
let mut left = BTreeMap::new();
|
||||
let mut right = BTreeMap::new();
|
||||
|
||||
left.insert(a.spawn(Panic::Never), a_val_left.spawn(Panic::Never));
|
||||
left.insert(b.spawn(Panic::Never), b_val_left.spawn(Panic::Never));
|
||||
left.insert(c.spawn(Panic::Never), c_val_left.spawn(Panic::Never));
|
||||
right.insert(b.spawn(Panic::Never), b_val_right.spawn(Panic::Never));
|
||||
right.insert(c.spawn(Panic::Never), c_val_right.spawn(Panic::Never));
|
||||
|
||||
// First key that conflicts should
|
||||
catch_unwind(move || {
|
||||
left.merge(right, |_, _, _| panic!("Panic in conflict function"));
|
||||
assert_eq!(left.len(), 1); // only 1 entry should be left
|
||||
})
|
||||
.unwrap_err();
|
||||
assert_eq!(a.dropped(), 1); // should not panic
|
||||
assert_eq!(a_val_left.dropped(), 1); // should not panic
|
||||
assert_eq!(b.dropped(), 2); // should drop from panic (conflict)
|
||||
assert_eq!(b_val_left.dropped(), 1); // should be 2 were it not for Rust issue #47949
|
||||
assert_eq!(b_val_right.dropped(), 1); // should be 2 were it not for Rust issue #47949
|
||||
assert_eq!(c.dropped(), 2); // should drop from panic (conflict)
|
||||
assert_eq!(c_val_left.dropped(), 1); // should be 2 were it not for Rust issue #47949
|
||||
assert_eq!(c_val_right.dropped(), 1); // should be 2 were it not for Rust issue #47949
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_ord_chaos() {
|
||||
let mut map1 = BTreeMap::new();
|
||||
map1.insert(Cyclic3::A, ());
|
||||
map1.insert(Cyclic3::B, ());
|
||||
let mut map2 = BTreeMap::new();
|
||||
map2.insert(Cyclic3::A, ());
|
||||
map2.insert(Cyclic3::B, ());
|
||||
map2.insert(Cyclic3::C, ()); // lands first, before A
|
||||
map2.insert(Cyclic3::B, ()); // lands first, before C
|
||||
map1.check();
|
||||
map2.check(); // keys are not unique but still strictly ascending
|
||||
assert_eq!(map1.len(), 2);
|
||||
assert_eq!(map2.len(), 4);
|
||||
map1.merge(map2, |_, _, _| ());
|
||||
assert_eq!(map1.len(), 5);
|
||||
map1.check();
|
||||
}
|
||||
|
||||
fn rand_data(len: usize) -> Vec<(u32, u32)> {
|
||||
let mut rng = DeterministicRng::new();
|
||||
Vec::from_iter((0..len).map(|_| (rng.next(), rng.next())))
|
||||
@@ -2695,9 +2783,15 @@ fn test_id_based_merge() {
|
||||
rhs.insert(IdBased { id: 0, name: "rhs_k".to_string() }, "2".to_string());
|
||||
|
||||
lhs.merge(rhs, |_, mut lhs_val, rhs_val| {
|
||||
// confirming that lhs_val comes from lhs tree,
|
||||
// rhs_val comes from rhs tree
|
||||
assert_eq!(lhs_val, String::from("1"));
|
||||
assert_eq!(rhs_val, String::from("2"));
|
||||
lhs_val.push_str(&rhs_val);
|
||||
lhs_val
|
||||
});
|
||||
|
||||
assert_eq!(lhs.pop_first().unwrap().0.name, "lhs_k".to_string());
|
||||
let merged_kv_pair = lhs.pop_first().unwrap();
|
||||
assert_eq!(merged_kv_pair.0.id, 0);
|
||||
assert_eq!(merged_kv_pair.0.name, "lhs_k".to_string());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user