diff --git a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs index 0b543f091f73..060447ba7206 100644 --- a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs +++ b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs @@ -497,6 +497,10 @@ fn fold_const(&mut self, mut ct: ty::Const<'tcx>) -> ty::Const<'tcx> { fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p } } + + fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> { + if c.flags().intersects(self.needs_canonical_flags) { c.super_fold_with(self) } else { c } + } } impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> { diff --git a/compiler/rustc_infer/src/infer/resolve.rs b/compiler/rustc_infer/src/infer/resolve.rs index 4b0ace8c554d..a95f24b5b95d 100644 --- a/compiler/rustc_infer/src/infer/resolve.rs +++ b/compiler/rustc_infer/src/infer/resolve.rs @@ -55,6 +55,14 @@ fn fold_const(&mut self, ct: Const<'tcx>) -> Const<'tcx> { ct.super_fold_with(self) } } + + fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { + if !p.has_non_region_infer() { p } else { p.super_fold_with(self) } + } + + fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> { + if !c.has_non_region_infer() { c } else { c.super_fold_with(self) } + } } /// The opportunistic region resolver opportunistically resolves regions diff --git a/compiler/rustc_middle/src/ty/erase_regions.rs b/compiler/rustc_middle/src/ty/erase_regions.rs index 45a0b1288db8..f4fead7e9526 100644 --- a/compiler/rustc_middle/src/ty/erase_regions.rs +++ b/compiler/rustc_middle/src/ty/erase_regions.rs @@ -86,4 +86,12 @@ fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { p } } + + fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> { + if c.has_type_flags(TypeFlags::HAS_BINDER_VARS | TypeFlags::HAS_FREE_REGIONS) { + c.super_fold_with(self) + } else { + c + } + } } diff --git a/compiler/rustc_middle/src/ty/fold.rs b/compiler/rustc_middle/src/ty/fold.rs index 8d6871d2f1fe..b2057fa36d7f 100644 --- a/compiler/rustc_middle/src/ty/fold.rs +++ b/compiler/rustc_middle/src/ty/fold.rs @@ -177,6 +177,10 @@ fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> { fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p } } + + fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> { + if c.has_vars_bound_at_or_above(self.current_index) { c.super_fold_with(self) } else { c } + } } impl<'tcx> TyCtxt<'tcx> { diff --git a/compiler/rustc_middle/src/ty/predicate.rs b/compiler/rustc_middle/src/ty/predicate.rs index 551d816941b6..bc2ac42b6b1f 100644 --- a/compiler/rustc_middle/src/ty/predicate.rs +++ b/compiler/rustc_middle/src/ty/predicate.rs @@ -238,6 +238,8 @@ pub fn as_region_outlives_clause( } } +impl<'tcx> rustc_type_ir::inherent::Clauses> for ty::Clauses<'tcx> {} + #[extension(pub trait ExistentialPredicateStableCmpExt<'tcx>)] impl<'tcx> ExistentialPredicate<'tcx> { /// Compares via an ordering that will not change if modules are reordered or other changes are diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs index 7a7c33fb34db..000ba7b6fa79 100644 --- a/compiler/rustc_middle/src/ty/structural_impls.rs +++ b/compiler/rustc_middle/src/ty/structural_impls.rs @@ -570,6 +570,19 @@ fn fold_with>>(self, folder: &mut F) -> Self { } } +impl<'tcx> TypeFoldable> for ty::Clauses<'tcx> { + fn try_fold_with>>( + self, + folder: &mut F, + ) -> Result { + folder.try_fold_clauses(self) + } + + fn fold_with>>(self, folder: &mut F) -> Self { + folder.fold_clauses(self) + } +} + impl<'tcx> TypeVisitable> for ty::Predicate<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_predicate(*self) @@ -615,6 +628,19 @@ fn super_visit_with>>(&self, visitor: &mut V) -> V:: } } +impl<'tcx> TypeSuperFoldable> for ty::Clauses<'tcx> { + fn try_super_fold_with>>( + self, + folder: &mut F, + ) -> Result { + ty::util::try_fold_list(self, folder, |tcx, v| tcx.mk_clauses(v)) + } + + fn super_fold_with>>(self, folder: &mut F) -> Self { + ty::util::fold_list(self, folder, |tcx, v| tcx.mk_clauses(v)) + } +} + impl<'tcx> TypeFoldable> for ty::Const<'tcx> { fn try_fold_with>>( self, @@ -775,7 +801,6 @@ fn fold_with>>( } list_fold! { - ty::Clauses<'tcx> : mk_clauses, &'tcx ty::List> : mk_poly_existential_predicates, &'tcx ty::List> : mk_place_elems, &'tcx ty::List> : mk_patterns, diff --git a/compiler/rustc_next_trait_solver/src/canonicalizer.rs b/compiler/rustc_next_trait_solver/src/canonicalizer.rs index addeb3e2b78e..e5ca2bda4592 100644 --- a/compiler/rustc_next_trait_solver/src/canonicalizer.rs +++ b/compiler/rustc_next_trait_solver/src/canonicalizer.rs @@ -572,4 +572,15 @@ fn fold_const(&mut self, c: I::Const) -> I::Const { fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate { if p.flags().intersects(NEEDS_CANONICAL) { p.super_fold_with(self) } else { p } } + + fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses { + match self.canonicalize_mode { + CanonicalizeMode::Input { keep_static: true } + | CanonicalizeMode::Response { max_input_universe: _ } => {} + CanonicalizeMode::Input { keep_static: false } => { + panic!("erasing 'static in env") + } + } + if c.flags().intersects(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c } + } } diff --git a/compiler/rustc_next_trait_solver/src/resolve.rs b/compiler/rustc_next_trait_solver/src/resolve.rs index 39abec2d7d8d..c3c57eccd6ef 100644 --- a/compiler/rustc_next_trait_solver/src/resolve.rs +++ b/compiler/rustc_next_trait_solver/src/resolve.rs @@ -11,7 +11,7 @@ // EAGER RESOLUTION /// Resolves ty, region, and const vars to their inferred values or their root vars. -pub struct EagerResolver<'a, D, I = ::Interner> +struct EagerResolver<'a, D, I = ::Interner> where D: SolverDelegate, I: Interner, @@ -22,8 +22,20 @@ pub struct EagerResolver<'a, D, I = ::Interner> cache: DelayedMap, } +pub fn eager_resolve_vars>( + delegate: &D, + value: T, +) -> T { + if value.has_infer() { + let mut folder = EagerResolver::new(delegate); + value.fold_with(&mut folder) + } else { + value + } +} + impl<'a, D: SolverDelegate> EagerResolver<'a, D> { - pub fn new(delegate: &'a D) -> Self { + fn new(delegate: &'a D) -> Self { EagerResolver { delegate, cache: Default::default() } } } @@ -90,4 +102,8 @@ fn fold_const(&mut self, c: I::Const) -> I::Const { fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate { if p.has_infer() { p.super_fold_with(self) } else { p } } + + fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses { + if c.has_infer() { c.super_fold_with(self) } else { c } + } } diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs index 66d4cd23112d..cb7c498b94e2 100644 --- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs +++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs @@ -22,7 +22,7 @@ use crate::canonicalizer::Canonicalizer; use crate::delegate::SolverDelegate; -use crate::resolve::EagerResolver; +use crate::resolve::eager_resolve_vars; use crate::solve::eval_ctxt::CurrentGoalKind; use crate::solve::{ CanonicalInput, CanonicalResponse, Certainty, EvalCtxt, ExternalConstraintsData, Goal, @@ -61,8 +61,7 @@ pub(super) fn canonicalize_goal( // so we only canonicalize the lookup table and ignore // duplicate entries. let opaque_types = self.delegate.clone_opaque_types_lookup_table(); - let (goal, opaque_types) = - (goal, opaque_types).fold_with(&mut EagerResolver::new(self.delegate)); + let (goal, opaque_types) = eager_resolve_vars(self.delegate, (goal, opaque_types)); let mut orig_values = Default::default(); let canonical = Canonicalizer::canonicalize_input( @@ -162,8 +161,8 @@ pub(in crate::solve) fn evaluate_added_goals_and_make_canonical_response( let external_constraints = self.compute_external_query_constraints(certainty, normalization_nested_goals); - let (var_values, mut external_constraints) = (self.var_values, external_constraints) - .fold_with(&mut EagerResolver::new(self.delegate)); + let (var_values, mut external_constraints) = + eager_resolve_vars(self.delegate, (self.var_values, external_constraints)); // Remove any trivial or duplicated region constraints once we've resolved regions let mut unique = HashSet::default(); @@ -474,7 +473,7 @@ pub(in crate::solve) fn make_canonical_state( { let var_values = CanonicalVarValues { var_values: delegate.cx().mk_args(var_values) }; let state = inspect::State { var_values, data }; - let state = state.fold_with(&mut EagerResolver::new(delegate)); + let state = eager_resolve_vars(delegate, state); Canonicalizer::canonicalize_response(delegate, max_input_universe, &mut vec![], state) } diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs index 2924208173d2..b1a842e04af5 100644 --- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs +++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs @@ -925,6 +925,22 @@ fn visit_const(&mut self, c: I::Const) -> Self::Result { } } } + + fn visit_predicate(&mut self, p: I::Predicate) -> Self::Result { + if p.has_non_region_infer() || p.has_placeholders() { + p.super_visit_with(self) + } else { + ControlFlow::Continue(()) + } + } + + fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result { + if c.has_non_region_infer() || c.has_placeholders() { + c.super_visit_with(self) + } else { + ControlFlow::Continue(()) + } + } } let mut visitor = ContainsTermOrNotNameable { diff --git a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs index fda2c97ed56f..1193a9059ca9 100644 --- a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs +++ b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs @@ -15,9 +15,9 @@ use rustc_macros::extension; use rustc_middle::traits::ObligationCause; use rustc_middle::traits::solve::{Certainty, Goal, GoalSource, NoSolution, QueryResult}; -use rustc_middle::ty::{TyCtxt, TypeFoldable, VisitorResult, try_visit}; +use rustc_middle::ty::{TyCtxt, VisitorResult, try_visit}; use rustc_middle::{bug, ty}; -use rustc_next_trait_solver::resolve::EagerResolver; +use rustc_next_trait_solver::resolve::eager_resolve_vars; use rustc_next_trait_solver::solve::inspect::{self, instantiate_canonical_state}; use rustc_next_trait_solver::solve::{GenerateProofTree, MaybeCause, SolverDelegateEvalExt as _}; use rustc_span::{DUMMY_SP, Span}; @@ -187,8 +187,7 @@ pub fn instantiate_nested_goals_and_opt_impl_args( let _ = term_hack.constrain(infcx, span, param_env); } - let opt_impl_args = - opt_impl_args.map(|impl_args| impl_args.fold_with(&mut EagerResolver::new(infcx))); + let opt_impl_args = opt_impl_args.map(|impl_args| eager_resolve_vars(infcx, impl_args)); let goals = instantiated_goals .into_iter() @@ -392,7 +391,7 @@ fn new( infcx, depth, orig_values, - goal: uncanonicalized_goal.fold_with(&mut EagerResolver::new(infcx)), + goal: eager_resolve_vars(infcx, uncanonicalized_goal), result, evaluation_kind: evaluation.kind, normalizes_to_term_hack, diff --git a/compiler/rustc_type_ir/src/binder.rs b/compiler/rustc_type_ir/src/binder.rs index 1e6ca0dcf5d3..55c0a3bba9f2 100644 --- a/compiler/rustc_type_ir/src/binder.rs +++ b/compiler/rustc_type_ir/src/binder.rs @@ -711,6 +711,14 @@ fn fold_const(&mut self, c: I::Const) -> I::Const { c.super_fold_with(self) } } + + fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate { + if p.has_param() { p.super_fold_with(self) } else { p } + } + + fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses { + if c.has_param() { c.super_fold_with(self) } else { c } + } } impl<'a, I: Interner> ArgFolder<'a, I> { diff --git a/compiler/rustc_type_ir/src/fold.rs b/compiler/rustc_type_ir/src/fold.rs index ce1188070ca7..a5eb8699e5fc 100644 --- a/compiler/rustc_type_ir/src/fold.rs +++ b/compiler/rustc_type_ir/src/fold.rs @@ -152,6 +152,10 @@ fn fold_const(&mut self, c: I::Const) -> I::Const { fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate { p.super_fold_with(self) } + + fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses { + c.super_fold_with(self) + } } /// This trait is implemented for every folding traversal. There is a fold @@ -190,6 +194,10 @@ fn try_fold_const(&mut self, c: I::Const) -> Result { fn try_fold_predicate(&mut self, p: I::Predicate) -> Result { p.try_super_fold_with(self) } + + fn try_fold_clauses(&mut self, c: I::Clauses) -> Result { + c.try_super_fold_with(self) + } } /////////////////////////////////////////////////////////////////////////// diff --git a/compiler/rustc_type_ir/src/inherent.rs b/compiler/rustc_type_ir/src/inherent.rs index dde55effc3d0..fa88bcb891a9 100644 --- a/compiler/rustc_type_ir/src/inherent.rs +++ b/compiler/rustc_type_ir/src/inherent.rs @@ -511,6 +511,18 @@ fn as_projection_clause(self) -> Option fn instantiate_supertrait(self, cx: I, trait_ref: ty::Binder>) -> Self; } +pub trait Clauses>: + Copy + + Debug + + Hash + + Eq + + TypeSuperVisitable + + TypeSuperFoldable + + Flags + + SliceLike +{ +} + /// Common capabilities of placeholder kinds pub trait PlaceholderLike: Copy + Debug + Hash + Eq { fn universe(self) -> ty::UniverseIndex; diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index 7e88114df460..a9917192144f 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -12,7 +12,7 @@ use crate::lang_items::TraitSolverLangItem; use crate::relate::Relate; use crate::solve::{CanonicalInput, ExternalConstraintsData, PredefinedOpaquesData, QueryResult}; -use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable}; +use crate::visit::{Flags, TypeVisitable}; use crate::{self as ty, search_graph}; #[cfg_attr(feature = "nightly", rustc_diagnostic_item = "type_ir_interner")] @@ -146,7 +146,7 @@ fn mk_tracked( type ParamEnv: ParamEnv; type Predicate: Predicate; type Clause: Clause; - type Clauses: Copy + Debug + Hash + Eq + TypeSuperVisitable + Flags; + type Clauses: Clauses; fn with_global_cache(self, f: impl FnOnce(&mut search_graph::GlobalCache) -> R) -> R; diff --git a/compiler/rustc_type_ir/src/visit.rs b/compiler/rustc_type_ir/src/visit.rs index ccb84e259112..fc3864dd5ae6 100644 --- a/compiler/rustc_type_ir/src/visit.rs +++ b/compiler/rustc_type_ir/src/visit.rs @@ -120,8 +120,8 @@ fn visit_predicate(&mut self, p: I::Predicate) -> Self::Result { p.super_visit_with(self) } - fn visit_clauses(&mut self, p: I::Clauses) -> Self::Result { - p.super_visit_with(self) + fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result { + c.super_visit_with(self) } fn visit_error(&mut self, _guar: I::ErrorGuaranteed) -> Self::Result {