Auto merge of #141581 - lcnr:fold-clauses, r=compiler-errors

add additional `TypeFlags` fast paths

Some crates, e.g. `diesel`, have items with a lot of where-clauses (more than 150). In these cases checking the `TypeFlags` of the whole `param_env` can be very beneficial.

This adds `fn fold_clauses` to mirror the existing `fn visit_clauses` and then uses this in folders which fold `ParamEnv`s.

Split out from rust-lang/rust#141451, depends on rust-lang/rust#141442.

r? `@compiler-errors`
This commit is contained in:
bors
2025-05-29 02:29:01 +00:00
16 changed files with 138 additions and 18 deletions
@@ -497,6 +497,10 @@ fn fold_const(&mut self, mut ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p }
}
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
if c.flags().intersects(self.needs_canonical_flags) { c.super_fold_with(self) } else { c }
}
}
impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
@@ -55,6 +55,14 @@ fn fold_const(&mut self, ct: Const<'tcx>) -> Const<'tcx> {
ct.super_fold_with(self)
}
}
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if !p.has_non_region_infer() { p } else { p.super_fold_with(self) }
}
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
if !c.has_non_region_infer() { c } else { c.super_fold_with(self) }
}
}
/// The opportunistic region resolver opportunistically resolves regions
@@ -86,4 +86,12 @@ fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
p
}
}
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
if c.has_type_flags(TypeFlags::HAS_BINDER_VARS | TypeFlags::HAS_FREE_REGIONS) {
c.super_fold_with(self)
} else {
c
}
}
}
+4
View File
@@ -177,6 +177,10 @@ fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
}
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
if c.has_vars_bound_at_or_above(self.current_index) { c.super_fold_with(self) } else { c }
}
}
impl<'tcx> TyCtxt<'tcx> {
@@ -238,6 +238,8 @@ pub fn as_region_outlives_clause(
}
}
impl<'tcx> rustc_type_ir::inherent::Clauses<TyCtxt<'tcx>> for ty::Clauses<'tcx> {}
#[extension(pub trait ExistentialPredicateStableCmpExt<'tcx>)]
impl<'tcx> ExistentialPredicate<'tcx> {
/// Compares via an ordering that will not change if modules are reordered or other changes are
@@ -570,6 +570,19 @@ fn fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
}
}
impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
self,
folder: &mut F,
) -> Result<Self, F::Error> {
folder.try_fold_clauses(self)
}
fn fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
folder.fold_clauses(self)
}
}
impl<'tcx> TypeVisitable<TyCtxt<'tcx>> for ty::Predicate<'tcx> {
fn visit_with<V: TypeVisitor<TyCtxt<'tcx>>>(&self, visitor: &mut V) -> V::Result {
visitor.visit_predicate(*self)
@@ -615,6 +628,19 @@ fn super_visit_with<V: TypeVisitor<TyCtxt<'tcx>>>(&self, visitor: &mut V) -> V::
}
}
impl<'tcx> TypeSuperFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
fn try_super_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
self,
folder: &mut F,
) -> Result<Self, F::Error> {
ty::util::try_fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
}
fn super_fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
ty::util::fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
}
}
impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Const<'tcx> {
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
self,
@@ -775,7 +801,6 @@ fn fold_with<F: TypeFolder<TyCtxt<'tcx>>>(
}
list_fold! {
ty::Clauses<'tcx> : mk_clauses,
&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> : mk_poly_existential_predicates,
&'tcx ty::List<PlaceElem<'tcx>> : mk_place_elems,
&'tcx ty::List<ty::Pattern<'tcx>> : mk_patterns,
@@ -572,4 +572,15 @@ fn fold_const(&mut self, c: I::Const) -> I::Const {
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
if p.flags().intersects(NEEDS_CANONICAL) { p.super_fold_with(self) } else { p }
}
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
match self.canonicalize_mode {
CanonicalizeMode::Input { keep_static: true }
| CanonicalizeMode::Response { max_input_universe: _ } => {}
CanonicalizeMode::Input { keep_static: false } => {
panic!("erasing 'static in env")
}
}
if c.flags().intersects(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c }
}
}
@@ -11,7 +11,7 @@
// EAGER RESOLUTION
/// Resolves ty, region, and const vars to their inferred values or their root vars.
pub struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
where
D: SolverDelegate<Interner = I>,
I: Interner,
@@ -22,8 +22,20 @@ pub struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
cache: DelayedMap<I::Ty, I::Ty>,
}
pub fn eager_resolve_vars<D: SolverDelegate, T: TypeFoldable<D::Interner>>(
delegate: &D,
value: T,
) -> T {
if value.has_infer() {
let mut folder = EagerResolver::new(delegate);
value.fold_with(&mut folder)
} else {
value
}
}
impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
pub fn new(delegate: &'a D) -> Self {
fn new(delegate: &'a D) -> Self {
EagerResolver { delegate, cache: Default::default() }
}
}
@@ -90,4 +102,8 @@ fn fold_const(&mut self, c: I::Const) -> I::Const {
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
if p.has_infer() { p.super_fold_with(self) } else { p }
}
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
if c.has_infer() { c.super_fold_with(self) } else { c }
}
}
@@ -22,7 +22,7 @@
use crate::canonicalizer::Canonicalizer;
use crate::delegate::SolverDelegate;
use crate::resolve::EagerResolver;
use crate::resolve::eager_resolve_vars;
use crate::solve::eval_ctxt::CurrentGoalKind;
use crate::solve::{
CanonicalInput, CanonicalResponse, Certainty, EvalCtxt, ExternalConstraintsData, Goal,
@@ -61,8 +61,7 @@ pub(super) fn canonicalize_goal(
// so we only canonicalize the lookup table and ignore
// duplicate entries.
let opaque_types = self.delegate.clone_opaque_types_lookup_table();
let (goal, opaque_types) =
(goal, opaque_types).fold_with(&mut EagerResolver::new(self.delegate));
let (goal, opaque_types) = eager_resolve_vars(self.delegate, (goal, opaque_types));
let mut orig_values = Default::default();
let canonical = Canonicalizer::canonicalize_input(
@@ -162,8 +161,8 @@ pub(in crate::solve) fn evaluate_added_goals_and_make_canonical_response(
let external_constraints =
self.compute_external_query_constraints(certainty, normalization_nested_goals);
let (var_values, mut external_constraints) = (self.var_values, external_constraints)
.fold_with(&mut EagerResolver::new(self.delegate));
let (var_values, mut external_constraints) =
eager_resolve_vars(self.delegate, (self.var_values, external_constraints));
// Remove any trivial or duplicated region constraints once we've resolved regions
let mut unique = HashSet::default();
@@ -474,7 +473,7 @@ pub(in crate::solve) fn make_canonical_state<D, T, I>(
{
let var_values = CanonicalVarValues { var_values: delegate.cx().mk_args(var_values) };
let state = inspect::State { var_values, data };
let state = state.fold_with(&mut EagerResolver::new(delegate));
let state = eager_resolve_vars(delegate, state);
Canonicalizer::canonicalize_response(delegate, max_input_universe, &mut vec![], state)
}
@@ -925,6 +925,22 @@ fn visit_const(&mut self, c: I::Const) -> Self::Result {
}
}
}
fn visit_predicate(&mut self, p: I::Predicate) -> Self::Result {
if p.has_non_region_infer() || p.has_placeholders() {
p.super_visit_with(self)
} else {
ControlFlow::Continue(())
}
}
fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result {
if c.has_non_region_infer() || c.has_placeholders() {
c.super_visit_with(self)
} else {
ControlFlow::Continue(())
}
}
}
let mut visitor = ContainsTermOrNotNameable {
@@ -15,9 +15,9 @@
use rustc_macros::extension;
use rustc_middle::traits::ObligationCause;
use rustc_middle::traits::solve::{Certainty, Goal, GoalSource, NoSolution, QueryResult};
use rustc_middle::ty::{TyCtxt, TypeFoldable, VisitorResult, try_visit};
use rustc_middle::ty::{TyCtxt, VisitorResult, try_visit};
use rustc_middle::{bug, ty};
use rustc_next_trait_solver::resolve::EagerResolver;
use rustc_next_trait_solver::resolve::eager_resolve_vars;
use rustc_next_trait_solver::solve::inspect::{self, instantiate_canonical_state};
use rustc_next_trait_solver::solve::{GenerateProofTree, MaybeCause, SolverDelegateEvalExt as _};
use rustc_span::{DUMMY_SP, Span};
@@ -187,8 +187,7 @@ pub fn instantiate_nested_goals_and_opt_impl_args(
let _ = term_hack.constrain(infcx, span, param_env);
}
let opt_impl_args =
opt_impl_args.map(|impl_args| impl_args.fold_with(&mut EagerResolver::new(infcx)));
let opt_impl_args = opt_impl_args.map(|impl_args| eager_resolve_vars(infcx, impl_args));
let goals = instantiated_goals
.into_iter()
@@ -392,7 +391,7 @@ fn new(
infcx,
depth,
orig_values,
goal: uncanonicalized_goal.fold_with(&mut EagerResolver::new(infcx)),
goal: eager_resolve_vars(infcx, uncanonicalized_goal),
result,
evaluation_kind: evaluation.kind,
normalizes_to_term_hack,
+8
View File
@@ -711,6 +711,14 @@ fn fold_const(&mut self, c: I::Const) -> I::Const {
c.super_fold_with(self)
}
}
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
if p.has_param() { p.super_fold_with(self) } else { p }
}
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
if c.has_param() { c.super_fold_with(self) } else { c }
}
}
impl<'a, I: Interner> ArgFolder<'a, I> {
+8
View File
@@ -152,6 +152,10 @@ fn fold_const(&mut self, c: I::Const) -> I::Const {
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
p.super_fold_with(self)
}
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
c.super_fold_with(self)
}
}
/// This trait is implemented for every folding traversal. There is a fold
@@ -190,6 +194,10 @@ fn try_fold_const(&mut self, c: I::Const) -> Result<I::Const, Self::Error> {
fn try_fold_predicate(&mut self, p: I::Predicate) -> Result<I::Predicate, Self::Error> {
p.try_super_fold_with(self)
}
fn try_fold_clauses(&mut self, c: I::Clauses) -> Result<I::Clauses, Self::Error> {
c.try_super_fold_with(self)
}
}
///////////////////////////////////////////////////////////////////////////
+12
View File
@@ -511,6 +511,18 @@ fn as_projection_clause(self) -> Option<ty::Binder<I, ty::ProjectionPredicate<I>
fn instantiate_supertrait(self, cx: I, trait_ref: ty::Binder<I, ty::TraitRef<I>>) -> Self;
}
pub trait Clauses<I: Interner<Clauses = Self>>:
Copy
+ Debug
+ Hash
+ Eq
+ TypeSuperVisitable<I>
+ TypeSuperFoldable<I>
+ Flags
+ SliceLike<Item = I::Clause>
{
}
/// Common capabilities of placeholder kinds
pub trait PlaceholderLike: Copy + Debug + Hash + Eq {
fn universe(self) -> ty::UniverseIndex;
+2 -2
View File
@@ -12,7 +12,7 @@
use crate::lang_items::TraitSolverLangItem;
use crate::relate::Relate;
use crate::solve::{CanonicalInput, ExternalConstraintsData, PredefinedOpaquesData, QueryResult};
use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable};
use crate::visit::{Flags, TypeVisitable};
use crate::{self as ty, search_graph};
#[cfg_attr(feature = "nightly", rustc_diagnostic_item = "type_ir_interner")]
@@ -146,7 +146,7 @@ fn mk_tracked<T: Debug + Clone>(
type ParamEnv: ParamEnv<Self>;
type Predicate: Predicate<Self>;
type Clause: Clause<Self>;
type Clauses: Copy + Debug + Hash + Eq + TypeSuperVisitable<Self> + Flags;
type Clauses: Clauses<Self>;
fn with_global_cache<R>(self, f: impl FnOnce(&mut search_graph::GlobalCache<Self>) -> R) -> R;
+2 -2
View File
@@ -120,8 +120,8 @@ fn visit_predicate(&mut self, p: I::Predicate) -> Self::Result {
p.super_visit_with(self)
}
fn visit_clauses(&mut self, p: I::Clauses) -> Self::Result {
p.super_visit_with(self)
fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result {
c.super_visit_with(self)
}
fn visit_error(&mut self, _guar: I::ErrorGuaranteed) -> Self::Result {