Port rustc_autodiff to the attribute parsers

This commit is contained in:
Jonathan Brouwer
2026-02-22 18:14:05 +01:00
parent 90c93ab7c1
commit 5cd5b90a38
12 changed files with 141 additions and 150 deletions
@@ -0,0 +1,117 @@
use std::str::FromStr;
use rustc_ast::LitKind;
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
use rustc_feature::{AttributeTemplate, template};
use rustc_hir::attrs::{AttributeKind, RustcAutodiff};
use rustc_hir::{MethodKind, Target};
use rustc_span::{Symbol, sym};
use thin_vec::ThinVec;
use crate::attributes::prelude::Allow;
use crate::attributes::{AttributeOrder, OnDuplicate, SingleAttributeParser};
use crate::context::{AcceptContext, Stage};
use crate::parser::{ArgParser, MetaItemOrLitParser};
use crate::target_checking::AllowedTargets;
pub(crate) struct RustcAutodiffParser;
impl<S: Stage> SingleAttributeParser<S> for RustcAutodiffParser {
const PATH: &[Symbol] = &[sym::rustc_autodiff];
const ATTRIBUTE_ORDER: AttributeOrder = AttributeOrder::KeepInnermost;
const ON_DUPLICATE: OnDuplicate<S> = OnDuplicate::Error;
const ALLOWED_TARGETS: AllowedTargets = AllowedTargets::AllowList(&[
Allow(Target::Fn),
Allow(Target::Method(MethodKind::Inherent)),
Allow(Target::Method(MethodKind::Trait { body: true })),
Allow(Target::Method(MethodKind::TraitImpl)),
]);
const TEMPLATE: AttributeTemplate = template!(
List: &["MODE", "WIDTH", "INPUT_ACTIVITIES", "OUTPUT_ACTIVITY"],
"https://doc.rust-lang.org/std/autodiff/index.html"
);
fn convert(cx: &mut AcceptContext<'_, '_, S>, args: &ArgParser) -> Option<AttributeKind> {
let list = match args {
ArgParser::NoArgs => return Some(AttributeKind::RustcAutodiff(None)),
ArgParser::List(list) => list,
ArgParser::NameValue(_) => {
cx.expected_list_or_no_args(cx.attr_span);
return None;
}
};
let mut items = list.mixed().peekable();
// Parse name
let Some(mode) = items.next() else {
cx.expected_at_least_one_argument(list.span);
return None;
};
let Some(mode) = mode.meta_item() else {
cx.expected_identifier(mode.span());
return None;
};
let Ok(()) = mode.args().no_args() else {
cx.expected_identifier(mode.span());
return None;
};
let Some(mode) = mode.path().word() else {
cx.expected_identifier(mode.span());
return None;
};
let Ok(mode) = DiffMode::from_str(mode.as_str()) else {
cx.expected_specific_argument(mode.span, DiffMode::all_modes());
return None;
};
// Parse width
let width = if let Some(width) = items.peek()
&& let MetaItemOrLitParser::Lit(width) = width
&& let LitKind::Int(width, _) = width.kind
&& let Ok(width) = width.0.try_into()
{
_ = items.next();
width
} else {
1
};
// Parse activities
let mut activities = ThinVec::new();
for activity in items {
let MetaItemOrLitParser::MetaItemParser(activity) = activity else {
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
return None;
};
let Ok(()) = activity.args().no_args() else {
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
return None;
};
let Some(activity) = activity.path().word() else {
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
return None;
};
let Ok(activity) = DiffActivity::from_str(activity.as_str()) else {
cx.expected_specific_argument(activity.span, DiffActivity::all_activities());
return None;
};
activities.push(activity);
}
let Some(ret_activity) = activities.pop() else {
cx.expected_specific_argument(
list.span.with_lo(list.span.hi()),
DiffActivity::all_activities(),
);
return None;
};
Some(AttributeKind::RustcAutodiff(Some(Box::new(RustcAutodiff {
mode,
width,
input_activity: activities,
ret_activity,
}))))
}
}
@@ -30,6 +30,7 @@
mod prelude;
pub(crate) mod allow_unstable;
pub(crate) mod autodiff;
pub(crate) mod body;
pub(crate) mod cfg;
pub(crate) mod cfg_select;
@@ -19,6 +19,7 @@
use crate::AttributeParser;
// Glob imports to avoid big, bitrotty import lists
use crate::attributes::allow_unstable::*;
use crate::attributes::autodiff::*;
use crate::attributes::body::*;
use crate::attributes::cfi_encoding::*;
use crate::attributes::codegen_attrs::*;
@@ -204,6 +205,7 @@ mod late {
Single<ReexportTestHarnessMainParser>,
Single<RustcAbiParser>,
Single<RustcAllocatorZeroedVariantParser>,
Single<RustcAutodiffParser>,
Single<RustcBuiltinMacroParser>,
Single<RustcDefPathParser>,
Single<RustcDeprecatedSafe2024Parser>,
+5 -3
View File
@@ -6,7 +6,6 @@
Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size, WrappingRange,
};
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
@@ -15,6 +14,7 @@
use rustc_data_structures::assert_matches;
use rustc_hir as hir;
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_hir::find_attr;
use rustc_middle::mir::BinOp;
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
use rustc_middle::ty::offload_meta::OffloadMetadata;
@@ -1367,7 +1367,9 @@ fn codegen_autodiff<'ll, 'tcx>(
let val_arr = get_args_from_tuple(bx, args[2], fn_diff);
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else {
let Some(Some(mut diff_attrs)) =
find_attr!(tcx, fn_diff.def_id(), RustcAutodiff(attr) => attr.clone())
else {
bug!("could not find autodiff attrs")
};
@@ -1389,7 +1391,7 @@ fn codegen_autodiff<'ll, 'tcx>(
&diff_symbol,
llret_ty,
&val_arr,
diff_attrs.clone(),
&diff_attrs,
result,
fnc_tree,
);
@@ -1,8 +1,4 @@
use std::str::FromStr;
use rustc_abi::{Align, ExternAbi};
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
use rustc_ast::{LitKind, MetaItem, MetaItemInner};
use rustc_hir::attrs::{
AttributeKind, EiiImplResolution, InlineAttr, Linkage, RtsanSetting, UsedBy,
};
@@ -14,7 +10,6 @@
};
use rustc_middle::mir::mono::Visibility;
use rustc_middle::query::Providers;
use rustc_middle::span_bug;
use rustc_middle::ty::{self as ty, TyCtxt};
use rustc_session::lint;
use rustc_session::parse::feature_err;
@@ -614,116 +609,6 @@ fn inherited_align<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Align> {
tcx.codegen_fn_attrs(tcx.trait_item_of(def_id)?).alignment
}
/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)]
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
/// panic, unless we introduced a bug when parsing the autodiff macro.
//FIXME(jdonszelmann): put in the main loop. No need to have two..... :/ Let's do that when we make autodiff parsed.
pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
#[allow(deprecated)]
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::<Vec<_>>();
// check for exactly one autodiff attribute on placeholder functions.
// There should only be one, since we generate a new placeholder per ad macro.
let attr = match &attrs[..] {
[] => return None,
[attr] => attr,
_ => {
span_bug!(attrs[1].span(), "cg_ssa: rustc_autodiff should only exist once per source");
}
};
let list = attr.meta_item_list().unwrap_or_default();
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
if list.is_empty() {
return Some(AutoDiffAttrs::source());
}
let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else {
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities");
};
let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode {
p1.segments.first().unwrap().ident
} else {
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode");
};
// parse mode
let mode = match mode.as_str() {
"Forward" => DiffMode::Forward,
"Reverse" => DiffMode::Reverse,
_ => {
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
}
};
let width: u32 = match width_meta {
MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => {
let w = p1.segments.first().unwrap().ident;
match w.as_str().parse() {
Ok(val) => val,
Err(_) => {
span_bug!(w.span, "rustc_autodiff width should fit u32");
}
}
}
MetaItemInner::Lit(lit) => {
if let LitKind::Int(val, _) = lit.kind {
match val.get().try_into() {
Ok(val) => val,
Err(_) => {
span_bug!(lit.span, "rustc_autodiff width should fit u32");
}
}
} else {
span_bug!(lit.span, "rustc_autodiff width should be an integer");
}
}
};
// First read the ret symbol from the attribute
let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity else {
span_bug!(attr.span(), "rustc_autodiff attribute must contain the return activity");
};
let ret_symbol = p1.segments.first().unwrap().ident;
// Then parse it into an actual DiffActivity
let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
span_bug!(ret_symbol.span, "invalid return activity");
};
// Now parse all the intermediate (input) activities
let mut arg_activities: Vec<DiffActivity> = vec![];
for arg in input_activities {
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p2, .. }) = arg {
match p2.segments.first() {
Some(x) => x.ident,
None => {
span_bug!(
arg.span(),
"rustc_autodiff attribute must contain the input activity"
);
}
}
} else {
span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity");
};
match DiffActivity::from_str(arg_symbol.as_str()) {
Ok(arg_activity) => arg_activities.push(arg_activity),
Err(_) => {
span_bug!(arg_symbol.span, "invalid input activity");
}
}
}
Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
}
pub(crate) fn provide(providers: &mut Providers) {
*providers = Providers {
codegen_fn_attrs,
@@ -1286,6 +1286,9 @@ pub enum AttributeKind {
/// Represents `#[rustc_as_ptr]` (used by the `dangling_pointers_from_temporaries` lint).
RustcAsPtr(Span),
/// Represents `#[rustc_autodiff]`.
RustcAutodiff(Option<Box<RustcAutodiff>>),
/// Represents `#[rustc_default_body_unstable]`.
RustcBodyStability {
stability: DefaultBodyStability,
@@ -102,6 +102,7 @@ pub fn encode_cross_crate(&self) -> EncodeCrossCrate {
RustcAllowConstFnUnstable(..) => No,
RustcAllowIncoherentImpl(..) => No,
RustcAsPtr(..) => Yes,
RustcAutodiff(..) => Yes,
RustcBodyStability { .. } => No,
RustcBuiltinMacro { .. } => Yes,
RustcCaptureAnalysis => No,
@@ -6,6 +6,7 @@
use rustc_ast::ast::{Path, join_path_idents};
use rustc_ast::attr::data_structures::CfgEntry;
use rustc_ast::attr::version::RustcVersion;
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
use rustc_ast::token::{CommentKind, DocFragmentKind};
use rustc_ast::{AttrId, AttrStyle, IntTy, UintTy};
use rustc_ast_pretty::pp::Printer;
@@ -8,7 +8,6 @@
use rustc_middle::query::Providers;
use rustc_middle::ty::TyCtxt;
use rustc_session::config::{InliningThreshold, OptLevel};
use rustc_span::sym;
use crate::{inline, pass_manager as pm};
@@ -37,11 +36,7 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool {
}
// FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880
#[allow(deprecated)]
if tcx.has_attr(def_id, sym::autodiff_forward)
|| tcx.has_attr(def_id, sym::autodiff_reverse)
|| tcx.has_attr(def_id, sym::rustc_autodiff)
{
if find_attr!(tcx, def_id, RustcAutodiff(..)) {
return true;
}
+1 -18
View File
@@ -53,7 +53,6 @@
use rustc_trait_selection::error_reporting::InferCtxtErrorExt;
use rustc_trait_selection::infer::{TyCtxtInferExt, ValuePairs};
use rustc_trait_selection::traits::ObligationCtxt;
use tracing::debug;
use crate::errors;
@@ -299,6 +298,7 @@ fn check_attributes(
| AttributeKind::RustcAllocatorZeroed
| AttributeKind::RustcAllocatorZeroedVariant { .. }
| AttributeKind::RustcAsPtr(..)
| AttributeKind::RustcAutodiff(..)
| AttributeKind::RustcBodyStability { .. }
| AttributeKind::RustcBuiltinMacro { .. }
| AttributeKind::RustcCaptureAnalysis
@@ -390,9 +390,6 @@ fn check_attributes(
Attribute::Unparsed(attr_item) => {
style = Some(attr_item.style);
match attr.path().as_slice() {
[sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
self.check_autodiff(hir_id, attr, span, target)
}
[
// ok
sym::allow
@@ -402,8 +399,6 @@ fn check_attributes(
| sym::forbid
// internal
| sym::rustc_on_unimplemented
| sym::rustc_layout
| sym::rustc_autodiff
// crate-level attrs, are checked below
| sym::feature,
..
@@ -1863,18 +1858,6 @@ fn check_mix_no_mangle_export(&self, hir_id: HirId, attrs: &[Attribute]) {
}
}
/// Checks if `#[autodiff]` is applied to an item other than a function item.
fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) {
debug!("check_autodiff");
match target {
Target::Fn => {}
_ => {
self.dcx().emit_err(errors::AutoDiffAttr { attr_span: span });
self.abort.set(true);
}
}
}
fn check_loop_match(&self, hir_id: HirId, attr_span: Span, target: Target) {
let node_span = self.tcx.hir_span(hir_id);
-8
View File
@@ -19,14 +19,6 @@
#[diag("`#[diagnostic::do_not_recommend]` can only be placed on trait implementations")]
pub(crate) struct IncorrectDoNotRecommendLocation;
#[derive(Diagnostic)]
#[diag("`#[autodiff]` should be applied to a function")]
pub(crate) struct AutoDiffAttr {
#[primary_span]
#[label("not a function")]
pub attr_span: Span,
}
#[derive(Diagnostic)]
#[diag("`#[loop_match]` should be applied to a loop")]
pub(crate) struct LoopMatchAttr {
+9
View File
@@ -157,6 +157,8 @@
Abi,
AcqRel,
Acquire,
Active,
ActiveOnly,
Alignment,
Arc,
ArcWeak,
@@ -213,6 +215,12 @@
Deref,
DispatchFromDyn,
Display,
Dual,
DualOnly,
Dualv,
DualvOnly,
Duplicated,
DuplicatedOnly,
DynTrait,
Enum,
Eq,
@@ -310,6 +318,7 @@
Slice,
SliceIndex,
Some,
Source,
SpanCtxt,
Str,
String,