From c021d2ddd49f8ff07d27fa772d88a3eb229e63ec Mon Sep 17 00:00:00 2001 From: beetrees Date: Fri, 28 Mar 2025 20:09:07 +0000 Subject: [PATCH] Fallback `{float}` to `f32` when `f32: From<{float}>` --- compiler/rustc_hir/src/lang_items.rs | 3 + compiler/rustc_hir_typeck/src/fallback.rs | 53 ++++++++- .../src/fn_ctxt/inspect_obligations.rs | 101 +++++++++++++++++- compiler/rustc_infer/src/infer/mod.rs | 4 + compiler/rustc_middle/src/ty/sty.rs | 8 ++ library/core/src/convert/mod.rs | 1 + tests/ui/float/f32-into-f32.rs | 9 ++ tests/ui/float/trait-f16-or-f32.rs | 13 +++ tests/ui/float/trait-f16-or-f32.stderr | 20 ++++ 9 files changed, 207 insertions(+), 5 deletions(-) create mode 100644 tests/ui/float/f32-into-f32.rs create mode 100644 tests/ui/float/trait-f16-or-f32.rs create mode 100644 tests/ui/float/trait-f16-or-f32.stderr diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index c144f0b7dbc5..75c708e33929 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -443,6 +443,9 @@ fn hash_stable(&self, _: &mut Hcx, hasher: &mut StableHasher) { FieldBase, sym::field_base, field_base, Target::AssocTy, GenericRequirement::Exact(0); FieldType, sym::field_type, field_type, Target::AssocTy, GenericRequirement::Exact(0); FieldOffset, sym::field_offset, field_offset, Target::AssocConst, GenericRequirement::Exact(0); + + // Used to fallback `{float}` to `f32` when `f32: From<{float}>` + From, sym::From, from_trait, Target::Trait, GenericRequirement::Exact(1); } /// The requirement imposed on the generics of a lang item diff --git a/compiler/rustc_hir_typeck/src/fallback.rs b/compiler/rustc_hir_typeck/src/fallback.rs index 5aadf37720d0..704abb9c39d9 100644 --- a/compiler/rustc_hir_typeck/src/fallback.rs +++ b/compiler/rustc_hir_typeck/src/fallback.rs @@ -11,7 +11,7 @@ use rustc_hir::def::{DefKind, Res}; use rustc_hir::def_id::DefId; use rustc_hir::intravisit::{InferKind, Visitor}; -use rustc_middle::ty::{self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable}; +use rustc_middle::ty::{self, FloatVid, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable}; use rustc_session::lint; use rustc_span::def_id::LocalDefId; use rustc_span::{DUMMY_SP, Span}; @@ -55,6 +55,7 @@ fn fallback_types(&self) -> bool { let (diverging_fallback, diverging_fallback_ty) = self.calculate_diverging_fallback(&unresolved_variables); + let fallback_to_f32 = self.calculate_fallback_to_f32(&unresolved_variables); // We do fallback in two passes, to try to generate // better error messages. @@ -62,8 +63,12 @@ fn fallback_types(&self) -> bool { let mut fallback_occurred = false; for ty in unresolved_variables { debug!("unsolved_variable = {:?}", ty); - fallback_occurred |= - self.fallback_if_possible(ty, &diverging_fallback, diverging_fallback_ty); + fallback_occurred |= self.fallback_if_possible( + ty, + &diverging_fallback, + diverging_fallback_ty, + &fallback_to_f32, + ); } fallback_occurred @@ -73,7 +78,8 @@ fn fallback_types(&self) -> bool { /// /// - Unconstrained ints are replaced with `i32`. /// - /// - Unconstrained floats are replaced with `f64`. + /// - Unconstrained floats are replaced with `f64`, except when there is a trait predicate + /// `f32: From<{float}>`, in which case `f32` is used as the fallback instead. /// /// - Non-numerics may get replaced with `()` or `!`, depending on how they /// were categorized by [`Self::calculate_diverging_fallback`], crate's @@ -89,6 +95,7 @@ fn fallback_if_possible( ty: Ty<'tcx>, diverging_fallback: &UnordSet>, diverging_fallback_ty: Ty<'tcx>, + fallback_to_f32: &UnordSet, ) -> bool { // Careful: we do NOT shallow-resolve `ty`. We know that `ty` // is an unsolved variable, and we determine its fallback @@ -111,6 +118,7 @@ fn fallback_if_possible( let fallback = match ty.kind() { _ if let Some(e) = self.tainted_by_errors() => Ty::new_error(self.tcx, e), ty::Infer(ty::IntVar(_)) => self.tcx.types.i32, + ty::Infer(ty::FloatVar(vid)) if fallback_to_f32.contains(vid) => self.tcx.types.f32, ty::Infer(ty::FloatVar(_)) => self.tcx.types.f64, _ if diverging_fallback.contains(&ty) => { self.diverging_fallback_has_occurred.set(true); @@ -125,6 +133,38 @@ fn fallback_if_possible( true } + /// Existing code relies on `f32: From` (usually written as `T: Into`) resolving `T` to + /// `f32` when the type of `T` is inferred from an unsuffixed float literal. Using the default + /// fallback of `f64`, this would break when adding `impl From for f32`, as there are now + /// two float type which could be `T`, meaning that the fallback of `f64` would be used and + /// compilation error would occur as `f32` does not implement `From`. To avoid breaking + /// existing code, we instead fallback `T` to `f32` when there is a trait predicate + /// `f32: From`. This means code like the following will continue to compile: + /// + /// ```rust + /// fn foo>(_: T) {} + /// + /// foo(1.0); + /// ``` + fn calculate_fallback_to_f32(&self, unresolved_variables: &[Ty<'tcx>]) -> UnordSet { + let roots: UnordSet = self.from_float_for_f32_root_vids(); + if roots.is_empty() { + // Most functions have no `f32: From<{float}>` predicates, so short-circuit and return + // an empty set when this is the case. + return UnordSet::new(); + } + // Calculate all the unresolved variables that need to fallback to `f32` here. This ensures + // we don't need to find root variables in `fallback_if_possible`: see the comment at the + // top of that function for details. + let fallback_to_f32 = unresolved_variables + .iter() + .flat_map(|ty| ty.float_vid()) + .filter(|vid| roots.contains(&self.root_float_var(*vid))) + .collect(); + debug!("calculate_fallback_to_f32: fallback_to_f32={:?}", fallback_to_f32); + fallback_to_f32 + } + fn calculate_diverging_fallback( &self, unresolved_variables: &[Ty<'tcx>], @@ -362,6 +402,11 @@ fn root_vid(&self, ty: Ty<'tcx>) -> Option { Some(self.root_var(self.shallow_resolve(ty).ty_vid()?)) } + /// If `ty` is an unresolved float type variable, returns its root vid. + pub(crate) fn root_float_vid(&self, ty: Ty<'tcx>) -> Option { + Some(self.root_float_var(self.shallow_resolve(ty).float_vid()?)) + } + /// Given a set of diverging vids and coercions, walk the HIR to gather a /// set of suggestions which can be applied to preserve fallback to unit. fn try_to_suggest_annotations( diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/inspect_obligations.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/inspect_obligations.rs index 1ab7ac4c2e36..dcaace299ada 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/inspect_obligations.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/inspect_obligations.rs @@ -1,5 +1,7 @@ //! A utility module to inspect currently ambiguous obligations in the current context. +use rustc_data_structures::unord::UnordSet; +use rustc_hir::def_id::DefId; use rustc_infer::traits::{self, ObligationCause, PredicateObligations}; use rustc_middle::ty::{self, Ty, TypeVisitableExt}; use rustc_span::Span; @@ -96,6 +98,69 @@ pub(crate) fn obligations_for_self_ty_next( }); obligations_for_self_ty } + + /// Only needed for the `From<{float}>` for `f32` type fallback. + #[instrument(skip(self), level = "debug")] + pub(crate) fn from_float_for_f32_root_vids(&self) -> UnordSet { + if self.next_trait_solver() { + self.from_float_for_f32_root_vids_next() + } else { + let Some(from_trait) = self.tcx.lang_items().from_trait() else { + return UnordSet::new(); + }; + self.fulfillment_cx + .borrow_mut() + .pending_obligations() + .into_iter() + .filter_map(|obligation| { + self.predicate_from_float_for_f32_root_vid(from_trait, obligation.predicate) + }) + .collect() + } + } + + fn predicate_from_float_for_f32_root_vid( + &self, + from_trait: DefId, + predicate: ty::Predicate<'tcx>, + ) -> Option { + // The predicates we are looking for look like + // `TraitPredicate(>, polarity:Positive)`. + // They will have no bound variables. + match predicate.kind().no_bound_vars() { + Some(ty::PredicateKind::Clause(ty::ClauseKind::Trait(ty::TraitPredicate { + polarity: ty::PredicatePolarity::Positive, + trait_ref, + }))) if trait_ref.def_id == from_trait + && self.shallow_resolve(trait_ref.self_ty()).kind() + == &ty::Float(ty::FloatTy::F32) => + { + self.root_float_vid(trait_ref.args.type_at(1)) + } + _ => None, + } + } + + fn from_float_for_f32_root_vids_next(&self) -> UnordSet { + let Some(from_trait) = self.tcx.lang_items().from_trait() else { + return UnordSet::new(); + }; + let obligations = self.fulfillment_cx.borrow().pending_obligations(); + debug!(?obligations); + let mut vids = UnordSet::new(); + for obligation in obligations { + let mut visitor = FindFromFloatForF32RootVids { + fcx: self, + from_trait, + vids: &mut vids, + span: obligation.cause.span, + }; + + let goal = obligation.as_goal(); + self.visit_proof_tree(goal, &mut visitor); + } + vids + } } struct NestedObligationsForSelfTy<'a, 'tcx> { @@ -105,7 +170,7 @@ struct NestedObligationsForSelfTy<'a, 'tcx> { obligations_for_self_ty: &'a mut PredicateObligations<'tcx>, } -impl<'a, 'tcx> ProofTreeVisitor<'tcx> for NestedObligationsForSelfTy<'a, 'tcx> { +impl<'tcx> ProofTreeVisitor<'tcx> for NestedObligationsForSelfTy<'_, 'tcx> { fn span(&self) -> Span { self.root_cause.span } @@ -144,3 +209,37 @@ fn visit_goal(&mut self, inspect_goal: &InspectGoal<'_, 'tcx>) { } } } + +struct FindFromFloatForF32RootVids<'a, 'tcx> { + fcx: &'a FnCtxt<'a, 'tcx>, + from_trait: DefId, + vids: &'a mut UnordSet, + span: Span, +} + +impl<'tcx> ProofTreeVisitor<'tcx> for FindFromFloatForF32RootVids<'_, 'tcx> { + fn span(&self) -> Span { + self.span + } + + fn config(&self) -> InspectConfig { + // Avoid hang from exponentially growing proof trees (see `cycle-modulo-ambig-aliases.rs`). + // 3 is more than enough for all occurences in practice (a.k.a. `Into`). + InspectConfig { max_depth: 3 } + } + + fn visit_goal(&mut self, inspect_goal: &InspectGoal<'_, 'tcx>) { + if let Some(vid) = self + .fcx + .predicate_from_float_for_f32_root_vid(self.from_trait, inspect_goal.goal().predicate) + { + self.vids.insert(vid); + } else if let Some(candidate) = inspect_goal.unique_applicable_candidate() { + let start_len = self.vids.len(); + let _ = candidate.goal().infcx().commit_if_ok(|_| { + candidate.visit_nested_no_probe(self); + if self.vids.len() > start_len { Ok(()) } else { Err(()) } + }); + } + } +} diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs index a38d4e819e29..0e2934760f4a 100644 --- a/compiler/rustc_infer/src/infer/mod.rs +++ b/compiler/rustc_infer/src/infer/mod.rs @@ -1166,6 +1166,10 @@ pub fn sub_unification_table_root_var(&self, var: ty::TyVid) -> ty::TyVid { self.inner.borrow_mut().type_variables().sub_unification_table_root_var(var) } + pub fn root_float_var(&self, var: ty::FloatVid) -> ty::FloatVid { + self.inner.borrow_mut().float_unification_table().find(var) + } + pub fn root_const_var(&self, var: ty::ConstVid) -> ty::ConstVid { self.inner.borrow_mut().const_unification_table().find(var).vid } diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index 9164f7b57e64..c781d129a160 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -1180,6 +1180,14 @@ pub fn ty_vid(self) -> Option { } } + #[inline] + pub fn float_vid(self) -> Option { + match self.kind() { + &Infer(FloatVar(vid)) => Some(vid), + _ => None, + } + } + #[inline] pub fn is_ty_or_numeric_infer(self) -> bool { matches!(self.kind(), Infer(_)) diff --git a/library/core/src/convert/mod.rs b/library/core/src/convert/mod.rs index ef4ab15f93c0..4a4c7ee388f9 100644 --- a/library/core/src/convert/mod.rs +++ b/library/core/src/convert/mod.rs @@ -577,6 +577,7 @@ pub const trait Into: Sized { /// [`from`]: From::from /// [book]: ../../book/ch09-00-error-handling.html #[rustc_diagnostic_item = "From"] +#[lang = "From"] #[stable(feature = "rust1", since = "1.0.0")] #[rustc_on_unimplemented(on( all(Self = "&str", T = "alloc::string::String"), diff --git a/tests/ui/float/f32-into-f32.rs b/tests/ui/float/f32-into-f32.rs new file mode 100644 index 000000000000..1b3f0926bdde --- /dev/null +++ b/tests/ui/float/f32-into-f32.rs @@ -0,0 +1,9 @@ +//@ revisions: old-solver next-solver +//@[next-solver] compile-flags: -Znext-solver +//@ run-pass + +fn foo(_: impl Into) {} + +fn main() { + foo(1.0); +} diff --git a/tests/ui/float/trait-f16-or-f32.rs b/tests/ui/float/trait-f16-or-f32.rs new file mode 100644 index 000000000000..72f0a4fbde47 --- /dev/null +++ b/tests/ui/float/trait-f16-or-f32.rs @@ -0,0 +1,13 @@ +//@ check-fail + +#![feature(f16)] + +trait Trait {} +impl Trait for f16 {} +impl Trait for f32 {} + +fn foo(_: impl Trait) {} + +fn main() { + foo(1.0); //~ ERROR the trait bound `f64: Trait` is not satisfied +} diff --git a/tests/ui/float/trait-f16-or-f32.stderr b/tests/ui/float/trait-f16-or-f32.stderr new file mode 100644 index 000000000000..8af81231bd94 --- /dev/null +++ b/tests/ui/float/trait-f16-or-f32.stderr @@ -0,0 +1,20 @@ +error[E0277]: the trait bound `f64: Trait` is not satisfied + --> $DIR/trait-f16-or-f32.rs:12:9 + | +LL | foo(1.0); + | --- ^^^ the trait `Trait` is not implemented for `f64` + | | + | required by a bound introduced by this call + | + = help: the following other types implement trait `Trait`: + f16 + f32 +note: required by a bound in `foo` + --> $DIR/trait-f16-or-f32.rs:9:16 + | +LL | fn foo(_: impl Trait) {} + | ^^^^^ required by this bound in `foo` + +error: aborting due to 1 previous error + +For more information about this error, try `rustc --explain E0277`.