From 5cd5b90a38b99cd119d82d3c45d99fb520f7b11b Mon Sep 17 00:00:00 2001 From: Jonathan Brouwer Date: Sun, 22 Feb 2026 18:14:05 +0100 Subject: [PATCH] Port `rustc_autodiff` to the attribute parsers --- .../src/attributes/autodiff.rs | 117 ++++++++++++++++++ .../rustc_attr_parsing/src/attributes/mod.rs | 1 + compiler/rustc_attr_parsing/src/context.rs | 2 + compiler/rustc_codegen_llvm/src/intrinsic.rs | 8 +- .../rustc_codegen_ssa/src/codegen_attrs.rs | 115 ----------------- .../rustc_hir/src/attrs/data_structures.rs | 3 + .../rustc_hir/src/attrs/encode_cross_crate.rs | 1 + .../rustc_hir/src/attrs/pretty_printing.rs | 1 + .../src/cross_crate_inline.rs | 7 +- compiler/rustc_passes/src/check_attr.rs | 19 +-- compiler/rustc_passes/src/errors.rs | 8 -- compiler/rustc_span/src/symbol.rs | 9 ++ 12 files changed, 141 insertions(+), 150 deletions(-) create mode 100644 compiler/rustc_attr_parsing/src/attributes/autodiff.rs diff --git a/compiler/rustc_attr_parsing/src/attributes/autodiff.rs b/compiler/rustc_attr_parsing/src/attributes/autodiff.rs new file mode 100644 index 000000000000..118a4103b1a9 --- /dev/null +++ b/compiler/rustc_attr_parsing/src/attributes/autodiff.rs @@ -0,0 +1,117 @@ +use std::str::FromStr; + +use rustc_ast::LitKind; +use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode}; +use rustc_feature::{AttributeTemplate, template}; +use rustc_hir::attrs::{AttributeKind, RustcAutodiff}; +use rustc_hir::{MethodKind, Target}; +use rustc_span::{Symbol, sym}; +use thin_vec::ThinVec; + +use crate::attributes::prelude::Allow; +use crate::attributes::{AttributeOrder, OnDuplicate, SingleAttributeParser}; +use crate::context::{AcceptContext, Stage}; +use crate::parser::{ArgParser, MetaItemOrLitParser}; +use crate::target_checking::AllowedTargets; + +pub(crate) struct RustcAutodiffParser; + +impl SingleAttributeParser for RustcAutodiffParser { + const PATH: &[Symbol] = &[sym::rustc_autodiff]; + const ATTRIBUTE_ORDER: AttributeOrder = AttributeOrder::KeepInnermost; + const ON_DUPLICATE: OnDuplicate = OnDuplicate::Error; + const ALLOWED_TARGETS: AllowedTargets = AllowedTargets::AllowList(&[ + Allow(Target::Fn), + Allow(Target::Method(MethodKind::Inherent)), + Allow(Target::Method(MethodKind::Trait { body: true })), + Allow(Target::Method(MethodKind::TraitImpl)), + ]); + const TEMPLATE: AttributeTemplate = template!( + List: &["MODE", "WIDTH", "INPUT_ACTIVITIES", "OUTPUT_ACTIVITY"], + "https://doc.rust-lang.org/std/autodiff/index.html" + ); + + fn convert(cx: &mut AcceptContext<'_, '_, S>, args: &ArgParser) -> Option { + let list = match args { + ArgParser::NoArgs => return Some(AttributeKind::RustcAutodiff(None)), + ArgParser::List(list) => list, + ArgParser::NameValue(_) => { + cx.expected_list_or_no_args(cx.attr_span); + return None; + } + }; + + let mut items = list.mixed().peekable(); + + // Parse name + let Some(mode) = items.next() else { + cx.expected_at_least_one_argument(list.span); + return None; + }; + let Some(mode) = mode.meta_item() else { + cx.expected_identifier(mode.span()); + return None; + }; + let Ok(()) = mode.args().no_args() else { + cx.expected_identifier(mode.span()); + return None; + }; + let Some(mode) = mode.path().word() else { + cx.expected_identifier(mode.span()); + return None; + }; + let Ok(mode) = DiffMode::from_str(mode.as_str()) else { + cx.expected_specific_argument(mode.span, DiffMode::all_modes()); + return None; + }; + + // Parse width + let width = if let Some(width) = items.peek() + && let MetaItemOrLitParser::Lit(width) = width + && let LitKind::Int(width, _) = width.kind + && let Ok(width) = width.0.try_into() + { + _ = items.next(); + width + } else { + 1 + }; + + // Parse activities + let mut activities = ThinVec::new(); + for activity in items { + let MetaItemOrLitParser::MetaItemParser(activity) = activity else { + cx.expected_specific_argument(activity.span(), DiffActivity::all_activities()); + return None; + }; + let Ok(()) = activity.args().no_args() else { + cx.expected_specific_argument(activity.span(), DiffActivity::all_activities()); + return None; + }; + let Some(activity) = activity.path().word() else { + cx.expected_specific_argument(activity.span(), DiffActivity::all_activities()); + return None; + }; + let Ok(activity) = DiffActivity::from_str(activity.as_str()) else { + cx.expected_specific_argument(activity.span, DiffActivity::all_activities()); + return None; + }; + + activities.push(activity); + } + let Some(ret_activity) = activities.pop() else { + cx.expected_specific_argument( + list.span.with_lo(list.span.hi()), + DiffActivity::all_activities(), + ); + return None; + }; + + Some(AttributeKind::RustcAutodiff(Some(Box::new(RustcAutodiff { + mode, + width, + input_activity: activities, + ret_activity, + })))) + } +} diff --git a/compiler/rustc_attr_parsing/src/attributes/mod.rs b/compiler/rustc_attr_parsing/src/attributes/mod.rs index 8ee453d7f464..223c88972d75 100644 --- a/compiler/rustc_attr_parsing/src/attributes/mod.rs +++ b/compiler/rustc_attr_parsing/src/attributes/mod.rs @@ -30,6 +30,7 @@ mod prelude; pub(crate) mod allow_unstable; +pub(crate) mod autodiff; pub(crate) mod body; pub(crate) mod cfg; pub(crate) mod cfg_select; diff --git a/compiler/rustc_attr_parsing/src/context.rs b/compiler/rustc_attr_parsing/src/context.rs index b82607e7c450..802ee56f504b 100644 --- a/compiler/rustc_attr_parsing/src/context.rs +++ b/compiler/rustc_attr_parsing/src/context.rs @@ -19,6 +19,7 @@ use crate::AttributeParser; // Glob imports to avoid big, bitrotty import lists use crate::attributes::allow_unstable::*; +use crate::attributes::autodiff::*; use crate::attributes::body::*; use crate::attributes::cfi_encoding::*; use crate::attributes::codegen_attrs::*; @@ -204,6 +205,7 @@ mod late { Single, Single, Single, + Single, Single, Single, Single, diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 854fbbcea3ee..6a7ee711ff8a 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -6,7 +6,6 @@ Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size, WrappingRange, }; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; -use rustc_codegen_ssa::codegen_attrs::autodiff_attrs; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; @@ -15,6 +14,7 @@ use rustc_data_structures::assert_matches; use rustc_hir as hir; use rustc_hir::def_id::LOCAL_CRATE; +use rustc_hir::find_attr; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; use rustc_middle::ty::offload_meta::OffloadMetadata; @@ -1367,7 +1367,9 @@ fn codegen_autodiff<'ll, 'tcx>( let val_arr = get_args_from_tuple(bx, args[2], fn_diff); let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); - let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else { + let Some(Some(mut diff_attrs)) = + find_attr!(tcx, fn_diff.def_id(), RustcAutodiff(attr) => attr.clone()) + else { bug!("could not find autodiff attrs") }; @@ -1389,7 +1391,7 @@ fn codegen_autodiff<'ll, 'tcx>( &diff_symbol, llret_ty, &val_arr, - diff_attrs.clone(), + &diff_attrs, result, fnc_tree, ); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 43f039cc5ebf..1ceb01337b11 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,8 +1,4 @@ -use std::str::FromStr; - use rustc_abi::{Align, ExternAbi}; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; -use rustc_ast::{LitKind, MetaItem, MetaItemInner}; use rustc_hir::attrs::{ AttributeKind, EiiImplResolution, InlineAttr, Linkage, RtsanSetting, UsedBy, }; @@ -14,7 +10,6 @@ }; use rustc_middle::mir::mono::Visibility; use rustc_middle::query::Providers; -use rustc_middle::span_bug; use rustc_middle::ty::{self as ty, TyCtxt}; use rustc_session::lint; use rustc_session::parse::feature_err; @@ -614,116 +609,6 @@ fn inherited_align<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option { tcx.codegen_fn_attrs(tcx.trait_item_of(def_id)?).alignment } -/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)] -/// macros. There are two forms. The pure one without args to mark primal functions (the functions -/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the -/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never -/// panic, unless we introduced a bug when parsing the autodiff macro. -//FIXME(jdonszelmann): put in the main loop. No need to have two..... :/ Let's do that when we make autodiff parsed. -pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { - #[allow(deprecated)] - let attrs = tcx.get_attrs(id, sym::rustc_autodiff); - - let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::>(); - - // check for exactly one autodiff attribute on placeholder functions. - // There should only be one, since we generate a new placeholder per ad macro. - let attr = match &attrs[..] { - [] => return None, - [attr] => attr, - _ => { - span_bug!(attrs[1].span(), "cg_ssa: rustc_autodiff should only exist once per source"); - } - }; - - let list = attr.meta_item_list().unwrap_or_default(); - - // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions - if list.is_empty() { - return Some(AutoDiffAttrs::source()); - } - - let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else { - span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities"); - }; - let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode { - p1.segments.first().unwrap().ident - } else { - span_bug!(attr.span(), "rustc_autodiff attribute must contain mode"); - }; - - // parse mode - let mode = match mode.as_str() { - "Forward" => DiffMode::Forward, - "Reverse" => DiffMode::Reverse, - _ => { - span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode"); - } - }; - - let width: u32 = match width_meta { - MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => { - let w = p1.segments.first().unwrap().ident; - match w.as_str().parse() { - Ok(val) => val, - Err(_) => { - span_bug!(w.span, "rustc_autodiff width should fit u32"); - } - } - } - MetaItemInner::Lit(lit) => { - if let LitKind::Int(val, _) = lit.kind { - match val.get().try_into() { - Ok(val) => val, - Err(_) => { - span_bug!(lit.span, "rustc_autodiff width should fit u32"); - } - } - } else { - span_bug!(lit.span, "rustc_autodiff width should be an integer"); - } - } - }; - - // First read the ret symbol from the attribute - let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity else { - span_bug!(attr.span(), "rustc_autodiff attribute must contain the return activity"); - }; - let ret_symbol = p1.segments.first().unwrap().ident; - - // Then parse it into an actual DiffActivity - let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else { - span_bug!(ret_symbol.span, "invalid return activity"); - }; - - // Now parse all the intermediate (input) activities - let mut arg_activities: Vec = vec![]; - for arg in input_activities { - let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p2, .. }) = arg { - match p2.segments.first() { - Some(x) => x.ident, - None => { - span_bug!( - arg.span(), - "rustc_autodiff attribute must contain the input activity" - ); - } - } - } else { - span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity"); - }; - - match DiffActivity::from_str(arg_symbol.as_str()) { - Ok(arg_activity) => arg_activities.push(arg_activity), - Err(_) => { - span_bug!(arg_symbol.span, "invalid input activity"); - } - } - } - - Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities }) -} - pub(crate) fn provide(providers: &mut Providers) { *providers = Providers { codegen_fn_attrs, diff --git a/compiler/rustc_hir/src/attrs/data_structures.rs b/compiler/rustc_hir/src/attrs/data_structures.rs index 89b064be3447..91409108a753 100644 --- a/compiler/rustc_hir/src/attrs/data_structures.rs +++ b/compiler/rustc_hir/src/attrs/data_structures.rs @@ -1286,6 +1286,9 @@ pub enum AttributeKind { /// Represents `#[rustc_as_ptr]` (used by the `dangling_pointers_from_temporaries` lint). RustcAsPtr(Span), + /// Represents `#[rustc_autodiff]`. + RustcAutodiff(Option>), + /// Represents `#[rustc_default_body_unstable]`. RustcBodyStability { stability: DefaultBodyStability, diff --git a/compiler/rustc_hir/src/attrs/encode_cross_crate.rs b/compiler/rustc_hir/src/attrs/encode_cross_crate.rs index c50d38b6d673..cd41a2b9b28c 100644 --- a/compiler/rustc_hir/src/attrs/encode_cross_crate.rs +++ b/compiler/rustc_hir/src/attrs/encode_cross_crate.rs @@ -102,6 +102,7 @@ pub fn encode_cross_crate(&self) -> EncodeCrossCrate { RustcAllowConstFnUnstable(..) => No, RustcAllowIncoherentImpl(..) => No, RustcAsPtr(..) => Yes, + RustcAutodiff(..) => Yes, RustcBodyStability { .. } => No, RustcBuiltinMacro { .. } => Yes, RustcCaptureAnalysis => No, diff --git a/compiler/rustc_hir/src/attrs/pretty_printing.rs b/compiler/rustc_hir/src/attrs/pretty_printing.rs index 2767d4f95054..9d14f9de3078 100644 --- a/compiler/rustc_hir/src/attrs/pretty_printing.rs +++ b/compiler/rustc_hir/src/attrs/pretty_printing.rs @@ -6,6 +6,7 @@ use rustc_ast::ast::{Path, join_path_idents}; use rustc_ast::attr::data_structures::CfgEntry; use rustc_ast::attr::version::RustcVersion; +use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode}; use rustc_ast::token::{CommentKind, DocFragmentKind}; use rustc_ast::{AttrId, AttrStyle, IntTy, UintTy}; use rustc_ast_pretty::pp::Printer; diff --git a/compiler/rustc_mir_transform/src/cross_crate_inline.rs b/compiler/rustc_mir_transform/src/cross_crate_inline.rs index 7435fbe8d38a..19ffffdd1eca 100644 --- a/compiler/rustc_mir_transform/src/cross_crate_inline.rs +++ b/compiler/rustc_mir_transform/src/cross_crate_inline.rs @@ -8,7 +8,6 @@ use rustc_middle::query::Providers; use rustc_middle::ty::TyCtxt; use rustc_session::config::{InliningThreshold, OptLevel}; -use rustc_span::sym; use crate::{inline, pass_manager as pm}; @@ -37,11 +36,7 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { } // FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880 - #[allow(deprecated)] - if tcx.has_attr(def_id, sym::autodiff_forward) - || tcx.has_attr(def_id, sym::autodiff_reverse) - || tcx.has_attr(def_id, sym::rustc_autodiff) - { + if find_attr!(tcx, def_id, RustcAutodiff(..)) { return true; } diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index 89bd39b77e64..7663a69cc430 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -53,7 +53,6 @@ use rustc_trait_selection::error_reporting::InferCtxtErrorExt; use rustc_trait_selection::infer::{TyCtxtInferExt, ValuePairs}; use rustc_trait_selection::traits::ObligationCtxt; -use tracing::debug; use crate::errors; @@ -299,6 +298,7 @@ fn check_attributes( | AttributeKind::RustcAllocatorZeroed | AttributeKind::RustcAllocatorZeroedVariant { .. } | AttributeKind::RustcAsPtr(..) + | AttributeKind::RustcAutodiff(..) | AttributeKind::RustcBodyStability { .. } | AttributeKind::RustcBuiltinMacro { .. } | AttributeKind::RustcCaptureAnalysis @@ -390,9 +390,6 @@ fn check_attributes( Attribute::Unparsed(attr_item) => { style = Some(attr_item.style); match attr.path().as_slice() { - [sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => { - self.check_autodiff(hir_id, attr, span, target) - } [ // ok sym::allow @@ -402,8 +399,6 @@ fn check_attributes( | sym::forbid // internal | sym::rustc_on_unimplemented - | sym::rustc_layout - | sym::rustc_autodiff // crate-level attrs, are checked below | sym::feature, .. @@ -1863,18 +1858,6 @@ fn check_mix_no_mangle_export(&self, hir_id: HirId, attrs: &[Attribute]) { } } - /// Checks if `#[autodiff]` is applied to an item other than a function item. - fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) { - debug!("check_autodiff"); - match target { - Target::Fn => {} - _ => { - self.dcx().emit_err(errors::AutoDiffAttr { attr_span: span }); - self.abort.set(true); - } - } - } - fn check_loop_match(&self, hir_id: HirId, attr_span: Span, target: Target) { let node_span = self.tcx.hir_span(hir_id); diff --git a/compiler/rustc_passes/src/errors.rs b/compiler/rustc_passes/src/errors.rs index b9ada150d030..0cf0d1a5c80f 100644 --- a/compiler/rustc_passes/src/errors.rs +++ b/compiler/rustc_passes/src/errors.rs @@ -19,14 +19,6 @@ #[diag("`#[diagnostic::do_not_recommend]` can only be placed on trait implementations")] pub(crate) struct IncorrectDoNotRecommendLocation; -#[derive(Diagnostic)] -#[diag("`#[autodiff]` should be applied to a function")] -pub(crate) struct AutoDiffAttr { - #[primary_span] - #[label("not a function")] - pub attr_span: Span, -} - #[derive(Diagnostic)] #[diag("`#[loop_match]` should be applied to a loop")] pub(crate) struct LoopMatchAttr { diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 422a15b060cc..731a83853072 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -157,6 +157,8 @@ Abi, AcqRel, Acquire, + Active, + ActiveOnly, Alignment, Arc, ArcWeak, @@ -213,6 +215,12 @@ Deref, DispatchFromDyn, Display, + Dual, + DualOnly, + Dualv, + DualvOnly, + Duplicated, + DuplicatedOnly, DynTrait, Enum, Eq, @@ -310,6 +318,7 @@ Slice, SliceIndex, Some, + Source, SpanCtxt, Str, String,