mirror of
https://github.com/rust-lang/rust.git
synced 2026-04-27 18:57:42 +03:00
Rollup merge of #153776 - zetanumbers:curry-howard-dyn-thread-safe, r=JonathanBrouwer
Remove redundant `is_dyn_thread_safe` checks Refactor uses of `FromDyn` to reduce number of redundant `is_dyn_thread_safe` checks by replacing `FromDyn::from` with `check_dyn_thread_safe` in tandem with existing `FromDyn::derive` so that the users would avoid redundancy in the future. PR is split up into multiple commits for an easier review.
This commit is contained in:
@@ -188,53 +188,6 @@ pub fn assert_dyn_send<T: ?Sized + PointeeSized + DynSend>() {}
|
||||
pub fn assert_dyn_send_val<T: ?Sized + PointeeSized + DynSend>(_t: &T) {}
|
||||
pub fn assert_dyn_send_sync_val<T: ?Sized + PointeeSized + DynSync + DynSend>(_t: &T) {}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct FromDyn<T>(T);
|
||||
|
||||
impl<T> FromDyn<T> {
|
||||
#[inline(always)]
|
||||
pub fn from(val: T) -> Self {
|
||||
// Check that `sync::is_dyn_thread_safe()` is true on creation so we can
|
||||
// implement `Send` and `Sync` for this structure when `T`
|
||||
// implements `DynSend` and `DynSync` respectively.
|
||||
assert!(crate::sync::is_dyn_thread_safe());
|
||||
FromDyn(val)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn derive<O>(&self, val: O) -> FromDyn<O> {
|
||||
// We already did the check for `sync::is_dyn_thread_safe()` when creating `Self`
|
||||
FromDyn(val)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn into_inner(self) -> T {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
// `FromDyn` is `Send` if `T` is `DynSend`, since it ensures that sync::is_dyn_thread_safe() is true.
|
||||
unsafe impl<T: DynSend> Send for FromDyn<T> {}
|
||||
|
||||
// `FromDyn` is `Sync` if `T` is `DynSync`, since it ensures that sync::is_dyn_thread_safe() is true.
|
||||
unsafe impl<T: DynSync> Sync for FromDyn<T> {}
|
||||
|
||||
impl<T> std::ops::Deref for FromDyn<T> {
|
||||
type Target = T;
|
||||
|
||||
#[inline(always)]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for FromDyn<T> {
|
||||
#[inline(always)]
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
// A wrapper to convert a struct that is already a `Send` or `Sync` into
|
||||
// an instance of `DynSend` and `DynSync`, since the compiler cannot infer
|
||||
// it automatically in some cases. (e.g. Box<dyn Send / Sync>)
|
||||
|
||||
@@ -34,7 +34,9 @@
|
||||
pub use self::freeze::{FreezeLock, FreezeReadGuard, FreezeWriteGuard};
|
||||
#[doc(no_inline)]
|
||||
pub use self::lock::{Lock, LockGuard, Mode};
|
||||
pub use self::mode::{is_dyn_thread_safe, set_dyn_thread_safe_mode};
|
||||
pub use self::mode::{
|
||||
FromDyn, check_dyn_thread_safe, is_dyn_thread_safe, set_dyn_thread_safe_mode,
|
||||
};
|
||||
pub use self::parallel::{
|
||||
broadcast, par_fns, par_for_each_in, par_join, par_map, parallel_guard, spawn,
|
||||
try_par_for_each_in,
|
||||
@@ -64,12 +66,20 @@ mod atomic {
|
||||
mod mode {
|
||||
use std::sync::atomic::{AtomicU8, Ordering};
|
||||
|
||||
use crate::sync::{DynSend, DynSync};
|
||||
|
||||
const UNINITIALIZED: u8 = 0;
|
||||
const DYN_NOT_THREAD_SAFE: u8 = 1;
|
||||
const DYN_THREAD_SAFE: u8 = 2;
|
||||
|
||||
static DYN_THREAD_SAFE_MODE: AtomicU8 = AtomicU8::new(UNINITIALIZED);
|
||||
|
||||
// Whether thread safety is enabled (due to running under multiple threads).
|
||||
#[inline]
|
||||
pub fn check_dyn_thread_safe() -> Option<FromDyn<()>> {
|
||||
is_dyn_thread_safe().then_some(FromDyn(()))
|
||||
}
|
||||
|
||||
// Whether thread safety is enabled (due to running under multiple threads).
|
||||
#[inline]
|
||||
pub fn is_dyn_thread_safe() -> bool {
|
||||
@@ -99,6 +109,44 @@ pub fn set_dyn_thread_safe_mode(mode: bool) {
|
||||
// Check that the mode was either uninitialized or was already set to the requested mode.
|
||||
assert!(previous.is_ok() || previous == Err(set));
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct FromDyn<T>(T);
|
||||
|
||||
impl<T> FromDyn<T> {
|
||||
#[inline(always)]
|
||||
pub fn derive<O>(&self, val: O) -> FromDyn<O> {
|
||||
// We already did the check for `sync::is_dyn_thread_safe()` when creating `Self`
|
||||
FromDyn(val)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn into_inner(self) -> T {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
// `FromDyn` is `Send` if `T` is `DynSend`, since it ensures that sync::is_dyn_thread_safe() is true.
|
||||
unsafe impl<T: DynSend> Send for FromDyn<T> {}
|
||||
|
||||
// `FromDyn` is `Sync` if `T` is `DynSync`, since it ensures that sync::is_dyn_thread_safe() is true.
|
||||
unsafe impl<T: DynSync> Sync for FromDyn<T> {}
|
||||
|
||||
impl<T> std::ops::Deref for FromDyn<T> {
|
||||
type Target = T;
|
||||
|
||||
#[inline(always)]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for FromDyn<T> {
|
||||
#[inline(always)]
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This makes locks panic if they are already held.
|
||||
|
||||
@@ -57,8 +57,8 @@ fn serial_join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
|
||||
}
|
||||
|
||||
pub fn spawn(func: impl FnOnce() + DynSend + 'static) {
|
||||
if mode::is_dyn_thread_safe() {
|
||||
let func = FromDyn::from(func);
|
||||
if let Some(proof) = mode::check_dyn_thread_safe() {
|
||||
let func = proof.derive(func);
|
||||
rustc_thread_pool::spawn(|| {
|
||||
(func.into_inner())();
|
||||
});
|
||||
@@ -73,8 +73,8 @@ pub fn spawn(func: impl FnOnce() + DynSend + 'static) {
|
||||
/// Use that for the longest running function for better scheduling.
|
||||
pub fn par_fns(funcs: &mut [&mut (dyn FnMut() + DynSend)]) {
|
||||
parallel_guard(|guard: &ParallelGuard| {
|
||||
if mode::is_dyn_thread_safe() {
|
||||
let funcs = FromDyn::from(funcs);
|
||||
if let Some(proof) = mode::check_dyn_thread_safe() {
|
||||
let funcs = proof.derive(funcs);
|
||||
rustc_thread_pool::scope(|s| {
|
||||
let Some((first, rest)) = funcs.into_inner().split_at_mut_checked(1) else {
|
||||
return;
|
||||
@@ -84,7 +84,7 @@ pub fn par_fns(funcs: &mut [&mut (dyn FnMut() + DynSend)]) {
|
||||
// order when using a single thread. This ensures the execution order matches
|
||||
// that of a single threaded rustc.
|
||||
for f in rest.iter_mut().rev() {
|
||||
let f = FromDyn::from(f);
|
||||
let f = proof.derive(f);
|
||||
s.spawn(|_| {
|
||||
guard.run(|| (f.into_inner())());
|
||||
});
|
||||
@@ -108,13 +108,13 @@ pub fn par_join<A, B, RA: DynSend, RB: DynSend>(oper_a: A, oper_b: B) -> (RA, RB
|
||||
A: FnOnce() -> RA + DynSend,
|
||||
B: FnOnce() -> RB + DynSend,
|
||||
{
|
||||
if mode::is_dyn_thread_safe() {
|
||||
let oper_a = FromDyn::from(oper_a);
|
||||
let oper_b = FromDyn::from(oper_b);
|
||||
if let Some(proof) = mode::check_dyn_thread_safe() {
|
||||
let oper_a = proof.derive(oper_a);
|
||||
let oper_b = proof.derive(oper_b);
|
||||
let (a, b) = parallel_guard(|guard| {
|
||||
rustc_thread_pool::join(
|
||||
move || guard.run(move || FromDyn::from(oper_a.into_inner()())),
|
||||
move || guard.run(move || FromDyn::from(oper_b.into_inner()())),
|
||||
move || guard.run(move || proof.derive(oper_a.into_inner()())),
|
||||
move || guard.run(move || proof.derive(oper_b.into_inner()())),
|
||||
)
|
||||
});
|
||||
(a.unwrap().into_inner(), b.unwrap().into_inner())
|
||||
@@ -127,8 +127,9 @@ fn par_slice<I: DynSend>(
|
||||
items: &mut [I],
|
||||
guard: &ParallelGuard,
|
||||
for_each: impl Fn(&mut I) + DynSync + DynSend,
|
||||
proof: FromDyn<()>,
|
||||
) {
|
||||
let for_each = FromDyn::from(for_each);
|
||||
let for_each = proof.derive(for_each);
|
||||
let mut items = for_each.derive(items);
|
||||
rustc_thread_pool::scope(|s| {
|
||||
let proof = items.derive(());
|
||||
@@ -150,9 +151,9 @@ pub fn par_for_each_in<I: DynSend, T: IntoIterator<Item = I>>(
|
||||
for_each: impl Fn(&I) + DynSync + DynSend,
|
||||
) {
|
||||
parallel_guard(|guard| {
|
||||
if mode::is_dyn_thread_safe() {
|
||||
if let Some(proof) = mode::check_dyn_thread_safe() {
|
||||
let mut items: Vec<_> = t.into_iter().collect();
|
||||
par_slice(&mut items, guard, |i| for_each(&*i))
|
||||
par_slice(&mut items, guard, |i| for_each(&*i), proof)
|
||||
} else {
|
||||
t.into_iter().for_each(|i| {
|
||||
guard.run(|| for_each(&i));
|
||||
@@ -173,16 +174,21 @@ pub fn try_par_for_each_in<T: IntoIterator, E: DynSend>(
|
||||
<T as IntoIterator>::Item: DynSend,
|
||||
{
|
||||
parallel_guard(|guard| {
|
||||
if mode::is_dyn_thread_safe() {
|
||||
if let Some(proof) = mode::check_dyn_thread_safe() {
|
||||
let mut items: Vec<_> = t.into_iter().collect();
|
||||
|
||||
let error = Mutex::new(None);
|
||||
|
||||
par_slice(&mut items, guard, |i| {
|
||||
if let Err(err) = for_each(&*i) {
|
||||
*error.lock() = Some(err);
|
||||
}
|
||||
});
|
||||
par_slice(
|
||||
&mut items,
|
||||
guard,
|
||||
|i| {
|
||||
if let Err(err) = for_each(&*i) {
|
||||
*error.lock() = Some(err);
|
||||
}
|
||||
},
|
||||
proof,
|
||||
);
|
||||
|
||||
if let Some(err) = error.into_inner() { Err(err) } else { Ok(()) }
|
||||
} else {
|
||||
@@ -196,15 +202,20 @@ pub fn par_map<I: DynSend, T: IntoIterator<Item = I>, R: DynSend, C: FromIterato
|
||||
map: impl Fn(I) -> R + DynSync + DynSend,
|
||||
) -> C {
|
||||
parallel_guard(|guard| {
|
||||
if mode::is_dyn_thread_safe() {
|
||||
let map = FromDyn::from(map);
|
||||
if let Some(proof) = mode::check_dyn_thread_safe() {
|
||||
let map = proof.derive(map);
|
||||
|
||||
let mut items: Vec<(Option<I>, Option<R>)> =
|
||||
t.into_iter().map(|i| (Some(i), None)).collect();
|
||||
|
||||
par_slice(&mut items, guard, |i| {
|
||||
i.1 = Some(map(i.0.take().unwrap()));
|
||||
});
|
||||
par_slice(
|
||||
&mut items,
|
||||
guard,
|
||||
|i| {
|
||||
i.1 = Some(map(i.0.take().unwrap()));
|
||||
},
|
||||
proof,
|
||||
);
|
||||
|
||||
items.into_iter().filter_map(|i| i.1).collect()
|
||||
} else {
|
||||
@@ -214,8 +225,8 @@ pub fn par_map<I: DynSend, T: IntoIterator<Item = I>, R: DynSend, C: FromIterato
|
||||
}
|
||||
|
||||
pub fn broadcast<R: DynSend>(op: impl Fn(usize) -> R + DynSync) -> Vec<R> {
|
||||
if mode::is_dyn_thread_safe() {
|
||||
let op = FromDyn::from(op);
|
||||
if let Some(proof) = mode::check_dyn_thread_safe() {
|
||||
let op = proof.derive(op);
|
||||
let results = rustc_thread_pool::broadcast(|context| op.derive(op(context.index())));
|
||||
results.into_iter().map(|r| r.into_inner()).collect()
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user